From 26ac5724000dcbcecb124836e43a6cc0e2ca19df Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Thu, 24 Oct 2024 15:05:55 +0800 Subject: [PATCH] fix: add task_ai_async Former-commit-id: 55fdcb01ab8719726afec147431087c2f2d284f8 --- client/types.go | 100 +++++++++++----------- internal/logic/core/pulltaskinfologic.go | 35 +------- internal/logic/core/pushtaskinfologic.go | 4 +- internal/storeLink/modelarts.go | 16 +--- pkg/models/taskaiasynchronousmodel.go | 29 +++++++ pkg/models/taskaiasynchronousmodel_gen.go | 98 +++++++++++++++++++++ 6 files changed, 181 insertions(+), 101 deletions(-) create mode 100644 pkg/models/taskaiasynchronousmodel.go create mode 100644 pkg/models/taskaiasynchronousmodel_gen.go diff --git a/client/types.go b/client/types.go index 97d96867..80f842be 100644 --- a/client/types.go +++ b/client/types.go @@ -136,61 +136,61 @@ type CloudInfo struct { } type AiInfo struct { - Id int64 `json:"id"` // id - AdapterId int64 `json:"adapterId,omitempty,optional"` - AdapterName string `json:"adapterName,omitempty,optional"` - ClusterId int64 `json:"clusterId,omitempty,optional"` - ClusterIds []int64 `json:"clusterIds,omitempty,optional"` - TaskId int64 `json:"taskId,omitempty"` - TaskName string `json:"taskName,omitempty"` - Replica int32 `json:"replica,omitempty"` - ResourceType string `json:"resourceType,omitempty"` - CpuCoreNum int32 `json:"cpuCoreNum,omitempty"` - TaskType string `json:"taskType,omitempty"` - DatasetsName string `json:"datasetsName,omitempty"` - ProjectId string `json:"project_id,omitempty"` - StrategyName string `json:"strategyName,omitempty"` - ClusterToStaticWeight map[string]int32 `json:"clusterToStaticWeight,omitempty"` - Tops float64 `json:"tops,omitempty"` - ComputeCard string `json:"computeCard,omitempty,optional"` - CodeType string `json:"codeType,omitempty,optional"` - ClusterName string `json:"clusterName,omitempty,optional"` - ModelName string `json:"ModelName,omitempty,optional"` - AlgorithmName string `json:"algorithmName,omitempty,optional"` - Strategy string `json:"strategy,omitempty"` + Id int64 `json:"id"` // id + AdapterId int64 `json:"adapterId,omitempty,optional"` + AdapterName string `json:"adapterName,omitempty,optional"` + ClusterId int64 `json:"clusterId,omitempty,optional"` + ClusterIds []int64 `json:"clusterIds,omitempty,optional"` + TaskId int64 `json:"taskId,omitempty"` + ClusterName string `json:"clusterName,omitempty,optional"` + ImageId string `json:"imageId,omitempty"` + ResourceId string `json:"resourceId,omitempty"` + AlgorithmId string `json:"algorithmId,omitempty"` + MetadataName string `json:"metadataName,omitempty"` + Command string `json:"command,omitempty"` + Environments string `json:"environments,omitempty"` + Parameters string `json:"parameters,omitempty"` - ImageId string `json:"imageId,omitempty"` - SpecId string `json:"specId,omitempty"` - DatasetsId string `json:"datasetsId,omitempty"` - CodeId string `json:"codeId,omitempty"` - ResourceId string `json:"resourceId,omitempty"` - AlgorithmId string `json:"algorithmId,omitempty"` - MetadataName string `json:"metadataName,omitempty"` + Name string `json:"name,omitempty"` + Status string `json:"status,omitempty"` + StartTime string `json:"startTime,omitempty"` + //RunningTime int64 `json:"runningTime,omitempty"` + JobId string `json:"jobId,omitempty"` + FlavorId string `json:"flavorId,omitempty"` + //TaskName string `json:"taskName,omitempty"` + //Replica int32 `json:"replica,omitempty"` + //ResourceType string `json:"resourceType,omitempty"` + //CpuCoreNum int32 `json:"cpuCoreNum,omitempty"` + //TaskType string `json:"taskType,omitempty"` + //DatasetsName string `json:"datasetsName,omitempty"` + //ProjectId string `json:"project_id,omitempty"` + //StrategyName string `json:"strategyName,omitempty"` + //ClusterToStaticWeight map[string]int32 `json:"clusterToStaticWeight,omitempty"` + //Tops float64 `json:"tops,omitempty"` + //ComputeCard string `json:"computeCard,omitempty,optional"` + //CodeType string `json:"codeType,omitempty,optional"` - Cmd string `json:"cmd,omitempty"` - Envs []string `json:"envs,omitempty"` - Params []string `json:"params,omitempty"` - Environments string `json:"environments,omitempty"` - Parameters string `json:"parameters,omitempty"` + //ModelName string `json:"ModelName,omitempty,optional"` + //AlgorithmName string `json:"algorithmName,omitempty,optional"` + //Strategy string `json:"strategy,omitempty"` + //Envs []string `json:"envs,omitempty"` + //Params []string `json:"params,omitempty"` - Name string `json:"name,omitempty"` - Status string `json:"status,omitempty"` - StartTime string `json:"startTime,omitempty"` - RunningTime int64 `json:"runningTime,omitempty"` - Result string `json:"result,omitempty"` - JobId string `json:"jobId,omitempty"` + //SpecId string `json:"specId,omitempty"` + //DatasetsId string `json:"datasetsId,omitempty"` + //CodeId string `json:"codeId,omitempty"` + //Result string `json:"result,omitempty"` - Datasets string `json:"datasets,omitempty"` - AlgorithmCode string `json:"algorithmCode,omitempty"` - Image string `json:"image,omitempty"` + //Datasets string `json:"datasets,omitempty"` + //AlgorithmCode string `json:"algorithmCode,omitempty"` + //Image string `json:"image,omitempty"` - CreateTime string `json:"createTime,omitempty"` - ImageUrl string `json:"imageUrl,omitempty"` - Command string `json:"command,omitempty"` - FlavorId string `json:"flavorId,omitempty"` - SubscriptionId string `json:"subscriptionId,omitempty"` - ItemVersionId string `json:"itemVersionId,omitempty"` - ObsUrl string `json:"obsUrl,omitempty"` + //CreateTime string `json:"createTime,omitempty"` + //ImageUrl string `json:"imageUrl,omitempty"` + + //SubscriptionId string `json:"subscriptionId,omitempty"` + //ItemVersionId string `json:"itemVersionId,omitempty"` + //ObsUrl string `json:"obsUrl,omitempty"` } type VmInfo struct { diff --git a/internal/logic/core/pulltaskinfologic.go b/internal/logic/core/pulltaskinfologic.go index d13e1ca0..671e0770 100644 --- a/internal/logic/core/pulltaskinfologic.go +++ b/internal/logic/core/pulltaskinfologic.go @@ -74,45 +74,12 @@ func (l *PullTaskInfoLogic) PullTaskInfo(req *clientCore.PullTaskInfoReq) (*clie } case 1: - var aiModelList []models.TaskAiModelarts + var aiModelList []models.TaskAiAsynchronous err := findModelList(req.AdapterId, l.svcCtx.DbEngin, &aiModelList) if err != nil { return nil, err } utils.Convert(aiModelList, &resp.AiInfoList) - /*if len(resp.AiInfoList) > 0 { - for i, aiInfo := range aiModelList { - if resp.AiInfoList[i].Environments != "" { - // 定义一个map来存储解析后的JSON数据 - var result map[string]interface{} - // 解析JSON字符串 - err := json.Unmarshal([]byte(resp.AiInfoList[i].Environments), &result) - if err != nil { - log.Fatalf("Error parsing JSON: %v", err) - } - // 如果你需要将解析后的map再次转换为JSON字符串,可以使用json.MarshalIndent - formattedJSON, err := json.MarshalIndent(result, "", " ") - aiInfo.Environments = string(formattedJSON) - fmt.Println(aiInfo.Environments) - resp.AiInfoList[i].Environments = aiInfo.Environments - } - if resp.AiInfoList[i].Parameters != "" { - // 定义一个map来存储解析后的JSON数据 - var result []interface{} - // 解析JSON字符串 - err := json.Unmarshal([]byte(resp.AiInfoList[i].Parameters), &result) - if err != nil { - log.Fatalf("Error parsing JSON: %v", err) - } - // 如果你需要将解析后的map再次转换为JSON字符串,可以使用json.MarshalIndent - formattedJSON, err := json.MarshalIndent(result, "", " ") - aiInfo.Parameters = string(formattedJSON) - fmt.Println(aiInfo.Parameters) - resp.AiInfoList[i].Parameters = aiInfo.Parameters - } - - } - }*/ } return &resp, nil } diff --git a/internal/logic/core/pushtaskinfologic.go b/internal/logic/core/pushtaskinfologic.go index 7e0f28d5..bd028241 100644 --- a/internal/logic/core/pushtaskinfologic.go +++ b/internal/logic/core/pushtaskinfologic.go @@ -89,8 +89,8 @@ func (l *PushTaskInfoLogic) PushTaskInfo(req *clientCore.PushTaskInfoReq) (*clie } case 1: for _, aiInfo := range req.AiInfoList { - l.svcCtx.DbEngin.Exec("update task_ai set status = ?,start_time = ?,project_id = ?,job_id = ? where participant_id = ? and task_id = ? and name = ?", - aiInfo.Status, aiInfo.StartTime, aiInfo.ProjectId, aiInfo.JobId, req.AdapterId, aiInfo.TaskId, aiInfo.Name) + l.svcCtx.DbEngin.Exec("update task_ai_asynchronous set status = ?,start_time = ?,job_id = ? where cluster_id = ? and task_id = ? and name = ?", + aiInfo.Status, aiInfo.StartTime, aiInfo.JobId, aiInfo.ClusterId, aiInfo.TaskId, aiInfo.Name) noticeInfo := clientCore.NoticeInfo{ TaskId: aiInfo.TaskId, AdapterId: aiInfo.AdapterId, diff --git a/internal/storeLink/modelarts.go b/internal/storeLink/modelarts.go index 4ce605b7..86fedfe1 100644 --- a/internal/storeLink/modelarts.go +++ b/internal/storeLink/modelarts.go @@ -22,14 +22,12 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" - "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils/timeutils" "gitlink.org.cn/JointCloud/pcm-modelarts/client/imagesservice" "gitlink.org.cn/JointCloud/pcm-modelarts/client/modelartsservice" "gitlink.org.cn/JointCloud/pcm-modelarts/modelarts" modelartsclient "gitlink.org.cn/JointCloud/pcm-modelarts/modelarts" - "gorm.io/gorm" "log" "mime/multipart" "strconv" @@ -56,7 +54,6 @@ type ModelArtsLink struct { Version string ModelId string ModelType string - DbEngin *gorm.DB } // Version 结构体表示版本号 @@ -174,20 +171,9 @@ func (m *ModelArtsLink) SubmitTask(ctx context.Context, imageId string, cmd stri NodeCount: 1, }, }, - //Platform: m.platform, - Platform: "modelarts-CloudBrain2", + Platform: m.platform, } resp, err := m.modelArtsRpc.CreateTrainingJob(ctx, req) - aiModelarts := models.TaskAiModelarts{} - aiModelarts.ImageId = imageId - aiModelarts.FlavorId = resourceId - aiModelarts.Cmd = cmd - //aiModelarts.TaskId = - tx := m.DbEngin.Table("task_ai_modelarts").Create(&aiModelarts) - if tx.Error != nil { - return tx.Error, nil - } - if err != nil { return nil, err } diff --git a/pkg/models/taskaiasynchronousmodel.go b/pkg/models/taskaiasynchronousmodel.go new file mode 100644 index 00000000..8943c365 --- /dev/null +++ b/pkg/models/taskaiasynchronousmodel.go @@ -0,0 +1,29 @@ +package models + +import "github.com/zeromicro/go-zero/core/stores/sqlx" + +var _ TaskAiAsynchronousModel = (*customTaskAiAsynchronousModel)(nil) + +type ( + // TaskAiAsynchronousModel is an interface to be customized, add more methods here, + // and implement the added methods in customTaskAiAsynchronousModel. + TaskAiAsynchronousModel interface { + taskAiAsynchronousModel + withSession(session sqlx.Session) TaskAiAsynchronousModel + } + + customTaskAiAsynchronousModel struct { + *defaultTaskAiAsynchronousModel + } +) + +// NewTaskAiAsynchronousModel returns a model for the database table. +func NewTaskAiAsynchronousModel(conn sqlx.SqlConn) TaskAiAsynchronousModel { + return &customTaskAiAsynchronousModel{ + defaultTaskAiAsynchronousModel: newTaskAiAsynchronousModel(conn), + } +} + +func (m *customTaskAiAsynchronousModel) withSession(session sqlx.Session) TaskAiAsynchronousModel { + return NewTaskAiAsynchronousModel(sqlx.NewSqlConnFromSession(session)) +} diff --git a/pkg/models/taskaiasynchronousmodel_gen.go b/pkg/models/taskaiasynchronousmodel_gen.go new file mode 100644 index 00000000..cca85a77 --- /dev/null +++ b/pkg/models/taskaiasynchronousmodel_gen.go @@ -0,0 +1,98 @@ +// Code generated by goctl. DO NOT EDIT. + +package models + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/zeromicro/go-zero/core/stores/builder" + "github.com/zeromicro/go-zero/core/stores/sqlx" + "github.com/zeromicro/go-zero/core/stringx" +) + +var ( + taskAiAsynchronousFieldNames = builder.RawFieldNames(&TaskAiAsynchronous{}) + taskAiAsynchronousRows = strings.Join(taskAiAsynchronousFieldNames, ",") + taskAiAsynchronousRowsExpectAutoSet = strings.Join(stringx.Remove(taskAiAsynchronousFieldNames, "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), ",") + taskAiAsynchronousRowsWithPlaceHolder = strings.Join(stringx.Remove(taskAiAsynchronousFieldNames, "`id`", "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), "=?,") + "=?" +) + +type ( + taskAiAsynchronousModel interface { + Insert(ctx context.Context, data *TaskAiAsynchronous) (sql.Result, error) + FindOne(ctx context.Context, id int64) (*TaskAiAsynchronous, error) + Update(ctx context.Context, data *TaskAiAsynchronous) error + Delete(ctx context.Context, id int64) error + } + + defaultTaskAiAsynchronousModel struct { + conn sqlx.SqlConn + table string + } + + TaskAiAsynchronous struct { + Id int64 `db:"id"` // 训练作业资源规格id + TaskId int64 `db:"task_id"` // 任务id + AdapterId int64 `db:"adapter_id"` // 执行任务的适配器id + AdapterName string `db:"adapter_name"` // 适配器名称 + ClusterId int64 `db:"cluster_id"` // 集群id + ClusterName string `db:"cluster_name"` // 集群名称 + Name string `db:"name"` // 任务名 + Replica int64 `db:"replica"` // 执行数 + JobId string `db:"job_id"` // 集群返回任务id + StartTime string `db:"start_time"` // 开始时间 + RunningTime string `db:"running_time"` // 运行时间 + Result string `db:"result"` // 运行结果 + DeletedAt string `db:"deleted_at"` // 删除时间 + ImageId string `db:"image_id"` // 镜像id + Cmd string `db:"cmd"` // 命令行 + FlavorId string `db:"flavor_id"` // 训练作业资源规格id + Status string `db:"status"` // 任务状态 + } +) + +func newTaskAiAsynchronousModel(conn sqlx.SqlConn) *defaultTaskAiAsynchronousModel { + return &defaultTaskAiAsynchronousModel{ + conn: conn, + table: "`task_ai_asynchronous`", + } +} + +func (m *defaultTaskAiAsynchronousModel) Delete(ctx context.Context, id int64) error { + query := fmt.Sprintf("delete from %s where `id` = ?", m.table) + _, err := m.conn.ExecCtx(ctx, query, id) + return err +} + +func (m *defaultTaskAiAsynchronousModel) FindOne(ctx context.Context, id int64) (*TaskAiAsynchronous, error) { + query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", taskAiAsynchronousRows, m.table) + var resp TaskAiAsynchronous + err := m.conn.QueryRowCtx(ctx, &resp, query, id) + switch err { + case nil: + return &resp, nil + case sqlx.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +func (m *defaultTaskAiAsynchronousModel) Insert(ctx context.Context, data *TaskAiAsynchronous) (sql.Result, error) { + query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", m.table, taskAiAsynchronousRowsExpectAutoSet) + ret, err := m.conn.ExecCtx(ctx, query, data.Id, data.TaskId, data.AdapterId, data.AdapterName, data.ClusterId, data.ClusterName, data.Name, data.Replica, data.JobId, data.StartTime, data.RunningTime, data.Result, data.DeletedAt, data.ImageId, data.Cmd, data.FlavorId, data.Status) + return ret, err +} + +func (m *defaultTaskAiAsynchronousModel) Update(ctx context.Context, data *TaskAiAsynchronous) error { + query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, taskAiAsynchronousRowsWithPlaceHolder) + _, err := m.conn.ExecCtx(ctx, query, data.TaskId, data.AdapterId, data.AdapterName, data.ClusterId, data.ClusterName, data.Name, data.Replica, data.JobId, data.StartTime, data.RunningTime, data.Result, data.DeletedAt, data.ImageId, data.Cmd, data.FlavorId, data.Status, data.Id) + return err +} + +func (m *defaultTaskAiAsynchronousModel) tableName() string { + return m.table +}