def match_output_type(func_obj: Any, allow_data_model: AllowDataModelType, max_depth: int = 5) -> None:
"""
匹配输出值的类型。
融合版:支持标准注解解析 + 高级跨函数 AST 静态类型推断。
Match the type of output values.
Merged version: supports standard annotation parsing plus advanced
cross-function AST static type inference.
"""
real_func = _get_real_function(func_obj)
sig = inspect.signature(real_func)
# ==========================================
# 第一阶段:优先使用标准的返回类型注解
# ==========================================
if sig.return_annotation != inspect.Signature.empty:
if _is_type_compatible(sig.return_annotation, allow_data_model):
return
raise TypeMatchError(f"输出类型不匹配!函数声明的返回类型为 {sig.return_annotation}, " f"而要求的 allow_data_model 为 {allow_data_model}")
# ==========================================
# 第二阶段:启动高级 AST 推断引擎 (跨函数 & 运算推断)
# ==========================================
warnings.warn(f"函数 {real_func.__name__} 没有定义返回类型注解,正在启动高级 AST 跨函数类型推断...")
# 提取全局变量空间,用于突破单函数的墙,寻找外部函数
global_vars = real_func.__globals__ if hasattr(real_func, "__globals__") else {}
def parse_and_infer(func_to_parse: Callable, current_depth: int) -> set:
if current_depth > max_depth:
return {Any}
try:
source = textwrap.dedent(inspect.getsource(func_to_parse))
tree = ast.parse(source)
except Exception:
return {Any}
# 1. 收集当前函数的局部变量
context_vars = {}
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name):
context_vars[target.id] = node.value
# 2. 节点推断逻辑
def infer_node(node: ast.AST, depth: int) -> Any:
if depth > max_depth:
return Any
# 数学运算
if isinstance(node, ast.BinOp):
return int
# 基础结构
if isinstance(node, ast.Dict):
return dict
if isinstance(node, (ast.List, ast.ListComp)):
return list
if isinstance(node, ast.Constant):
return type(node.value)
# 跨函数或类实例化识别
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
func_name = node.func.id
# 如果调用的是同环境下的另一个函数,则深度跳入解析
if func_name in global_vars and callable(global_vars[func_name]):
sub_types = parse_and_infer(global_vars[func_name], depth + 1)
return list(sub_types)[0] if sub_types else Any
return func_name # 当作 BaseModel 类名处理
# 局部变量溯源
if isinstance(node, ast.Name) and node.id in context_vars:
return infer_node(context_vars[node.id], depth + 1)
return Any
# 3. 提取所有 Return 的结果
inferred = set()
for node in ast.walk(tree):
if isinstance(node, ast.Return) and node.value:
inferred.add(infer_node(node.value, current_depth))
return inferred
# 获取推断类型集合
inferred_types = parse_and_infer(real_func, 0)
if not inferred_types:
return # 如果函数没有返回值 (返回 None),忽略
# ==========================================
# 第三阶段:将推断出的类型与 allow_data_model 进行严格比对 (校验与拦截)
# ==========================================
for g_type in inferred_types:
# 1. 字典兼容
if g_type == dict and allow_data_model is dict:
continue
# 2. 列表兼容
if g_type == list and allow_data_model is list:
continue
# 3. Pydantic 模型兼容
if isinstance(g_type, str) and isinstance(allow_data_model, type) and issubclass(allow_data_model, BaseModel):
if g_type == allow_data_model.__name__:
continue
# 4. 无法推断的动态类型 (兜底放行)
if g_type == Any:
warnings.warn("AST推断达到最大深度或遇到不可推断的动态类型,已跳过严格校验。")
continue
# 5. 如果以上都不满足,说明推断类型与要求的类型不符,触发异常拦截!
raise TypeMatchError(f"AST 推断输出类型不匹配!推断出的返回格式为 '{g_type}' 语法结构, " f"但要求的 allow_data_model 为 {allow_data_model}")