跳转至

调度执行

直接使用 DynamicScheduler 时,外部 API 主要是构造 scheduler 并调用 execute()。调度器上的其他辅助方法归入内部 API。

astrum.scheduler.DynamicScheduler

DynamicScheduler(tasks: list[tuple[str, TaskCallable]], task_order: list[TaskOrder], task_data_refs: dict[str, TaskData] | None = None, has_data_path: bool = False, ignore_tail_task: list[str] | None = None, concurrency_context: Semaphore | None = None, *, task_retries: dict[str, int] | None = None, silence: bool = True)

Async DAG scheduler for predeclared task orders.

每个任务以 (task_id, callable) 的形式传入;调度器在调度期才真正调用 callable,从而支持基于 :class:TaskData 的动态参数装配。不再支持 将"已 call 但未 run 的协程对象"作为 task;这种用法会在 __init__ 阶段直接抛 :class:TypeError

Each task is passed in as (task_id, callable); the scheduler only actually calls the callable during scheduling, which supports dynamic parameter assembly based on :class:TaskData. Passing an "already called but not yet run coroutine object" as a task is no longer supported; that usage raises :class:TypeError directly during __init__.

源代码位于: src/astrum/scheduler.py
def __init__(
    self,
    tasks: list[tuple[str, TaskCallable]],
    task_order: list[TaskOrder],
    task_data_refs: dict[str, TaskData] | None = None,
    has_data_path: bool = False,
    ignore_tail_task: list[str] | None = None,
    concurrency_context: asyncio.Semaphore | None = None,
    *,
    task_retries: dict[str, int] | None = None,
    silence: bool = True,
) -> None:
    self.ignore_tail_task = ignore_tail_task or []
    self.task_data_refs = task_data_refs or {}
    self.has_data_path = has_data_path
    self.silence = silence
    self._validate_task_inputs(tasks)
    self.task_retries = task_retries or {}

    self.tasks = tasks
    self.task_order = task_order
    self.concurrency_context = concurrency_context

    self.execution_state = ExecutionState.NONE
    self.stage_status: StageStatus | None = None
    self.current_stage = -1
    self._task_outputs: dict[str, Any] = {}
    self._task_attempts: dict[str, list[TaskAttemptStatistics]] = {}
    self.task_return_set: dict[str, Any] = {}

execute async

execute() -> ExecutionReport
源代码位于: src/astrum/scheduler.py
async def execute(self) -> ExecutionReport:
    total_start_time = time.time()
    planning_start = time.time()
    self.current_stage = -1
    self.execution_state = ExecutionState.PLANNING
    self._task_attempts = {}

    try:
        plan = self.get_execute_timeline()
    except Exception as exc:
        self.execution_state = ExecutionState.FAILED
        return self._create_failed_report(total_start_time, time.time(), planning_start, str(exc))

    planning_end = time.time()
    execution_start = time.time()
    self.execution_state = ExecutionState.RUNNING

    stage_statistics: list[TaskStageStatistics] = []
    task_statistics: list[TaskExecutionStatistics] = []
    error_summary: list[str] = []
    task_map: dict[str, asyncio.Task[Any]] = {}
    task_start_times: dict[str, float] = {}
    completed_tasks_recorded: set[str] = set()

    execution_failed = False
    failed_task_name: str | None = None
    first_error: BaseException | None = None

    try:
        for stage in plan.stages:
            stage_start_time = time.time()
            self.current_stage = stage.stage_id
            self.stage_status = StageStatus.PENDING

            for task_name in stage.wait_for_tasks:
                if task_name not in task_map:
                    continue

                try:
                    self._task_outputs[task_name] = await task_map[task_name]
                    self._record_task_stat(
                        task_statistics,
                        completed_tasks_recorded,
                        task_name,
                        stage.stage_id,
                        task_start_times[task_name],
                        "completed",
                    )
                except Exception as exc:
                    self._record_task_stat(
                        task_statistics,
                        completed_tasks_recorded,
                        task_name,
                        stage.stage_id,
                        task_start_times[task_name],
                        "failed",
                        str(exc),
                    )
                    error_summary.append(f"Task '{task_name}' failed during execution: {exc}")
                    execution_failed = True
                    failed_task_name = task_name
                    first_error = exc
                    self.execution_state = ExecutionState.FAILED
                    self.stage_status = StageStatus.FAILED
                    raise

            self.stage_status = StageStatus.RUNNING
            for task_name in stage.start_tasks:
                if task_name in task_map:
                    raise TaskDuplicateExecutionError(task_name)

                task_callable = self.find_task_by_name(task_name)
                if task_callable is None:
                    raise TaskNotFoundError(task_name)

                task_start_times[task_name] = time.time()
                task_map[task_name] = asyncio.create_task(self._run_with_concurrency(self._invoke_task_with_retries(task_name, task_callable, error_summary)))

            self.stage_status = StageStatus.COMPLETED
            stage_end_time = time.time()
            stage_statistics.append(
                TaskStageStatistics(
                    stage_id=stage.stage_id,
                    stage_name=f"Stage_{stage.stage_id}",
                    start_time=stage_start_time,
                    end_time=stage_end_time,
                    duration=stage_end_time - stage_start_time,
                    parallel_task_count=len(stage.parallel_tasks),
                    wait_task_count=len(stage.wait_for_tasks),
                    parallel_tasks=stage.parallel_tasks.copy(),
                    wait_tasks=stage.wait_for_tasks.copy(),
                )
            )

        if not execution_failed:
            self.execution_state = ExecutionState.COMPLETED

    except Exception as exc:
        if not execution_failed:
            self.execution_state = ExecutionState.FAILED
            error_summary.append(f"Execution failed: {exc}")
            execution_failed = True
            first_error = exc

    tail_tasks: dict[str, asyncio.Task[Any]] = {}
    for task_name, task in task_map.items():
        if task_name in self.ignore_tail_task:
            tail_tasks[task_name] = task
            continue

        if task_name in completed_tasks_recorded:
            continue

        if execution_failed:
            task.cancel()
            try:
                await task
            except asyncio.CancelledError:
                self._record_task_stat(
                    task_statistics,
                    completed_tasks_recorded,
                    task_name,
                    -1,
                    task_start_times[task_name],
                    "cancelled",
                    "Task cancelled due to previous failure",
                )
            except Exception as exc:
                self._record_task_stat(
                    task_statistics,
                    completed_tasks_recorded,
                    task_name,
                    -1,
                    task_start_times[task_name],
                    "failed",
                    str(exc),
                )
        else:
            try:
                self._task_outputs[task_name] = await task
                self._record_task_stat(
                    task_statistics,
                    completed_tasks_recorded,
                    task_name,
                    -1,
                    task_start_times[task_name],
                    "completed",
                )
            except Exception as exc:
                error_summary.append(f"Final task '{task_name}' failed: {exc}")
                self._record_task_stat(
                    task_statistics,
                    completed_tasks_recorded,
                    task_name,
                    -1,
                    task_start_times[task_name],
                    "failed",
                    str(exc),
                )
                self.execution_state = ExecutionState.FAILED
                execution_failed = True
                if failed_task_name is None:
                    failed_task_name = task_name
                    first_error = exc

    if execution_failed and failed_task_name and first_error:
        error_summary.insert(0, f"Execution interrupted; task '{failed_task_name}' failed: {first_error}")

    total_end_time = time.time()
    successful_tasks = len([task for task in task_statistics if task.status == "completed"])
    failed_tasks = len([task for task in task_statistics if task.status == "failed"])

    return ExecutionReport(
        total_start_time=total_start_time,
        total_end_time=total_end_time,
        total_duration=total_end_time - total_start_time,
        planning_duration=planning_end - planning_start,
        execution_duration=total_end_time - execution_start,
        total_tasks=len(self.task_order),
        total_stages=len(plan.stages),
        max_parallelism=plan.max_parallelism,
        successful_tasks=successful_tasks,
        failed_tasks=failed_tasks,
        stage_statistics=stage_statistics,
        task_statistics=task_statistics,
        execution_state=self.execution_state.value,
        error_summary=error_summary,
        original_tasks=self._serialize_task_orders(),
        tail_tasks=tail_tasks,
        task_return_set=self.task_return_set,
    )