跳转至

类型工具

astrum.type_tools 是高级类型匹配辅助模块,普通用户一般不需要直接调用。

astrum.type_tools.TypeMatchError

Bases: TypeError

自定义类型匹配错误

Custom type matching error.

astrum.type_tools.match_input_type

match_input_type(func_obj: Any, allow_data_model: AllowDataModelType, key: Optional[str] = None, index: Optional[int] = None, single_object_mode: bool = False) -> None

匹配输入参数的类型。

Match the type of input parameters.

源代码位于: src/astrum/type_tools.py
def match_input_type(func_obj: Any, allow_data_model: AllowDataModelType, key: Optional[str] = None, index: Optional[int] = None, single_object_mode: bool = False) -> None:
    """
    匹配输入参数的类型。

    Match the type of input parameters.
    """
    real_func = _get_real_function(func_obj)
    sig = inspect.signature(real_func)
    params = list(sig.parameters.values())

    target_param = None

    # 路由逻辑:单对象模式 vs Key/Index 定位
    if single_object_mode:
        if not params:
            raise ValueError(f"启用单对象模式失败: 函数 {real_func.__name__} 没有输入参数")
        target_param = params[0]  # 单对象模式忽略 key 和 index,直接取第一个
    else:
        if key is None and index is None:
            raise ValueError("非单对象模式下,必须提供 key 或 index。如果不提供,请将 single_object_mode 设为 True")

        if key is not None:  # Key 优先级最高,有 Key 忽略 Index
            if key in sig.parameters:
                target_param = sig.parameters[key]
            else:
                raise ValueError(f"函数 {real_func.__name__} 中未找到参数名: {key}")
        elif index is not None:
            if 0 <= index < len(params):
                target_param = params[index]
            else:
                raise IndexError(f"索引 {index} 超出函数参数范围 (共 {len(params)} 个参数)")

    # 检查注解
    annotated_type = target_param.annotation
    if annotated_type == inspect.Parameter.empty:
        warnings.warn(f"警告:参数 '{target_param.name}' 没有定义类型注解,跳过严格类型检查。")
        return

    if not _is_type_compatible(annotated_type, allow_data_model):
        raise TypeMatchError(f"输入类型不匹配!参数 '{target_param.name}' 类型为 {annotated_type}, " f"但要求的 allow_data_model 为 {allow_data_model}")

astrum.type_tools.match_output_type

match_output_type(func_obj: Any, allow_data_model: AllowDataModelType, max_depth: int = 5) -> None

匹配输出值的类型。 融合版:支持标准注解解析 + 高级跨函数 AST 静态类型推断。

Match the type of output values. Merged version: supports standard annotation parsing plus advanced cross-function AST static type inference.

源代码位于: src/astrum/type_tools.py
def match_output_type(func_obj: Any, allow_data_model: AllowDataModelType, max_depth: int = 5) -> None:
    """
    匹配输出值的类型。
    融合版:支持标准注解解析 + 高级跨函数 AST 静态类型推断。

    Match the type of output values.
    Merged version: supports standard annotation parsing plus advanced
    cross-function AST static type inference.
    """
    real_func = _get_real_function(func_obj)
    sig = inspect.signature(real_func)

    # ==========================================
    # 第一阶段:优先使用标准的返回类型注解
    # ==========================================
    if sig.return_annotation != inspect.Signature.empty:
        if _is_type_compatible(sig.return_annotation, allow_data_model):
            return
        raise TypeMatchError(f"输出类型不匹配!函数声明的返回类型为 {sig.return_annotation}, " f"而要求的 allow_data_model 为 {allow_data_model}")

    # ==========================================
    # 第二阶段:启动高级 AST 推断引擎 (跨函数 & 运算推断)
    # ==========================================
    warnings.warn(f"函数 {real_func.__name__} 没有定义返回类型注解,正在启动高级 AST 跨函数类型推断...")

    # 提取全局变量空间,用于突破单函数的墙,寻找外部函数
    global_vars = real_func.__globals__ if hasattr(real_func, "__globals__") else {}

    def parse_and_infer(func_to_parse: Callable, current_depth: int) -> set:
        if current_depth > max_depth:
            return {Any}

        try:
            source = textwrap.dedent(inspect.getsource(func_to_parse))
            tree = ast.parse(source)
        except Exception:
            return {Any}

        # 1. 收集当前函数的局部变量
        context_vars = {}
        for node in ast.walk(tree):
            if isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name):
                        context_vars[target.id] = node.value

        # 2. 节点推断逻辑
        def infer_node(node: ast.AST, depth: int) -> Any:
            if depth > max_depth:
                return Any

            # 数学运算
            if isinstance(node, ast.BinOp):
                return int
            # 基础结构
            if isinstance(node, ast.Dict):
                return dict
            if isinstance(node, (ast.List, ast.ListComp)):
                return list
            if isinstance(node, ast.Constant):
                return type(node.value)

            # 跨函数或类实例化识别
            if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
                func_name = node.func.id
                # 如果调用的是同环境下的另一个函数,则深度跳入解析
                if func_name in global_vars and callable(global_vars[func_name]):
                    sub_types = parse_and_infer(global_vars[func_name], depth + 1)
                    return list(sub_types)[0] if sub_types else Any
                return func_name  # 当作 BaseModel 类名处理

            # 局部变量溯源
            if isinstance(node, ast.Name) and node.id in context_vars:
                return infer_node(context_vars[node.id], depth + 1)

            return Any

        # 3. 提取所有 Return 的结果
        inferred = set()
        for node in ast.walk(tree):
            if isinstance(node, ast.Return) and node.value:
                inferred.add(infer_node(node.value, current_depth))
        return inferred

    # 获取推断类型集合
    inferred_types = parse_and_infer(real_func, 0)

    if not inferred_types:
        return  # 如果函数没有返回值 (返回 None),忽略

    # ==========================================
    # 第三阶段:将推断出的类型与 allow_data_model 进行严格比对 (校验与拦截)
    # ==========================================
    for g_type in inferred_types:
        # 1. 字典兼容
        if g_type == dict and allow_data_model is dict:
            continue
        # 2. 列表兼容
        if g_type == list and allow_data_model is list:
            continue
        # 3. Pydantic 模型兼容
        if isinstance(g_type, str) and isinstance(allow_data_model, type) and issubclass(allow_data_model, BaseModel):
            if g_type == allow_data_model.__name__:
                continue
        # 4. 无法推断的动态类型 (兜底放行)
        if g_type == Any:
            warnings.warn("AST推断达到最大深度或遇到不可推断的动态类型,已跳过严格校验。")
            continue

        # 5. 如果以上都不满足,说明推断类型与要求的类型不符,触发异常拦截!
        raise TypeMatchError(f"AST 推断输出类型不匹配!推断出的返回格式为 '{g_type}' 语法结构, " f"但要求的 allow_data_model 为 {allow_data_model}")