前言

在昇腾 CANN 的算子开发流程中,“算子原型定义” 是连接框架(如 TensorFlow)与核函数的关键环节 —— 它负责描述算子的输入输出、推导 Shape 与数据类型,是算子能够被框架调用的前提。很多开发者入门时会忽略这一步,导致算子无法被正确解析或执行。本文基于 CANN 训练营的内容,从算子原型的核心组件出发,讲清 InferShape(Shape 推导)与 InferType(数据类型推导)的实现逻辑,同时给出完整的算子原型定义案例。

一、算子原型的核心作用

算子原型是算子的 “元信息描述”,主要完成 3 件事:

  1. 定义输入输出:指定算子的输入数量、输出数量,以及每个 Tensor 的属性(如是否为可选输入);
  2. 推导输出 Shape:通过InferShape函数,根据输入 Shape 计算输出 Shape;
  3. 推导输出数据类型:通过InferType函数,根据输入数据类型确定输出数据类型。

算子原型是 CANN 编译框架(如 AscendCL)解析算子的依据,必须与核函数的输入输出一致。

二、算子原型的基础结构

在 CANN 中,算子原型需继承Operator类,并实现InferShapeInferType方法,基础结构如下:

c

运行

#include "ascendc/ascendc_operator.h"

class CustomOperator : public Operator {
public:
    // 构造函数:定义输入输出
    CustomOperator() {
        // 添加输入(名称、是否可选)
        AddInput("input1", false);
        AddInput("input2", false);
        // 添加输出(名称)
        AddOutput("output");
    }

    // Shape推导
    Status InferShape() override {
        // 实现Shape推导逻辑
        return Status::SUCCESS;
    }

    // 数据类型推导
    Status InferType() override {
        // 实现数据类型推导逻辑
        return Status::SUCCESS;
    }
};

// 注册算子原型(框架通过名称识别)
REGISTER_OPERATOR(CustomOperator, "CustomOp");
三、InferShape:输出 Shape 的推导逻辑

InferShape是算子原型的核心方法,需根据输入 Shape 的规则推导输出 Shape。常见的推导场景包括:

3.1 元素级算子的 Shape 推导

元素级算子(如加法、乘法)的输出 Shape 与输入 Shape 一致,需先检查输入 Shape 的一致性:

c

运行

Status InferShape() override {
    // 获取输入Tensor
    const Tensor& input1 = GetInput(0);
    const Tensor& input2 = GetInput(1);

    // 检查输入Shape是否一致
    if (input1.GetShape() != input2.GetShape()) {
        AERROR << "Input shapes are not consistent: input1=" << input1.GetShape() 
               << ", input2=" << input2.GetShape();
        return Status::ERROR;
    }

    // 输出Shape与输入一致
    SetOutputShape(0, input1.GetShape());
    return Status::SUCCESS;
}
3.2 广播算子的 Shape 推导

广播算子(如[N, 1][1, C]广播为[N, C])需实现广播规则:

c

运行

Status InferShape() override {
    const Tensor& input1 = GetInput(0);
    const Tensor& input2 = GetInput(1);
    Shape shape1 = input1.GetShape();
    Shape shape2 = input2.GetShape();

    // 广播规则:从后往前匹配维度,维度不同时需有一个为1
    Shape outputShape;
    int32_t dimCount = std::max(shape1.GetDimCount(), shape2.GetDimCount());
    for (int32_t i = 0; i < dimCount; i++) {
        int32_t dim1 = (i < shape1.GetDimCount()) ? shape1.GetDim(i) : 1;
        int32_t dim2 = (i < shape2.GetDimCount()) ? shape2.GetDim(i) : 1;
        if (dim1 != 1 && dim2 != 1 && dim1 != dim2) {
            AERROR << "Broadcast not supported: dim1=" << dim1 << ", dim2=" << dim2;
            return Status::ERROR;
        }
        outputShape.AddDim(std::max(dim1, dim2));
    }

    SetOutputShape(0, outputShape);
    return Status::SUCCESS;
}
3.3 降维算子的 Shape 推导

降维算子(如求和、求均值)需根据降维轴推导输出 Shape:

c

运行

Status InferShape() override {
    const Tensor& input = GetInput(0);
    const Tensor& axis = GetInput(1);  // 降维轴(如[1]表示对第1维降维)

    // 获取降维轴的值(假设axis是标量)
    int32_t axisValue = *reinterpret_cast<const int32_t*>(axis.GetData());
    Shape inputShape = input.GetShape();
    int32_t dimCount = inputShape.GetDimCount();

    // 检查轴的有效性
    if (axisValue < 0 || axisValue >= dimCount) {
        AERROR << "Invalid axis: " << axisValue << ", dimCount=" << dimCount;
        return Status::ERROR;
    }

    // 输出Shape:移除降维轴
    Shape outputShape;
    for (int32_t i = 0; i < dimCount; i++) {
        if (i != axisValue) {
            outputShape.AddDim(inputShape.GetDim(i));
        }
    }

    SetOutputShape(0, outputShape);
    return Status::SUCCESS;
}
四、InferType:输出数据类型的推导逻辑

InferType负责确定输出数据类型,常见规则包括:

4.1 与输入类型一致

元素级算子通常采用与输入一致的数据类型:

c

运行

Status InferType() override {
    const Tensor& input1 = GetInput(0);
    SetOutputType(0, input1.GetDataType());
    return Status::SUCCESS;
}
4.2 精度提升

部分算子(如求和)会提升输出精度(如 float16 输入→float32 输出):

c

运行

Status InferType() override {
    const Tensor& input = GetInput(0);
    DataType inputType = input.GetDataType();
    if (inputType == DT_FLOAT16) {
        SetOutputType(0, DT_FLOAT32);
    } else {
        SetOutputType(0, inputType);
    }
    return Status::SUCCESS;
}
五、实战:完整的算子原型定义案例

以 “带广播的元素级加法算子” 为例,完整的算子原型定义如下:

c

运行

#include "ascendc/ascendc_operator.h"

class BroadcastAddOperator : public Operator {
public:
    BroadcastAddOperator() {
        AddInput("x", false);
        AddInput("y", false);
        AddOutput("z");
    }

    Status InferShape() override {
        const Tensor& x = GetInput(0);
        const Tensor& y = GetInput(1);
        Shape xShape = x.GetShape();
        Shape yShape = y.GetShape();

        // 广播Shape推导
        Shape outputShape;
        int32_t maxDim = std::max(xShape.GetDimCount(), yShape.GetDimCount());
        for (int32_t i = 0; i < maxDim; i++) {
            int32_t xDim = (i < xShape.GetDimCount()) ? xShape.GetDim(i) : 1;
            int32_t yDim = (i < yShape.GetDimCount()) ? yShape.GetDim(i) : 1;
            if (xDim != 1 && yDim != 1 && xDim != yDim) {
                AERROR << "Broadcast failed: xDim=" << xDim << ", yDim=" << yDim;
                return Status::ERROR;
            }
            outputShape.AddDim(std::max(xDim, yDim));
        }

        SetOutputShape(0, outputShape);
        return Status::SUCCESS;
    }

    Status InferType() override {
        const Tensor& x = GetInput(0);
        SetOutputType(0, x.GetDataType());
        return Status::SUCCESS;
    }
};

REGISTER_OPERATOR(BroadcastAddOperator, "BroadcastAdd");
六、算子原型的验证方法

算子原型定义完成后,需验证其正确性:

  1. Shape 推导验证:构造不同输入 Shape 的用例,检查输出 Shape 是否符合预期;
  2. 数据类型验证:构造不同输入类型的用例,检查输出类型是否符合预期;
  3. 错误场景验证:构造 Shape 不匹配、轴无效等用例,检查是否返回错误状态。
结语

算子原型定义是昇腾 CANN 算子开发的 “前端入口”,其核心是 “规则化推导 Shape 与数据类型”—— 既需要贴合算子的数学逻辑(如元素级、广播、降维),也需要考虑框架的兼容性(如输入输出的属性定义)。掌握这一能力后,算子才能被正确解析并调用核函数。后续进阶内容中,算子原型还会涉及 “属性参数”“动态 Shape 适配” 等复杂场景,建议大家结合 CANN 官方的算子原型示例,多做不同类型算子的推导练习。

2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。

报名链接:https://www.hiascend.com/developer/activities/cann20252

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐