diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 2a06f78d..3eec7d69 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -114,6 +114,9 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere return resp, nil } +var acs []*strategy.AssignedCluster +var aiTaskList []*models.TaskAi + func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct { imageResult *types.ImageResult file multipart.File @@ -211,7 +214,6 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s cs = append(cs, s) } - var aiTaskList []*models.TaskAi tx := svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList) if tx.Error != nil { return nil, tx.Error @@ -219,7 +221,6 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s } //change cluster status if len(clusters) != len(cs) { - var acs []*strategy.AssignedCluster for _, cluster := range clusters { if contains(cs, cluster.ClusterId) { continue @@ -261,7 +262,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s imageNumIdx = imageNumIdx + c.imageNum wg.Add(len(new_images)) - go sendInferReq(new_images, c, &wg, result_ch) + go sendInferReq(new_images, c, &wg, *svcCtx, result_ch) } wg.Wait() close(result_ch) @@ -297,7 +298,7 @@ func sendInferReq(images []struct { clusterId string clusterName string imageNum int32 -}, wg *sync.WaitGroup, ch chan<- *types.ImageResult) { +}, wg *sync.WaitGroup, svcCtx svc.ServiceContext, ch chan<- *types.ImageResult) { for _, image := range images { go func(t struct { imageResult *types.ImageResult @@ -339,7 +340,27 @@ func sendInferReq(images []struct { t.imageResult.ImageResult = r t.imageResult.ClusterName = c.clusterName t.imageResult.Card = c.urls[idx].Card + for _, ac := range acs { + for _, task := range aiTaskList { + if ac.ClusterId == strconv.Itoa(int(task.ClusterId)) && ac.ClusterId == t.imageResult.ClusterId { + taskAiSub := &models.TaskAiSub{ + Id: task.Id, + ImageName: t.imageResult.ImageName, + Result: t.imageResult.ImageResult, + Card: t.imageResult.Card, + ClusterId: task.ClusterId, + ClusterName: t.imageResult.ClusterName, + } + tx := svcCtx.DbEngin.Save(&taskAiSub) + if tx.Error != nil { + logx.Errorf(err.Error()) + } + } + continue + } + continue + } ch <- t.imageResult wg.Done() return diff --git a/pkg/models/taskaisubmodel.go b/pkg/models/taskaisubmodel.go new file mode 100644 index 00000000..207772c6 --- /dev/null +++ b/pkg/models/taskaisubmodel.go @@ -0,0 +1,29 @@ +package models + +import "github.com/zeromicro/go-zero/core/stores/sqlx" + +var _ TaskAiSubModel = (*customTaskAiSubModel)(nil) + +type ( + // TaskAiSubModel is an interface to be customized, add more methods here, + // and implement the added methods in customTaskAiSubModel. + TaskAiSubModel interface { + taskAiSubModel + withSession(session sqlx.Session) TaskAiSubModel + } + + customTaskAiSubModel struct { + *defaultTaskAiSubModel + } +) + +// NewTaskAiSubModel returns a model for the database table. +func NewTaskAiSubModel(conn sqlx.SqlConn) TaskAiSubModel { + return &customTaskAiSubModel{ + defaultTaskAiSubModel: newTaskAiSubModel(conn), + } +} + +func (m *customTaskAiSubModel) withSession(session sqlx.Session) TaskAiSubModel { + return NewTaskAiSubModel(sqlx.NewSqlConnFromSession(session)) +} diff --git a/pkg/models/taskaisubmodel_gen.go b/pkg/models/taskaisubmodel_gen.go new file mode 100644 index 00000000..53e48575 --- /dev/null +++ b/pkg/models/taskaisubmodel_gen.go @@ -0,0 +1,88 @@ +// Code generated by goctl. DO NOT EDIT. + +package models + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/zeromicro/go-zero/core/stores/builder" + "github.com/zeromicro/go-zero/core/stores/sqlc" + "github.com/zeromicro/go-zero/core/stores/sqlx" + "github.com/zeromicro/go-zero/core/stringx" +) + +var ( + taskAiSubFieldNames = builder.RawFieldNames(&TaskAiSub{}) + taskAiSubRows = strings.Join(taskAiSubFieldNames, ",") + taskAiSubRowsExpectAutoSet = strings.Join(stringx.Remove(taskAiSubFieldNames, "`id`", "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), ",") + taskAiSubRowsWithPlaceHolder = strings.Join(stringx.Remove(taskAiSubFieldNames, "`id`", "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), "=?,") + "=?" +) + +type ( + taskAiSubModel interface { + Insert(ctx context.Context, data *TaskAiSub) (sql.Result, error) + FindOne(ctx context.Context, id int64) (*TaskAiSub, error) + Update(ctx context.Context, data *TaskAiSub) error + Delete(ctx context.Context, id int64) error + } + + defaultTaskAiSubModel struct { + conn sqlx.SqlConn + table string + } + + TaskAiSub struct { + Id int64 `db:"id"` // id + ImageName string `db:"image_name"` // 图片名称 + Result string `db:"result"` // 识别结果 + Card string `db:"card"` // 加速卡 + ClusterId int64 `db:"cluster_id"` // 集群id + ClusterName string `db:"cluster_name"` // 集群名称 + } +) + +func newTaskAiSubModel(conn sqlx.SqlConn) *defaultTaskAiSubModel { + return &defaultTaskAiSubModel{ + conn: conn, + table: "`task_ai_sub`", + } +} + +func (m *defaultTaskAiSubModel) Delete(ctx context.Context, id int64) error { + query := fmt.Sprintf("delete from %s where `id` = ?", m.table) + _, err := m.conn.ExecCtx(ctx, query, id) + return err +} + +func (m *defaultTaskAiSubModel) FindOne(ctx context.Context, id int64) (*TaskAiSub, error) { + query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", taskAiSubRows, m.table) + var resp TaskAiSub + err := m.conn.QueryRowCtx(ctx, &resp, query, id) + switch err { + case nil: + return &resp, nil + case sqlc.ErrNotFound: + return nil, ErrNotFound + default: + return nil, err + } +} + +func (m *defaultTaskAiSubModel) Insert(ctx context.Context, data *TaskAiSub) (sql.Result, error) { + query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?)", m.table, taskAiSubRowsExpectAutoSet) + ret, err := m.conn.ExecCtx(ctx, query, data.ImageName, data.Result, data.Card, data.ClusterId, data.ClusterName) + return ret, err +} + +func (m *defaultTaskAiSubModel) Update(ctx context.Context, data *TaskAiSub) error { + query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, taskAiSubRowsWithPlaceHolder) + _, err := m.conn.ExecCtx(ctx, query, data.ImageName, data.Result, data.Card, data.ClusterId, data.ClusterName, data.Id) + return err +} + +func (m *defaultTaskAiSubModel) tableName() string { + return m.table +}