写自定义算子的时候,加载阶段报错是最让人崩溃的——代码逻辑明明没问题,但 ACL 就是说 format error 或者 schema mismatch

这类问题大多数出在 metadef(元数据定义)上。metadef 描述了算子的接口规范,包括输入输出的 shape、dtype、format 等信息。如果自定义算子的实现和 metadef 描述对不上,ACL 在加载阶段就会拒绝。

元数据到底在描述什么

昇腾的算子加载流程里,metadef 是一个桥梁:一边是算子的实际实现(核函数),另一边是 ACL 的调用接口(Graph 或者单算子调用)。

# 用户代码里调用一个自定义算子
import acl

# 定义输入 tensor
input_x = acl.create_tensor(acl_dtype, shape, format, data_ptr)
input_y = acl.create_tensor(acl_dtype, shape, format, data_ptr)

# 调用算子(假设叫 my_custom_op)
op_desc = acl.create_op_desc(
    "my_custom_op",    # 算子名字
    acl.rt.op_type(),  # 算子类型
    input_x,
    input_y,
)

# 执行
acl.rt.execute_op(op_desc)

这段代码里,算子名字 my_custom_op 对应的实现必须在 metadef 里注册过。ACL 会去 metadef 定义的 schema 里校验传入的 tensor 参数是否匹配。

常见的 schema 校验失败原因

dtype 不匹配

最常见的问题:代码里传的 dtype 和 metadef 定义的不一样。

# 用户代码里传的是 float32
a = torch.randn(4, 4).npu().half()  # 这里用了 .half() 变成 FP16

# 但 metadef 里定义的是 float32
# 加载的时候 ACL 会报 dtype mismatch

排查方式:打印出实际 tensor 的 dtype,和 metadef 里的定义对照。

# 查 tensor dtype
print(f"input dtype: {tensor.dtype}")        # torch.float16
print(f"metadef expects: {metadef.input_dtype}")  # torch.float32

# 如果不匹配,要么改代码强制转 dtype,要么更新 metadef 定义

format 不匹配

昇腾 NPU 支持多种 tensor format:NCHW、NHWC、ND 等等。自定义算子如果对 format 有要求,但调用的时候传的 format 不对,也会报错。

# metadef 里定义卷积算子要求 NCHW 格式
# 但某段代码把 tensor 转成了 NHWC
a = a.permute(0, 3, 1, 2)  # NCHW -> NHWC
# 调用算子,报 format mismatch

昇腾的图编译器(Graph Compiler)一般会自动做 format 转换,但有些特殊算子不支持某些 format 之间的转换,这时候就要手动处理。

shape 不匹配

shape 校验有时候会被忽略。比如 metadef 里定义了输入要求是 4 维张量,但某段代码传了一个 3 维的 tensor(batch size = 1 的时候被 squeeze 掉了)。

# 代码里传了 3 维 tensor
x = torch.randn(1, 64, 64)  # 少了 batch 维度

# metadef 定义的是 [batch, channel, height, width],要求 4 维
# 报错:shape rank mismatch

写 metadef 的规范

metadef 文件通常是一个 JSON 或者 proto 格式的定义,描述算子的签名。

{
  "op_name": "my_custom_gemm",
  "op_type": "Gemm",
  "input_desc": [
    {
      "name": "x",
      "dtype": ["float32", "float16"],
      "format": ["ND", "NCHW"],
      "shape": [-1, -1, -1, -1]
    },
    {
      "name": "w",
      "dtype": ["float32", "float16"],
      "format": ["ND"],
      "shape": [-1, -1]
    }
  ],
  "output_desc": [
    {
      "name": "y",
      "dtype": ["float32", "float16"],
      "format": ["ND"],
      "shape": [-1, -1, -1, -1]
    }
  ],
  "attr_desc": [
    {
      "name": "transpose_a",
      "dtype": "bool",
      "default": false
    },
    {
      "name": "transpose_b",
      "dtype": "bool",
      "default": false
    }
  ]
}

-1 在 shape 里表示动态维度,运行时才确定具体值。如果你的算子只支持固定维度,这里要写具体数字。

动态 shape 的坑

动态 shape 是 metadef 里最容易出问题的部分。

比如写一个变长序列处理的算子,输入是 [batch, seq_len, hidden],seq_len 每次不一样。如果 metadef 里把 seq_len 写成固定值,推理的时候一旦实际长度和定义不符,就会报 shape 不匹配。

# metadef 定义
"shape": [8, 512, 768]  # 固定长度

# 但实际推理的时候 seq_len 可能是 256,也可能是 1024
# 报 shape mismatch

正确的写法是用 -1 表示动态维度:

"shape": [-1, -1, 768]  # batch 和 seq_len 动态,hidden 固定

但用 -1 之后,昇腾的图编译器在优化阶段可能没法做一些 shape-specific 的优化。动态 shape 是个双刃剑,用的时候要想清楚。

属性(attr)校验

metadef 里除了描述输入输出,还能定义算子的属性。属性是一种静态参数,在创建算子的时候就固定了,不参与计算图的数据流。

# 定义算子属性
op_desc = acl.create_op_desc("my_op")
acl.set_attr_bool(op_desc, "use_relu", True)
acl.set_attr_int(op_desc, "threshold", 128)

这些属性在 metadef 里也要声明。如果代码里设置了一个属性但 metadef 里没有定义,ACL 会报 unknown attribute 错误。

调试 metadef 问题的小技巧

当报错信息不够明确的时候,可以用昇腾提供的工具校验 metadef 文件:

# 校验 metadef 文件的合法性
python -m metadef.validator my_op_schema.json

# 检查 dtype、format、shape 定义是否完整
# 输出可能的问题列表

另外,很多 metadef 相关的报错其实是加载顺序问题:算子的实现库(.so)没有先加载,metadef 里定义的算子找不到对应的实现。

# 确保先加载算子库,再注册 metadef
import acl

# 加载自定义算子的实现
acl.rt.load_addon("/path/to/libmy_custom_op.so")

# 然后才能通过名字找到算子
acl.op.set_addon_op_type("my_custom_op", "ACL_ENGINE_OP_TYPE_USER_DEF")

加载顺序搞反的话,报错往往也是 schema 相关的问题,很容易误判。

仓库在 https://atomgit.com/cann/metadef,仓库里有一些标准算子的 metadef 定义可以参考。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐