Skip to content

Scheduler Execution

When using DynamicScheduler directly, the external contract is to construct the scheduler and call execute(). Other scheduler helper methods are documented as internal APIs.

astrum.scheduler.DynamicScheduler

DynamicScheduler(tasks: list[tuple[str, TaskCallable]], task_order: list[TaskOrder], task_data_refs: dict[str, TaskData] | None = None, has_data_path: bool = False, ignore_tail_task: list[str] | None = None, concurrency_context: Semaphore | None = None, *, task_retries: dict[str, int] | None = None, silence: bool = True)

Async DAG scheduler for predeclared task orders.

每个任务以 (task_id, callable) 的形式传入;调度器在调度期才真正调用 callable,从而支持基于 :class:TaskData 的动态参数装配。不再支持 将"已 call 但未 run 的协程对象"作为 task;这种用法会在 __init__ 阶段直接抛 :class:TypeError

Each task is passed in as (task_id, callable); the scheduler only actually calls the callable during scheduling, which supports dynamic parameter assembly based on :class:TaskData. Passing an "already called but not yet run coroutine object" as a task is no longer supported; that usage raises :class:TypeError directly during __init__.

源代码位于: src/astrum/scheduler.py
def __init__(
    self,
    tasks: list[tuple[str, TaskCallable]],
    task_order: list[TaskOrder],
    task_data_refs: dict[str, TaskData] | None = None,
    has_data_path: bool = False,
    ignore_tail_task: list[str] | None = None,
    concurrency_context: asyncio.Semaphore | None = None,
    *,
    task_retries: dict[str, int] | None = None,
    silence: bool = True,
) -> None:
    self.ignore_tail_task = ignore_tail_task or []
    self.task_data_refs = task_data_refs or {}
    self.has_data_path = has_data_path
    self.silence = silence
    self._validate_task_inputs(tasks)
    self.task_retries = task_retries or {}

    self.tasks = tasks
    self.task_order = task_order
    self.concurrency_context = concurrency_context

    self.execution_state = ExecutionState.NONE
    self.stage_status: StageStatus | None = None
    self.current_stage = -1
    self._task_outputs: dict[str, Any] = {}
    self._task_attempts: dict[str, list[TaskAttemptStatistics]] = {}
    self.task_return_set: dict[str, Any] = {}

execute async

execute() -> ExecutionReport
源代码位于: src/astrum/scheduler.py
async def execute(self) -> ExecutionReport:
    total_start_time = time.time()
    planning_start = time.time()
    self.current_stage = -1
    self.execution_state = ExecutionState.PLANNING
    self._task_attempts = {}

    try:
        plan = self.get_execute_timeline()
    except Exception as exc:
        self.execution_state = ExecutionState.FAILED
        return self._create_failed_report(total_start_time, time.time(), planning_start, str(exc))

    planning_end = time.time()
    execution_start = time.time()
    self.execution_state = ExecutionState.RUNNING

    stage_statistics: list[TaskStageStatistics] = []
    task_statistics: list[TaskExecutionStatistics] = []
    error_summary: list[str] = []
    task_map: dict[str, asyncio.Task[Any]] = {}
    task_start_times: dict[str, float] = {}
    completed_tasks_recorded: set[str] = set()

    execution_failed = False
    failed_task_name: str | None = None
    first_error: BaseException | None = None

    try:
        for stage in plan.stages:
            stage_start_time = time.time()
            self.current_stage = stage.stage_id
            self.stage_status = StageStatus.PENDING

            for task_name in stage.wait_for_tasks:
                if task_name not in task_map:
                    continue

                try:
                    self._task_outputs[task_name] = await task_map[task_name]
                    self._record_task_stat(
                        task_statistics,
                        completed_tasks_recorded,
                        task_name,
                        stage.stage_id,
                        task_start_times[task_name],
                        "completed",
                    )
                except Exception as exc:
                    self._record_task_stat(
                        task_statistics,
                        completed_tasks_recorded,
                        task_name,
                        stage.stage_id,
                        task_start_times[task_name],
                        "failed",
                        str(exc),
                    )
                    error_summary.append(f"Task '{task_name}' failed during execution: {exc}")
                    execution_failed = True
                    failed_task_name = task_name
                    first_error = exc
                    self.execution_state = ExecutionState.FAILED
                    self.stage_status = StageStatus.FAILED
                    raise

            self.stage_status = StageStatus.RUNNING
            for task_name in stage.start_tasks:
                if task_name in task_map:
                    raise TaskDuplicateExecutionError(task_name)

                task_callable = self.find_task_by_name(task_name)
                if task_callable is None:
                    raise TaskNotFoundError(task_name)

                task_start_times[task_name] = time.time()
                task_map[task_name] = asyncio.create_task(self._run_with_concurrency(self._invoke_task_with_retries(task_name, task_callable, error_summary)))

            self.stage_status = StageStatus.COMPLETED
            stage_end_time = time.time()
            stage_statistics.append(
                TaskStageStatistics(
                    stage_id=stage.stage_id,
                    stage_name=f"Stage_{stage.stage_id}",
                    start_time=stage_start_time,
                    end_time=stage_end_time,
                    duration=stage_end_time - stage_start_time,
                    parallel_task_count=len(stage.parallel_tasks),
                    wait_task_count=len(stage.wait_for_tasks),
                    parallel_tasks=stage.parallel_tasks.copy(),
                    wait_tasks=stage.wait_for_tasks.copy(),
                )
            )

        if not execution_failed:
            self.execution_state = ExecutionState.COMPLETED

    except Exception as exc:
        if not execution_failed:
            self.execution_state = ExecutionState.FAILED
            error_summary.append(f"Execution failed: {exc}")
            execution_failed = True
            first_error = exc

    tail_tasks: dict[str, asyncio.Task[Any]] = {}
    for task_name, task in task_map.items():
        if task_name in self.ignore_tail_task:
            tail_tasks[task_name] = task
            continue

        if task_name in completed_tasks_recorded:
            continue

        if execution_failed:
            task.cancel()
            try:
                await task
            except asyncio.CancelledError:
                self._record_task_stat(
                    task_statistics,
                    completed_tasks_recorded,
                    task_name,
                    -1,
                    task_start_times[task_name],
                    "cancelled",
                    "Task cancelled due to previous failure",
                )
            except Exception as exc:
                self._record_task_stat(
                    task_statistics,
                    completed_tasks_recorded,
                    task_name,
                    -1,
                    task_start_times[task_name],
                    "failed",
                    str(exc),
                )
        else:
            try:
                self._task_outputs[task_name] = await task
                self._record_task_stat(
                    task_statistics,
                    completed_tasks_recorded,
                    task_name,
                    -1,
                    task_start_times[task_name],
                    "completed",
                )
            except Exception as exc:
                error_summary.append(f"Final task '{task_name}' failed: {exc}")
                self._record_task_stat(
                    task_statistics,
                    completed_tasks_recorded,
                    task_name,
                    -1,
                    task_start_times[task_name],
                    "failed",
                    str(exc),
                )
                self.execution_state = ExecutionState.FAILED
                execution_failed = True
                if failed_task_name is None:
                    failed_task_name = task_name
                    first_error = exc

    if execution_failed and failed_task_name and first_error:
        error_summary.insert(0, f"Execution interrupted; task '{failed_task_name}' failed: {first_error}")

    total_end_time = time.time()
    successful_tasks = len([task for task in task_statistics if task.status == "completed"])
    failed_tasks = len([task for task in task_statistics if task.status == "failed"])

    return ExecutionReport(
        total_start_time=total_start_time,
        total_end_time=total_end_time,
        total_duration=total_end_time - total_start_time,
        planning_duration=planning_end - planning_start,
        execution_duration=total_end_time - execution_start,
        total_tasks=len(self.task_order),
        total_stages=len(plan.stages),
        max_parallelism=plan.max_parallelism,
        successful_tasks=successful_tasks,
        failed_tasks=failed_tasks,
        stage_statistics=stage_statistics,
        task_statistics=task_statistics,
        execution_state=self.execution_state.value,
        error_summary=error_summary,
        original_tasks=self._serialize_task_orders(),
        tail_tasks=tail_tasks,
        task_return_set=self.task_return_set,
    )