diff --git a/api/internal/logic/schedule/downloadalgothmcodelogic.go b/api/internal/logic/schedule/downloadalgothmcodelogic.go index 81b96579..03800fe7 100644 --- a/api/internal/logic/schedule/downloadalgothmcodelogic.go +++ b/api/internal/logic/schedule/downloadalgothmcodelogic.go @@ -2,6 +2,7 @@ package schedule import ( "context" + "strings" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" @@ -26,7 +27,7 @@ func NewDownloadAlgothmCodeLogic(ctx context.Context, svcCtx *svc.ServiceContext func (l *DownloadAlgothmCodeLogic) DownloadAlgorithmCode(req *types.DownloadAlgorithmCodeReq) (resp *types.DownloadAlgorithmCodeResp, err error) { resp = &types.DownloadAlgorithmCodeResp{} code, err := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[req.AdapterId][req.ClusterId].DownloadAlgorithmCode(l.ctx, - req.ResourceType, req.Card, req.TaskType, req.Dataset, req.Algorithm) + req.ResourceType, strings.ToLower(req.Card), req.TaskType, req.Dataset, req.Algorithm) if err != nil { return nil, err } diff --git a/api/internal/logic/schedule/getcomputecardsbyclusterlogic.go b/api/internal/logic/schedule/getcomputecardsbyclusterlogic.go index 772a5ce6..4cb94f91 100644 --- a/api/internal/logic/schedule/getcomputecardsbyclusterlogic.go +++ b/api/internal/logic/schedule/getcomputecardsbyclusterlogic.go @@ -24,7 +24,12 @@ func NewGetComputeCardsByClusterLogic(ctx context.Context, svcCtx *svc.ServiceCo } func (l *GetComputeCardsByClusterLogic) GetComputeCardsByCluster(req *types.GetComputeCardsByClusterReq) (resp *types.GetComputeCardsByClusterResp, err error) { - // todo: add your logic here and delete this line + resp = &types.GetComputeCardsByClusterResp{} + cards, err := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[req.AdapterId][req.ClusterId].GetComputeCards(l.ctx) + if err != nil { + return nil, err + } + resp.Cards = cards - return + return resp, nil } diff --git a/api/internal/scheduler/service/collector/collector.go b/api/internal/scheduler/service/collector/collector.go index 453d710c..5e6a7940 100644 --- a/api/internal/scheduler/service/collector/collector.go +++ b/api/internal/scheduler/service/collector/collector.go @@ -10,6 +10,7 @@ type AiCollector interface { GetTrainingTask(ctx context.Context, taskId string) (*Task, error) DownloadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string) (string, error) UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error + GetComputeCards(ctx context.Context) ([]string, error) } type ResourceStats struct { diff --git a/api/internal/storeLink/modelarts.go b/api/internal/storeLink/modelarts.go index 7bb6db2d..5843eeff 100644 --- a/api/internal/storeLink/modelarts.go +++ b/api/internal/storeLink/modelarts.go @@ -162,6 +162,10 @@ func (m *ModelArtsLink) GetAlgorithms(ctx context.Context) ([]*collector.Algorit return nil, nil } +func (m *ModelArtsLink) GetComputeCards(ctx context.Context) ([]string, error) { + return nil, nil +} + func (m *ModelArtsLink) DownloadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string) (string, error) { return "", nil } diff --git a/api/internal/storeLink/octopus.go b/api/internal/storeLink/octopus.go index a088e56a..3b1d2521 100644 --- a/api/internal/storeLink/octopus.go +++ b/api/internal/storeLink/octopus.go @@ -46,12 +46,14 @@ const ( SUIYUAN = "suiyuan" SAILINGSI = "sailingsi" MLU = "MLU" + BIV100 = "BI-V100" CAMBRICONMLU290 = 256 GCU = "GCU" ENFLAME = "enflame" EnflameT20 = 128 BASE_TOPS = 128 CAMBRICON = "cambricon" + TIANSHU = "天数" TRAIN_CMD = "cd /code; python train.py" VERSION = "V1" DOMAIN = "http://192.168.242.41:8001/" @@ -59,8 +61,9 @@ const ( var ( cardAliasMap = map[string]string{ - MLU: CAMBRICON, - GCU: ENFLAME, + MLU: CAMBRICON, + GCU: ENFLAME, + BIV100: TIANSHU, } cardTopsMap = map[string]float64{ MLU: CAMBRICONMLU290, @@ -340,11 +343,54 @@ func (o *OctopusLink) GetAlgorithms(ctx context.Context) ([]*collector.Algorithm return algorithms, nil } +func (o *OctopusLink) GetComputeCards(ctx context.Context) ([]string, error) { + var cards []string + for s, _ := range cardAliasMap { + cards = append(cards, s) + } + return cards, nil +} + func (o *OctopusLink) DownloadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string) (string, error) { + var name string + if resourceType == CARD { + name = dataset + UNDERSCORE + algorithm + UNDERSCORE + card + } else { + name = dataset + UNDERSCORE + algorithm + UNDERSCORE + CPU + } + + req := &octopus.GetMyAlgorithmListReq{ + Platform: o.platform, + PageIndex: o.pageIndex, + PageSize: o.pageSize, + } + resp, err := o.octopusRpc.GetMyAlgorithmList(ctx, req) + if err != nil { + return "", err + } + if !resp.Success { + return "", errors.New("failed to get algorithmList") + } + + var algorithmId string + for _, a := range resp.Payload.Algorithms { + if strings.ToLower(a.FrameworkName) != taskType { + continue + } + if a.AlgorithmName == name { + algorithmId = a.AlgorithmId + break + } + } + + if algorithmId == "" { + return "", errors.New("algorithmId not found") + } + dcReq := &octopus.DownloadCompressReq{ Platform: o.platform, Version: VERSION, - AlgorithmId: "", + AlgorithmId: algorithmId, } dcResp, err := o.octopusRpc.DownloadCompress(ctx, dcReq) if err != nil { @@ -358,7 +404,7 @@ func (o *OctopusLink) DownloadAlgorithmCode(ctx context.Context, resourceType st daReq := &octopus.DownloadAlgorithmReq{ Platform: o.platform, Version: VERSION, - AlgorithmId: "", + AlgorithmId: algorithmId, CompressAt: dcResp.Payload.CompressAt, Domain: DOMAIN, } @@ -591,16 +637,6 @@ func (o *OctopusLink) generateImageId(ctx context.Context, option *option.AiOpti } func (o *OctopusLink) generateAlgorithmId(ctx context.Context, option *option.AiOption) error { - // temporarily set algorithm to cnn - if option.AlgorithmName == "" { - switch option.DatasetsName { - case "cifar10": - option.AlgorithmName = "cnn" - case "mnist": - option.AlgorithmName = "fcn" - } - } - req := &octopus.GetMyAlgorithmListReq{ Platform: o.platform, PageIndex: o.pageIndex, diff --git a/api/internal/storeLink/shuguangai.go b/api/internal/storeLink/shuguangai.go index 4f783357..16811591 100644 --- a/api/internal/storeLink/shuguangai.go +++ b/api/internal/storeLink/shuguangai.go @@ -447,6 +447,12 @@ func (s *ShuguangAi) GetAlgorithms(ctx context.Context) ([]*collector.Algorithm, return algorithms, nil } +func (s *ShuguangAi) GetComputeCards(ctx context.Context) ([]string, error) { + var cards []string + cards = append(cards, DCU) + return cards, nil +} + func (s *ShuguangAi) DownloadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string) (string, error) { algoName := dataset + DASH + algorithm req := &hpcAC.GetFileReq{