昇腾 CANN 初级算子开发:算子原型定义与 InferShape/InferType 实现
在昇腾 CANN 的算子开发流程中,“算子原型定义” 是连接框架(如 TensorFlow)与核函数的关键环节 —— 它负责描述算子的输入输出、推导 Shape 与数据类型,是算子能够被框架调用的前提。很多开发者入门时会忽略这一步,导致算子无法被正确解析或执行。本文基于 CANN 训练营的内容,从算子原型的核心组件出发,讲清 InferShape(Shape 推导)与 InferType(数据类型
前言
在昇腾 CANN 的算子开发流程中,“算子原型定义” 是连接框架(如 TensorFlow)与核函数的关键环节 —— 它负责描述算子的输入输出、推导 Shape 与数据类型,是算子能够被框架调用的前提。很多开发者入门时会忽略这一步,导致算子无法被正确解析或执行。本文基于 CANN 训练营的内容,从算子原型的核心组件出发,讲清 InferShape(Shape 推导)与 InferType(数据类型推导)的实现逻辑,同时给出完整的算子原型定义案例。
一、算子原型的核心作用
算子原型是算子的 “元信息描述”,主要完成 3 件事:
- 定义输入输出:指定算子的输入数量、输出数量,以及每个 Tensor 的属性(如是否为可选输入);
- 推导输出 Shape:通过
InferShape函数,根据输入 Shape 计算输出 Shape; - 推导输出数据类型:通过
InferType函数,根据输入数据类型确定输出数据类型。
算子原型是 CANN 编译框架(如 AscendCL)解析算子的依据,必须与核函数的输入输出一致。
二、算子原型的基础结构
在 CANN 中,算子原型需继承Operator类,并实现InferShape与InferType方法,基础结构如下:
c
运行
#include "ascendc/ascendc_operator.h"
class CustomOperator : public Operator {
public:
// 构造函数:定义输入输出
CustomOperator() {
// 添加输入(名称、是否可选)
AddInput("input1", false);
AddInput("input2", false);
// 添加输出(名称)
AddOutput("output");
}
// Shape推导
Status InferShape() override {
// 实现Shape推导逻辑
return Status::SUCCESS;
}
// 数据类型推导
Status InferType() override {
// 实现数据类型推导逻辑
return Status::SUCCESS;
}
};
// 注册算子原型(框架通过名称识别)
REGISTER_OPERATOR(CustomOperator, "CustomOp");
三、InferShape:输出 Shape 的推导逻辑
InferShape是算子原型的核心方法,需根据输入 Shape 的规则推导输出 Shape。常见的推导场景包括:
3.1 元素级算子的 Shape 推导
元素级算子(如加法、乘法)的输出 Shape 与输入 Shape 一致,需先检查输入 Shape 的一致性:
c
运行
Status InferShape() override {
// 获取输入Tensor
const Tensor& input1 = GetInput(0);
const Tensor& input2 = GetInput(1);
// 检查输入Shape是否一致
if (input1.GetShape() != input2.GetShape()) {
AERROR << "Input shapes are not consistent: input1=" << input1.GetShape()
<< ", input2=" << input2.GetShape();
return Status::ERROR;
}
// 输出Shape与输入一致
SetOutputShape(0, input1.GetShape());
return Status::SUCCESS;
}
3.2 广播算子的 Shape 推导
广播算子(如[N, 1]与[1, C]广播为[N, C])需实现广播规则:
c
运行
Status InferShape() override {
const Tensor& input1 = GetInput(0);
const Tensor& input2 = GetInput(1);
Shape shape1 = input1.GetShape();
Shape shape2 = input2.GetShape();
// 广播规则:从后往前匹配维度,维度不同时需有一个为1
Shape outputShape;
int32_t dimCount = std::max(shape1.GetDimCount(), shape2.GetDimCount());
for (int32_t i = 0; i < dimCount; i++) {
int32_t dim1 = (i < shape1.GetDimCount()) ? shape1.GetDim(i) : 1;
int32_t dim2 = (i < shape2.GetDimCount()) ? shape2.GetDim(i) : 1;
if (dim1 != 1 && dim2 != 1 && dim1 != dim2) {
AERROR << "Broadcast not supported: dim1=" << dim1 << ", dim2=" << dim2;
return Status::ERROR;
}
outputShape.AddDim(std::max(dim1, dim2));
}
SetOutputShape(0, outputShape);
return Status::SUCCESS;
}
3.3 降维算子的 Shape 推导
降维算子(如求和、求均值)需根据降维轴推导输出 Shape:
c
运行
Status InferShape() override {
const Tensor& input = GetInput(0);
const Tensor& axis = GetInput(1); // 降维轴(如[1]表示对第1维降维)
// 获取降维轴的值(假设axis是标量)
int32_t axisValue = *reinterpret_cast<const int32_t*>(axis.GetData());
Shape inputShape = input.GetShape();
int32_t dimCount = inputShape.GetDimCount();
// 检查轴的有效性
if (axisValue < 0 || axisValue >= dimCount) {
AERROR << "Invalid axis: " << axisValue << ", dimCount=" << dimCount;
return Status::ERROR;
}
// 输出Shape:移除降维轴
Shape outputShape;
for (int32_t i = 0; i < dimCount; i++) {
if (i != axisValue) {
outputShape.AddDim(inputShape.GetDim(i));
}
}
SetOutputShape(0, outputShape);
return Status::SUCCESS;
}
四、InferType:输出数据类型的推导逻辑
InferType负责确定输出数据类型,常见规则包括:
4.1 与输入类型一致
元素级算子通常采用与输入一致的数据类型:
c
运行
Status InferType() override {
const Tensor& input1 = GetInput(0);
SetOutputType(0, input1.GetDataType());
return Status::SUCCESS;
}
4.2 精度提升
部分算子(如求和)会提升输出精度(如 float16 输入→float32 输出):
c
运行
Status InferType() override {
const Tensor& input = GetInput(0);
DataType inputType = input.GetDataType();
if (inputType == DT_FLOAT16) {
SetOutputType(0, DT_FLOAT32);
} else {
SetOutputType(0, inputType);
}
return Status::SUCCESS;
}
五、实战:完整的算子原型定义案例
以 “带广播的元素级加法算子” 为例,完整的算子原型定义如下:
c
运行
#include "ascendc/ascendc_operator.h"
class BroadcastAddOperator : public Operator {
public:
BroadcastAddOperator() {
AddInput("x", false);
AddInput("y", false);
AddOutput("z");
}
Status InferShape() override {
const Tensor& x = GetInput(0);
const Tensor& y = GetInput(1);
Shape xShape = x.GetShape();
Shape yShape = y.GetShape();
// 广播Shape推导
Shape outputShape;
int32_t maxDim = std::max(xShape.GetDimCount(), yShape.GetDimCount());
for (int32_t i = 0; i < maxDim; i++) {
int32_t xDim = (i < xShape.GetDimCount()) ? xShape.GetDim(i) : 1;
int32_t yDim = (i < yShape.GetDimCount()) ? yShape.GetDim(i) : 1;
if (xDim != 1 && yDim != 1 && xDim != yDim) {
AERROR << "Broadcast failed: xDim=" << xDim << ", yDim=" << yDim;
return Status::ERROR;
}
outputShape.AddDim(std::max(xDim, yDim));
}
SetOutputShape(0, outputShape);
return Status::SUCCESS;
}
Status InferType() override {
const Tensor& x = GetInput(0);
SetOutputType(0, x.GetDataType());
return Status::SUCCESS;
}
};
REGISTER_OPERATOR(BroadcastAddOperator, "BroadcastAdd");
六、算子原型的验证方法
算子原型定义完成后,需验证其正确性:
- Shape 推导验证:构造不同输入 Shape 的用例,检查输出 Shape 是否符合预期;
- 数据类型验证:构造不同输入类型的用例,检查输出类型是否符合预期;
- 错误场景验证:构造 Shape 不匹配、轴无效等用例,检查是否返回错误状态。
结语
算子原型定义是昇腾 CANN 算子开发的 “前端入口”,其核心是 “规则化推导 Shape 与数据类型”—— 既需要贴合算子的数学逻辑(如元素级、广播、降维),也需要考虑框架的兼容性(如输入输出的属性定义)。掌握这一能力后,算子才能被正确解析并调用核函数。后续进阶内容中,算子原型还会涉及 “属性参数”“动态 Shape 适配” 等复杂场景,建议大家结合 CANN 官方的算子原型示例,多做不同类型算子的推导练习。
2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。
报名链接:https://www.hiascend.com/developer/activities/cann20252
更多推荐




所有评论(0)