diff --git a/plugin/migration/pipeline.go b/plugin/migration/pipeline.go index 0818567c..49b76b0e 100644 --- a/plugin/migration/pipeline.go +++ b/plugin/migration/pipeline.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "math" + "sync" "time" log "github.com/cihub/seelog" @@ -32,7 +33,8 @@ type DispatcherProcessor struct { id string config *DispatcherConfig - state map[string]DispatcherState + state map[string]DispatcherState + stateLock sync.Mutex pipelineTaskProcessor migration_model.Processor } @@ -89,14 +91,8 @@ func newMigrationDispatcherProcessor(c *config.Config) (pipeline.Processor, erro processor := DispatcherProcessor{ id: util.GetUUID(), config: &cfg, - state: map[string]DispatcherState{}, } - state, err := processor.getInstanceTaskState() - if err != nil { - log.Errorf("failed to get instance task state, err: %v", err) - return nil, err - } - processor.state = state + processor.refreshInstanceJobsFromES() processor.pipelineTaskProcessor = pipeline_task.NewProcessor(cfg.Elasticsearch, cfg.IndexName, cfg.LogIndexName) return &processor, nil @@ -261,7 +257,9 @@ func (p *DispatcherProcessor) handlePendingStopMajorTask(taskItem *task2.Task) e if len(tasks) == 0 { taskItem.Status = task2.StatusStopped p.sendMajorTaskNotification(taskItem) - p.saveTaskAndWriteLog(taskItem, "", nil, fmt.Sprintf("task [%s] stopped", taskItem.ID)) + p.saveTaskAndWriteLog(taskItem, "wait_for", nil, fmt.Sprintf("task [%s] stopped", taskItem.ID)) + // NOTE: we don't know how many running index_migration's stopped, so do a refresh from ES + p.refreshInstanceJobsFromES() } return nil } @@ -656,7 +654,7 @@ func (p *DispatcherProcessor) handleScheduleSubTask(taskItem *task2.Task) error if err != nil { return fmt.Errorf("get preference intance error: %w", err) } - if p.state[instance.ID].Total >= p.config.MaxTasksPerInstance { + if p.getInstanceState(instance.ID).Total >= p.config.MaxTasksPerInstance { log.Debugf("hit max tasks per instance with %d, skip dispatch", p.config.MaxTasksPerInstance) return nil } @@ -715,7 +713,8 @@ func (p *DispatcherProcessor) getPreferenceInstance(majorTaskID string) (instanc tempInst = model.Instance{} ) for _, node := range cfg.Settings.Execution.Nodes.Permit { - if p.state[node.ID].Total < total { + instanceTotal := p.getInstanceState(node.ID).Total + if instanceTotal < total { if p.config.CheckInstanceAvailable { tempInst.ID = node.ID _, err = orm.Get(&tempInst) @@ -730,7 +729,7 @@ func (p *DispatcherProcessor) getPreferenceInstance(majorTaskID string) (instanc } } instance.ID = node.ID - total = p.state[node.ID].Total + total = instanceTotal } } if instance.ID == "" && p.config.CheckInstanceAvailable { @@ -1289,6 +1288,8 @@ func (p *DispatcherProcessor) getScrollBulkPipelineTasks(taskItem *task2.Task) ( } func (p *DispatcherProcessor) decrInstanceJobs(instanceID string) { + p.stateLock.Lock() + defer p.stateLock.Unlock() if st, ok := p.state[instanceID]; ok { st.Total -= 1 p.state[instanceID] = st @@ -1296,15 +1297,40 @@ func (p *DispatcherProcessor) decrInstanceJobs(instanceID string) { } func (p *DispatcherProcessor) incrInstanceJobs(instanceID string) { + p.stateLock.Lock() + defer p.stateLock.Unlock() instanceState := p.state[instanceID] instanceState.Total = instanceState.Total + 1 p.state[instanceID] = instanceState } +func (p *DispatcherProcessor) getInstanceState(instanceID string) DispatcherState { + p.stateLock.Lock() + defer p.stateLock.Unlock() + + return p.state[instanceID] +} + +func (p *DispatcherProcessor) refreshInstanceJobsFromES() { + log.Debug("refreshing instance state from ES") + p.stateLock.Lock() + defer p.stateLock.Unlock() + + state, err := p.getInstanceTaskState() + if err != nil { + log.Errorf("failed to get instance task state, err: %v", err) + return + } + p.state = state +} + func (p *DispatcherProcessor) cleanGatewayQueue(taskItem *task2.Task) { var err error instance := model.Instance{} instanceID := taskItem.Metadata.Labels["execution_instance_id"] + if instanceID == "" { + log.Debugf("task [%s] not scheduled yet, skip cleaning queue", taskItem.ID) + } instance.ID, _ = util.ExtractString(instanceID) _, err = orm.Get(&instance) if err != nil {