From b7a8602155675e0f30e1733f62099b0eef3fadca Mon Sep 17 00:00:00 2001 From: qiwang <1364512070@qq.com> Date: Wed, 23 Oct 2024 14:25:15 +0800 Subject: [PATCH] fix: add task_ai_modelarts Former-commit-id: 9e9675e3bfa20f095abb9bd7424c9e382041ccf0 --- internal/logic/core/pulltaskinfologic.go | 2 +- internal/storeLink/modelarts.go | 19 +++-- pkg/models/taskaimodelartsmodel.go | 29 +++++++ pkg/models/taskaimodelartsmodel_gen.go | 98 ++++++++++++++++++++++++ 4 files changed, 139 insertions(+), 9 deletions(-) create mode 100644 pkg/models/taskaimodelartsmodel.go create mode 100644 pkg/models/taskaimodelartsmodel_gen.go diff --git a/internal/logic/core/pulltaskinfologic.go b/internal/logic/core/pulltaskinfologic.go index a727f766..d13e1ca0 100644 --- a/internal/logic/core/pulltaskinfologic.go +++ b/internal/logic/core/pulltaskinfologic.go @@ -74,7 +74,7 @@ func (l *PullTaskInfoLogic) PullTaskInfo(req *clientCore.PullTaskInfoReq) (*clie } case 1: - var aiModelList []models.TaskAi + var aiModelList []models.TaskAiModelarts err := findModelList(req.AdapterId, l.svcCtx.DbEngin, &aiModelList) if err != nil { return nil, err diff --git a/internal/storeLink/modelarts.go b/internal/storeLink/modelarts.go index 29f63753..4ce605b7 100644 --- a/internal/storeLink/modelarts.go +++ b/internal/storeLink/modelarts.go @@ -22,6 +22,7 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils/timeutils" "gitlink.org.cn/JointCloud/pcm-modelarts/client/imagesservice" @@ -63,10 +64,6 @@ type Version struct { Major, Minor, Patch int } -type AiStorage struct { - DbEngin *gorm.DB -} - // ParseVersion 从字符串解析版本号 func ParseVersion(versionStr string) (*Version, error) { parts := strings.Split(versionStr, ".") @@ -177,13 +174,19 @@ func (m *ModelArtsLink) SubmitTask(ctx context.Context, imageId string, cmd stri NodeCount: 1, }, }, - Platform: m.platform, + //Platform: m.platform, + Platform: "modelarts-CloudBrain2", } resp, err := m.modelArtsRpc.CreateTrainingJob(ctx, req) - //tx := m.DbEngin.Create(adapterId) - /*if tx.Error != nil { + aiModelarts := models.TaskAiModelarts{} + aiModelarts.ImageId = imageId + aiModelarts.FlavorId = resourceId + aiModelarts.Cmd = cmd + //aiModelarts.TaskId = + tx := m.DbEngin.Table("task_ai_modelarts").Create(&aiModelarts) + if tx.Error != nil { return tx.Error, nil - }*/ + } if err != nil { return nil, err diff --git a/pkg/models/taskaimodelartsmodel.go b/pkg/models/taskaimodelartsmodel.go new file mode 100644 index 00000000..e225f9ae --- /dev/null +++ b/pkg/models/taskaimodelartsmodel.go @@ -0,0 +1,29 @@ +package models + +import "github.com/zeromicro/go-zero/core/stores/sqlx" + +var _ TaskAiModelartsModel = (*customTaskAiModelartsModel)(nil) + +type ( + // TaskAiModelartsModel is an interface to be customized, add more methods here, + // and implement the added methods in customTaskAiModelartsModel. + TaskAiModelartsModel interface { + taskAiModelartsModel + withSession(session sqlx.Session) TaskAiModelartsModel + } + + customTaskAiModelartsModel struct { + *defaultTaskAiModelartsModel + } +) + +// NewTaskAiModelartsModel returns a model for the database table. +func NewTaskAiModelartsModel(conn sqlx.SqlConn) TaskAiModelartsModel { + return &customTaskAiModelartsModel{ + defaultTaskAiModelartsModel: newTaskAiModelartsModel(conn), + } +} + +func (m *customTaskAiModelartsModel) withSession(session sqlx.Session) TaskAiModelartsModel { + return NewTaskAiModelartsModel(sqlx.NewSqlConnFromSession(session)) +} diff --git a/pkg/models/taskaimodelartsmodel_gen.go b/pkg/models/taskaimodelartsmodel_gen.go new file mode 100644 index 00000000..244604bd --- /dev/null +++ b/pkg/models/taskaimodelartsmodel_gen.go @@ -0,0 +1,98 @@ +// Code generated by goctl. DO NOT EDIT. + +package models + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/zeromicro/go-zero/core/stores/builder" + "github.com/zeromicro/go-zero/core/stores/sqlx" + "github.com/zeromicro/go-zero/core/stringx" +) + +var ( + taskAiModelartsFieldNames = builder.RawFieldNames(&TaskAiModelarts{}) + taskAiModelartsRows = strings.Join(taskAiModelartsFieldNames, ",") + taskAiModelartsRowsExpectAutoSet = strings.Join(stringx.Remove(taskAiModelartsFieldNames, "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), ",") + taskAiModelartsRowsWithPlaceHolder = strings.Join(stringx.Remove(taskAiModelartsFieldNames, "`id`", "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), "=?,") + "=?" +) + +type ( + taskAiModelartsModel interface { + Insert(ctx context.Context, data *TaskAiModelarts) (sql.Result, error) + FindOne(ctx context.Context, id int64) (*TaskAiModelarts, error) + Update(ctx context.Context, data *TaskAiModelarts) error + Delete(ctx context.Context, id int64) error + } + + defaultTaskAiModelartsModel struct { + conn sqlx.SqlConn + table string + } + + TaskAiModelarts struct { + Id int64 `db:"id"` // 训练作业资源规格id + TaskId int64 `db:"task_id"` // 任务id + AdapterId int64 `db:"adapter_id"` // 执行任务的适配器id + AdapterName string `db:"adapter_name"` // 适配器名称 + ClusterId int64 `db:"cluster_id"` // 集群id + ClusterName string `db:"cluster_name"` // 集群名称 + Name string `db:"name"` // 任务名 + Replica int64 `db:"replica"` // 执行数 + JobId string `db:"job_id"` // 集群返回任务id + StartTime string `db:"start_time"` // 开始时间 + RunningTime string `db:"running_time"` // 运行时间 + Result string `db:"result"` // 运行结果 + DeletedAt string `db:"deleted_at"` // 删除时间 + ImageId string `db:"image_id"` // 镜像id + Cmd string `db:"cmd"` // 命令行 + FlavorId string `db:"flavor_id"` // 训练作业资源规格id + Status string `db:"status"` // 任务状态 + } +) + +func newTaskAiModelartsModel(conn sqlx.SqlConn) *defaultTaskAiModelartsModel { + return &defaultTaskAiModelartsModel{ + conn: conn, + table: "`task_ai_modelarts`", + } +} + +func (m *defaultTaskAiModelartsModel) Delete(ctx context.Context, id int64) error { + query := fmt.Sprintf("delete from %s where `id` = ?", m.table) + _, err := m.conn.ExecCtx(ctx, query, id) + return err +} + +func (m *defaultTaskAiModelartsModel) FindOne(ctx context.Context, id int64) (*TaskAiModelarts, error) { + query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", taskAiModelartsRows, m.table) + var resp TaskAiModelarts + err := m.conn.QueryRowCtx(ctx, &resp, query, id) + switch err { + case nil: + return &resp, nil + case sqlx.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +func (m *defaultTaskAiModelartsModel) Insert(ctx context.Context, data *TaskAiModelarts) (sql.Result, error) { + query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", m.table, taskAiModelartsRowsExpectAutoSet) + ret, err := m.conn.ExecCtx(ctx, query, data.Id, data.TaskId, data.AdapterId, data.AdapterName, data.ClusterId, data.ClusterName, data.Name, data.Replica, data.JobId, data.StartTime, data.RunningTime, data.Result, data.DeletedAt, data.ImageId, data.Cmd, data.FlavorId, data.Status) + return ret, err +} + +func (m *defaultTaskAiModelartsModel) Update(ctx context.Context, data *TaskAiModelarts) error { + query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, taskAiModelartsRowsWithPlaceHolder) + _, err := m.conn.ExecCtx(ctx, query, data.TaskId, data.AdapterId, data.AdapterName, data.ClusterId, data.ClusterName, data.Name, data.Replica, data.JobId, data.StartTime, data.RunningTime, data.Result, data.DeletedAt, data.ImageId, data.Cmd, data.FlavorId, data.Status, data.Id) + return err +} + +func (m *defaultTaskAiModelartsModel) tableName() string { + return m.table +}