前言

你写了个Ascend C算子,编译成.o文件,往GE图引擎里一塞——GE怎么知道这个算子叫什么名字、要几个输入、输出什么类型、有哪些属性?

答案是metadef。它给每个算子发"身份证",GE编译计算图时查这个身份证来理解你的算子。没有身份证,GE不认识你的算子,编译直接报错。

这篇文章讲metadef是什么、里面有什么、怎么给你的算子写metadef。

metadef是什么?

打个比方:metadef是餐厅的"菜单",算子是"菜",GE是"厨师"。

厨师(GE)要做菜(编译计算图),必须先看菜单(metadef)知道有哪些菜、每道菜要什么食材(输入)、用什么调料(属性)、出什么成品(输出)。没有菜单,厨师不知道怎么下单。

metadef定义了CANN生态中所有算子的元数据

  • 算子叫什么名字(比如"MatMul")
  • 有几个输入,每个输入的类型和shape约束(比如"输入A是2D FP16 tensor")
  • 有几个输出,每个输出的类型和shape(比如"输出C是2D FP16 tensor")
  • 有哪些属性,每个属性的类型和默认值(比如"transpose_a: bool, 默认false")
  • 输出的shape和dtype怎么从输入推导出来(InferShape/InferDataType)

metadef本身不包含算子的计算逻辑(那是Ascend C代码的事),只包含算子的"描述信息"。就像菜单只告诉你菜名和食材,不教你炒菜。

metadef里有什么?

metadef的核心内容分三部分:算子原型(OpProto)、算子注册(REGISTER_OP)、推导函数(InferShape/InferDataType)。

算子原型(OpProto)

OpProto定义了算子的"外貌"——名字、输入、输出、属性:

// ops-math仓库中MatMul算子的OpProto定义
namespace ge {
class MatMulOp : public OpProto {
public:
    MatMulOp() {
        // 算子名字
        SetName("MatMul");
        
        // 输入定义
        Input("x1")
            .SetDataType(DT_FLOAT16)    // 输入x1的类型:FP16
            .SetShape({-1, -1});        // 输入x1的shape:2D,维度不限
        Input("x2")
            .SetDataType(DT_FLOAT16)
            .SetShape({-1, -1});
        
        // 输出定义
        Output("y")
            .SetDataType(DT_FLOAT16)
            .SetShape({-1, -1});
        
        // 属性定义
        Attr("transpose_x1")
            .SetType(ATTR_BOOL)         // 属性类型:bool
            .SetDefaultValue(false);    // 默认值:false
        Attr("transpose_x2")
            .SetType(ATTR_BOOL)
            .SetDefaultValue(false);
    }
};
}

这段代码的意思:MatMul算子有2个FP16输入(x1, x2)、1个FP16输出(y)、2个bool属性(transpose_x1, transpose_x2)。

算子注册(REGISTER_OP)

写好OpProto后,用REGISTER_OP宏把它注册到全局注册表:

// 注册MatMul算子到全局注册表
REGISTER_OP(MatMulOp);

GE编译计算图时,遇到一个叫"MatMul"的算子节点,就去全局注册表里查MatMulOp,拿到输入/输出/属性的定义,然后做类型检查和shape推导。

推导函数(InferShape / InferDataType)

InferShape根据输入的shape推导输出的shape。这是GE编译的关键——GE需要知道每个算子输出的shape,才能给输出tensor分配内存。

// MatMul的InferShape
// 输入x1: [M, K], x2: [K, N] → 输出y: [M, N]
graphStatus MatMulInferShape(const ge::OpDescPtr& op_desc) {
    auto x1_shape = op_desc->GetInputDesc(0).GetShape();
    auto x2_shape = op_desc->GetInputDesc(1).GetShape();
    
    int64_t M = x1_shape.GetDim(0);  // x1的第0维
    int64_t K1 = x1_shape.GetDim(1); // x1的第1维
    int64_t K2 = x2_shape.GetDim(0); // x2的第0维
    int64_t N = x2_shape.GetDim(1);  // x2的第1维
    
    // 检查K维度匹配
    if (K1 != K2) {
        return GRAPH_FAILED;  // K不匹配,编译失败
    }
    
    // 设置输出shape
    ge::GeShape y_shape({M, N});
    op_desc->GetOutputDesc(0).SetShape(y_shape);
    
    return GRAPH_SUCCESS;
}

InferDataType根据输入的dtype推导输出的dtype。对于MatMul,输出类型跟输入相同:

graphStatus MatMulInferDataType(const ge::OpDescPtr& op_desc) {
    auto x1_dtype = op_desc->GetInputDesc(0).GetDataType();
    op_desc->GetOutputDesc(0).SetDataType(x1_dtype);  // 输出类型=输入类型
    return GRAPH_SUCCESS;
}

实战:给你的Ascend C算子写metadef

假设你写了一个自定义融合算子MatMulRelu(矩阵乘+ReLU融合),需要给它写metadef让GE认识。

Step 1:定义OpProto

// custom_op_proto.h
namespace ge {
class MatMulReluOp : public OpProto {
public:
    MatMulReluOp() {
        SetName("MatMulRelu");
        
        // 输入:跟MatMul一样,2个FP16矩阵
        Input("x1")
            .SetDataType(DT_FLOAT16)
            .SetShape({-1, -1});
        Input("x2")
            .SetDataType(DT_FLOAT16)
            .SetShape({-1, -1});
        
        // 输出:1个FP16矩阵(ReLU后)
        Output("y")
            .SetDataType(DT_FLOAT16)
            .SetShape({-1, -1});
        
        // 属性:跟MatMul一样
        Attr("transpose_x1")
            .SetType(ATTR_BOOL)
            .SetDefaultValue(false);
        Attr("transpose_x2")
            .SetType(ATTR_BOOL)
            .SetDefaultValue(false);
    }
};
}

Step 2:实现InferShape和InferDataType

// custom_op_infer.cc
graphStatus MatMulReluInferShape(const ge::OpDescPtr& op_desc) {
    auto x1_shape = op_desc->GetInputDesc(0).GetShape();
    auto x2_shape = op_desc->GetInputDesc(1).GetShape();
    
    int64_t M = x1_shape.GetDim(0);
    int64_t K1 = x1_shape.GetDim(1);
    int64_t K2 = x2_shape.GetDim(0);
    int64_t N = x2_shape.GetDim(1);
    
    if (K1 != K2) {
        return GRAPH_FAILED;
    }
    
    // 输出shape跟MatMul一样,ReLU不改变shape
    ge::GeShape y_shape({M, N});
    op_desc->GetOutputDesc(0).SetShape(y_shape);
    return GRAPH_SUCCESS;
}

graphStatus MatMulReluInferDataType(const ge::OpDescPtr& op_desc) {
    auto x1_dtype = op_desc->GetInputDesc(0).GetDataType();
    op_desc->GetOutputDesc(0).SetDataType(x1_dtype);
    return GRAPH_SUCCESS;
}

Step 3:注册算子

// custom_op_register.cc
REGISTER_OP(MatMulReluOp)
    .InferShape(MatMulReluInferShape)
    .InferDataType(MatMulReluInferDataType);

Step 4:编译并安装

# 编译metadef动态库
g++ -shared -fPIC -o libcustom_op_proto.so \
    custom_op_proto.cc custom_op_infer.cc custom_op_register.cc \
    -I/usr/local/Ascend/ascend-toolkit/latest/metadef/include \
    -L/usr/local/Ascend/ascend-toolkit/latest/metadef/lib \
    -lgraph -lge_common

# 安装到CANN的算子目录
sudo cp libcustom_op_proto.so \
    /usr/local/Ascend/ascend-toolkit/latest/opp/built-in/op_impl/ai_core/tbe/op_tiling/

安装后,GE编译时就能识别"MatMulRelu"这个算子了。

metadef在CANN架构中的位置

metadef位于CANN五层架构的第3层(昇腾计算编译层),被GE图编译器引用:

第3层:昇腾计算编译层
  ├─ Graph Compiler(GE图编译器)
  │   ├─ 查metadef获取算子元数据
  │   ├─ 调InferShape推导输出shape
  │   └─ 分配输出内存
  ├─ metadef(算子元数据定义)← 你在这里
  └─ BiSheng / ATC 编译器

依赖关系:metadef ← ge ← 所有算子仓库。每个算子仓库(ops-math、ops-nn等)都依赖metadef来注册自己的算子。

踩坑实录

坑1:InferShape写错了,GE编译报"output shape mismatch"

问题:InferShape推导的输出shape跟算子实际输出的shape不一致,GE报错。

原因:InferShape是静态推导(编译时),算子执行是动态计算(运行时)。如果InferShape推导错了,GE给输出分配的内存大小不对,运行时会出问题。

解决方案:InferShape要跟算子的实际输出严格一致。不确定时,用-1标记动态维度(表示运行时才能确定):

//  如果输出某个维度是动态的,用-1
ge::GeShape y_shape({-1, N});  // M维度动态
op_desc->GetOutputDesc(0).SetShape(y_shape);

坑2:属性类型必须是metadef支持的

问题:想定义一个自定义结构体类型的属性,编译报错。

原因:metadef只支持以下属性类型:int、float、string、bool、list、list、list。不支持自定义结构体。

解决方案:把结构体拆成多个属性,或者序列化成string:

//  错误写法(自定义结构体)
Attr("config").SetType(ATTR_CUSTOM_STRUCT);

//  正确写法(拆成多个属性)
Attr("config_batch_size").SetType(ATTR_INT);
Attr("config_seq_len").SetType(ATTR_INT);
Attr("config_dtype").SetType(ATTR_STRING);

坑3:算子名字不能跟已有算子重复

问题:注册了一个叫"MatMul"的算子,但ops-math已经注册了同名算子,冲突。

原因:全局注册表里算子名字必须唯一。重复注册会覆盖已有定义。

解决方案:自定义算子加前缀或后缀避免冲突:

SetName("Custom_MatMulRelu");  // 加Custom_前缀

结尾

metadef不起眼,但没它整个编译链就断了。你写的Ascend C算子再好,GE不认识它,就无法编译执行。metadef就是给算子发身份证——告诉GE这个算子叫什么、要什么输入、出什么输出、有哪些属性。身份证写对了,GE才能正确编译和调度你的算子。

写自定义算子时,别只写Ascend C代码,把metadef也一起写了。这是算子开发的基本功。

https://atomgit.com/cann/metadef

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐