从零开始学昇腾Ascend C算子开发-第十七篇:Broadcast广播算子

17.1 什么是Broadcast
17.1.1 Broadcast的概念
Broadcast(广播)是深度学习中常见的操作,用于将小形状的tensor扩展为大形状的tensor。比如输入shape为(2, 1),输出shape为(2, 16),Broadcast会将原来的一列扩展为相同的16列。
Broadcast规则:
- 从最右边的维度开始对齐
- 如果维度大小为1,可以广播到任意大小
- 如果维度大小相同,保持不变
- 如果维度大小不同且都不为1,则不能广播
示例:
输入shape: (2, 1)
输出shape: (2, 16)
结果: 将第2维从1扩展到16,每一行重复16次
输入: [[1], [2]]
输出: [[1, 1, 1, ..., 1], [2, 2, 2, ..., 2]] // 每行16个
17.1.2 Broadcast的应用场景
形状不匹配的运算:当两个tensor形状不同但可以广播时,使用Broadcast使它们形状匹配。
维度扩展:将低维tensor扩展到高维,比如将标量扩展到向量。
数据复制:将数据沿某个维度复制多次。
17.1.3 Broadcast的实现方式
Broadcast可以通过两种方式实现:
- 内存复制:直接复制数据到目标位置
- 索引计算:通过索引计算访问源数据,避免实际复制
Ascend C的BroadCast API内部会优化选择实现方式。
17.2 Broadcast算子实现
17.2.1 JSON描述文件
BroadcastCustom.json定义了算子规格:
[
{
"op": "BroadcastCustom",
"input_desc": [
{
"name": "x",
"param_type": "required",
"format": ["ND"],
"type": ["float16"]
}
],
"output_desc": [
{
"name": "y",
"param_type": "required",
"format": ["ND"],
"type": ["float16"]
}
]
}
]
Broadcast算子有一个输入和一个输出,输入和输出的数据类型相同。
17.2.2 算子属性
Broadcast算子通过属性(Attr)传递参数:
this->Attr("bufferMode").AttrType(REQUIRED).Int(0); // 缓冲区模式
this->Attr("dim").AttrType(REQUIRED).Int(0); // 维度
this->Attr("isReuseSource").AttrType(REQUIRED).Int(0); // 是否重用源数据
this->Attr("axis").AttrType(REQUIRED).Int(0); // 广播轴
this->Attr("num").AttrType(REQUIRED).Int(0); // 广播倍数
bufferMode:临时缓冲区模式
- 0:不分配额外缓冲区
- 1:分配最小最优临时空间
- 2:分配最大最优临时空间
- 其他:分配介于最大和最小之间的空间
dim:输入tensor的维度(1或2)
isReuseSource:是否重用源操作数的内存
axis:广播的轴(0或1)
num:广播的倍数,输出大小除以输入大小
17.3 Host端实现
17.3.1 Tiling函数
Tiling函数计算Broadcast所需的参数:
namespace optiling {
static ge::graphStatus TilingFunc(gert::TilingContext *context)
{
BroadcastTilingData tiling;
// 1. 获取算子属性
const gert::RuntimeAttrs *broadcastattrs = context->GetAttrs();
const uint32_t bufferMode = *(broadcastattrs->GetAttrPointer<uint32_t>(0));
const uint32_t dim = *(broadcastattrs->GetAttrPointer<uint32_t>(1));
const uint32_t isReuseSource = *(broadcastattrs->GetAttrPointer<uint32_t>(2));
const uint32_t axis = *(broadcastattrs->GetAttrPointer<uint32_t>(3));
const uint32_t num = *(broadcastattrs->GetAttrPointer<uint32_t>(4));
// 2. 获取输入信息
uint32_t totalLength = context->GetInputShape(0)->GetOriginShape().GetShapeSize();
auto dt = context->GetInputDesc(0)->GetDataType();
uint32_t dtypesize;
if (dt == ge::DT_FLOAT16) {
dtypesize = 2;
} else {
dtypesize = 4;
}
// 3. 计算输入输出形状
const gert::StorageShape *src_shape = context->GetInputShape(0);
uint32_t bLength, sLength;
ge::Shape inputShape, outputShape;
if (dim == 1) {
// 1维:shape为(1),广播到(num)
bLength = src_shape->GetStorageShape().GetDim(0);
std::vector<int64_t> inputShapeDim = {1};
inputShape = ge::Shape(inputShapeDim);
std::vector<int64_t> outputShapeDim = {num};
outputShape = ge::Shape(outputShapeDim);
} else {
// 2维:shape为(bLength, sLength)
bLength = src_shape->GetStorageShape().GetDim(0);
sLength = src_shape->GetStorageShape().GetDim(1);
std::vector<int64_t> inputShapeDim = {bLength, sLength};
inputShape = ge::Shape(inputShapeDim);
if (axis == 0) {
// 沿第0维广播:(bLength, sLength) -> (bLength * num, sLength)
std::vector<int64_t> outputShapeDim = {bLength * num, sLength};
outputShape = ge::Shape(outputShapeDim);
} else {
// 沿第1维广播:(bLength, sLength) -> (bLength, sLength * num)
std::vector<int64_t> outputShapeDim = {bLength, sLength * num};
outputShape = ge::Shape(outputShapeDim);
}
}
// 4. 计算临时缓冲区大小
uint32_t tmpSize;
uint32_t maxsize = 0, minsize = 0;
auto platformInfo = context->GetPlatformInfo();
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
if (isReuseSource == 0) {
AscendC::GetBroadCastMaxMinTmpSize(ascendcPlatform, inputShape, outputShape,
dtypesize, false, maxsize, minsize);
} else {
AscendC::GetBroadCastMaxMinTmpSize(ascendcPlatform, inputShape, outputShape,
dtypesize, true, maxsize, minsize);
}
// 根据bufferMode选择tmpSize
if (bufferMode == 0) {
tmpSize = 0;
} else if (bufferMode == 1) {
tmpSize = minsize;
} else if (bufferMode == 2) {
tmpSize = maxsize;
} else {
tmpSize = (maxsize + minsize) / 2;
}
// 5. 设置Tiling参数
tiling.set_tmpSize(tmpSize);
context->SetBlockDim(1);
tiling.set_totalLength(totalLength);
tiling.set_tilenum(1);
tiling.set_isReuseSource(isReuseSource);
tiling.set_axis(axis);
tiling.set_num(num);
tiling.set_bLength(bLength);
tiling.set_dim(dim);
context->SetTilingKey(1);
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(),
context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
size_t *currentWorkSpace = context->GetWorkspaceSizes(1);
currentWorkSpace[0] = 0;
return ge::GRAPH_SUCCESS;
}
} // namespace optiling
关键步骤:
- 获取算子属性(bufferMode、dim、axis等)
- 计算输入输出形状
- 调用GetBroadCastMaxMinTmpSize计算临时缓冲区大小
- 根据bufferMode选择tmpSize
- 设置Tiling参数
17.3.2 算子注册
算子注册时定义属性:
namespace ops {
class BroadcastCustom : public OpDef {
public:
BroadcastCustom(const char *name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_UINT8})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("y")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_UINT8})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
// 定义算子属性
this->Attr("bufferMode").AttrType(REQUIRED).Int(0);
this->Attr("dim").AttrType(REQUIRED).Int(0);
this->Attr("isReuseSource").AttrType(REQUIRED).Int(0);
this->Attr("axis").AttrType(REQUIRED).Int(0);
this->Attr("num").AttrType(REQUIRED).Int(0);
this->AICore()
.SetTiling(optiling::TilingFunc)
.AddConfig("ascend310p")
.AddConfig("ascend910b");
}
};
OP_ADD(BroadcastCustom);
} // namespace ops
17.4 Kernel端实现
17.4.1 Kernel类设计
KernelBroadcastCustom类实现Broadcast逻辑:
class KernelBroadcastCustom {
public:
__aicore__ inline KernelBroadcastCustom() {}
__aicore__ inline void Init(GM_ADDR x, GM_ADDR y, uint32_t totalLength,
uint32_t tilenum, uint32_t tmpSize,
uint32_t dim, uint32_t axis, uint32_t num,
uint32_t bLength)
{
this->blockLength = totalLength / AscendC::GetBlockNum();
this->tilenum = tilenum;
this->tileLength = this->blockLength / tilenum / BUFFER_NUM;
this->dim = dim;
this->tmpSize = tmpSize;
this->axis = axis;
this->num = num;
this->bLength = bLength;
// 计算输出tile长度
if (this->dim == 1) {
this->tileLength2 = num;
} else {
this->tileLength2 = this->tileLength * num;
}
// 设置GlobalTensor
xGm.SetGlobalBuffer((__gm__ DTYPE_X *)x +
this->blockLength * AscendC::GetBlockIdx(),
this->blockLength);
yGm.SetGlobalBuffer((__gm__ DTYPE_Y *)y +
this->tileLength2 * AscendC::GetBlockIdx(),
this->blockLength);
// 初始化队列
pipe.InitBuffer(inQueueX, BUFFER_NUM,
this->tileLength * sizeof(DTYPE_X));
pipe.InitBuffer(outQueueY, BUFFER_NUM,
this->tileLength2 * sizeof(DTYPE_Y));
// 初始化临时缓冲区(如果需要)
if (this->tmpSize != 0) {
pipe.InitBuffer(tmpQueue, this->tmpSize);
}
}
__aicore__ inline void Process()
{
int32_t loopCount = this->tilenum * BUFFER_NUM;
for (int32_t i = 0; i < loopCount; i++) {
CopyIn(i);
Compute(i);
CopyOut(i);
}
}
17.4.2 Compute函数
Compute函数使用BroadCast API进行广播:
__aicore__ inline void Compute(int32_t progress)
{
AscendC::LocalTensor<DTYPE_X> xLocal = inQueueX.DeQue<DTYPE_X>();
AscendC::LocalTensor<DTYPE_Y> yLocal = outQueueY.AllocTensor<DTYPE_Y>();
if (this->tmpSize == 0) {
// 不使用临时缓冲区
if (this->dim == 1) {
// 1维广播
const uint32_t srcShape[] = {1};
const uint32_t dstShape[] = {this->num};
AscendC::BroadCast<DTYPE_X, 1, 0>(yLocal, xLocal, dstShape, srcShape);
} else {
// 2维广播
const uint32_t srcShape[] = {this->bLength, this->tileLength / this->bLength};
if (this->axis == 0) {
// 沿第0维广播
const uint32_t dstShape[] = {this->bLength * this->num,
this->tileLength / this->bLength};
AscendC::BroadCast<DTYPE_X, 2, 0>(yLocal, xLocal, dstShape, srcShape);
} else {
// 沿第1维广播
const uint32_t dstShape[] = {this->bLength,
this->tileLength / this->bLength * this->num};
AscendC::BroadCast<DTYPE_X, 2, 1>(yLocal, xLocal, dstShape, srcShape);
}
}
} else {
// 使用临时缓冲区
AscendC::LocalTensor<uint8_t> tmpTensor = tmpQueue.Get<uint8_t>();
if (this->dim == 1) {
const uint32_t srcShape[] = {1};
const uint32_t dstShape[] = {this->num};
AscendC::BroadCast<DTYPE_X, 1, 0>(yLocal, xLocal, dstShape, srcShape, tmpTensor);
} else {
const uint32_t srcShape[] = {this->bLength, this->tileLength / this->bLength};
if (this->axis == 0) {
const uint32_t dstShape[] = {this->bLength * this->num,
this->tileLength / this->bLength};
AscendC::BroadCast<DTYPE_X, 2, 0>(yLocal, xLocal, dstShape, srcShape, tmpTensor);
} else {
const uint32_t dstShape[] = {this->bLength,
this->tileLength / this->bLength * this->num};
AscendC::BroadCast<DTYPE_X, 2, 1>(yLocal, xLocal, dstShape, srcShape, tmpTensor);
}
}
tmpQueue.FreeTensor(tmpTensor);
}
outQueueY.EnQue<DTYPE_Y>(yLocal);
inQueueX.FreeTensor(xLocal);
}
BroadCast API:
AscendC::BroadCast<DataType, Dim, Axis>(dst, src, dstShape, srcShape);
AscendC::BroadCast<DataType, Dim, Axis>(dst, src, dstShape, srcShape, tmpTensor);
- DataType:数据类型
- Dim:维度(1或2)
- Axis:广播轴(0或1)
- dst:输出LocalTensor
- src:输入LocalTensor
- dstShape:输出形状数组
- srcShape:输入形状数组
- tmpTensor:临时缓冲区(可选)
17.4.3 Kernel函数
extern "C" __global__ __aicore__ void broadcast_custom(
GM_ADDR x, GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
KernelBroadcastCustom op;
op.Init(x, y, tilingData.totalLength, tilingData.tilenum,
tilingData.tmpSize, tilingData.dim, tilingData.axis,
tilingData.num, tilingData.bLength);
if (TILING_KEY_IS(1)) {
op.Process();
}
}
17.5 Broadcast API详解
17.5.1 1维Broadcast
1维Broadcast将shape为(1)的tensor广播到(num):
const uint32_t srcShape[] = {1};
const uint32_t dstShape[] = {num};
AscendC::BroadCast<DTYPE_X, 1, 0>(yLocal, xLocal, dstShape, srcShape);
示例:
输入: [a]
输出: [a, a, a, ..., a] // num个a
17.5.2 2维Broadcast(axis=0)
沿第0维广播,将(bLength, sLength)广播到(bLength * num, sLength):
const uint32_t srcShape[] = {bLength, sLength};
const uint32_t dstShape[] = {bLength * num, sLength};
AscendC::BroadCast<DTYPE_X, 2, 0>(yLocal, xLocal, dstShape, srcShape);
示例:
输入: [[1, 2], [3, 4]] // shape (2, 2)
输出: [[1, 2], [3, 4], [1, 2], [3, 4], ...] // shape (2*num, 2)
17.5.3 2维Broadcast(axis=1)
沿第1维广播,将(bLength, sLength)广播到(bLength, sLength * num):
const uint32_t srcShape[] = {bLength, sLength};
const uint32_t dstShape[] = {bLength, sLength * num};
AscendC::BroadCast<DTYPE_X, 2, 1>(yLocal, xLocal, dstShape, srcShape);
示例:
输入: [[1], [2]] // shape (2, 1)
输出: [[1, 1, 1, ...], [2, 2, 2, ...]] // shape (2, num)
17.5.4 临时缓冲区
对于复杂的Broadcast,可能需要临时缓冲区:
AscendC::LocalTensor<uint8_t> tmpTensor = tmpQueue.Get<uint8_t>();
AscendC::BroadCast<DTYPE_X, 2, 0>(yLocal, xLocal, dstShape, srcShape, tmpTensor);
tmpQueue.FreeTensor(tmpTensor);
临时缓冲区的大小通过GetBroadCastMaxMinTmpSize计算。
17.6 调用示例
17.6.1 调用流程
AclNNInvocationNaive/main.cpp展示了如何调用Broadcast算子:
int main(int argc, char **argv)
{
// 1. 初始化ACL
int32_t deviceId = 0;
aclrtStream stream;
Init(deviceId, &stream);
// 2. 创建输入输出Tensor
std::vector<int64_t> inputXShape = {16, 1}; // 输入shape
std::vector<int64_t> outputYShape = {16, 3}; // 输出shape
aclTensor *inputX = CreateAclTensor(...);
aclTensor *outputY = CreateAclTensor(...);
// 3. 设置Broadcast参数
uint32_t axis = 1; // 沿第1维广播
bool isReuseSource = false; // 不重用源数据
uint32_t bufferMode = 1; // 最小最优临时空间
uint32_t dim = 2; // 2维
uint32_t num = (outputYShape[0] * outputYShape[1]) /
(inputXShape[0] * inputXShape[1]); // 广播倍数
// 4. 调用算子API
uint64_t workspaceSize = 0;
aclOpExecutor *executor;
aclnnBroadcastCustomGetWorkspaceSize(inputX, bufferMode, dim,
isReuseSource, axis, num, outputY,
&workspaceSize, &executor);
void *workspaceAddr = nullptr;
if (workspaceSize > 0) {
aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
}
aclnnBroadcastCustom(workspaceAddr, workspaceSize, executor, stream);
// 5. 同步并获取结果
aclrtSynchronizeStream(stream);
// ...
}
17.6.2 参数计算
num的计算:
uint32_t num = (outputYShape[0] * outputYShape[1]) /
(inputXShape[0] * inputXShape[1]);
num表示输出大小是输入的多少倍,用于确定广播的倍数。
示例:
- 输入shape: (16, 1),元素数 = 16
- 输出shape: (16, 3),元素数 = 48
- num = 48 / 16 = 3
17.7 临时缓冲区管理
17.7.1 缓冲区大小计算
GetBroadCastMaxMinTmpSize函数计算临时缓冲区的大小:
uint32_t maxsize = 0, minsize = 0;
auto platformInfo = context->GetPlatformInfo();
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
if (isReuseSource == 0) {
AscendC::GetBroadCastMaxMinTmpSize(ascendcPlatform, inputShape, outputShape,
dtypesize, false, maxsize, minsize);
} else {
AscendC::GetBroadCastMaxMinTmpSize(ascendcPlatform, inputShape, outputShape,
dtypesize, true, maxsize, minsize);
}
参数说明:
- platformInfo:平台信息
- inputShape:输入形状
- outputShape:输出形状
- dtypesize:数据类型大小(2或4字节)
- isReuseSource:是否重用源数据
- maxsize:最大最优临时空间大小(输出)
- minsize:最小最优临时空间大小(输出)
17.7.2 缓冲区模式选择
根据bufferMode选择tmpSize:
if (bufferMode == 0) {
tmpSize = 0; // 不分配
} else if (bufferMode == 1) {
tmpSize = minsize; // 最小
} else if (bufferMode == 2) {
tmpSize = maxsize; // 最大
} else {
tmpSize = (maxsize + minsize) / 2; // 中间值
}
选择建议:
- bufferMode=0:内存充足时,不使用临时缓冲区可能更快
- bufferMode=1:内存受限时,使用最小缓冲区
- bufferMode=2:性能优先时,使用最大缓冲区
- 其他值:平衡性能和内存
17.8 与Add算子的对比
17.8.1 复杂度对比
Add算子:
- 输入输出形状相同
- 简单的element-wise操作
- 不需要复杂的形状处理
Broadcast算子:
- 输入输出形状不同
- 需要处理形状变换
- 需要计算临时缓冲区
17.8.2 实现对比
Add算子:
AscendC::Add(zLocal, xLocal, yLocal, tileLength);
Broadcast算子:
AscendC::BroadCast<DTYPE_X, 2, 1>(yLocal, xLocal, dstShape, srcShape);
Broadcast需要指定形状和轴,更复杂。
17.8.3 应用场景对比
Add算子:用于两个相同形状tensor的相加。
Broadcast算子:用于形状不匹配时的扩展,通常配合其他算子使用。
17.9 关键注意事项
17.9.1 形状计算
必须正确计算输入输出形状,确保Broadcast规则正确:
- 从最右边维度对齐
- 维度大小为1可以广播
- 维度大小相同保持不变
17.9.2 轴的选择
axis必须是0或1(对于2维):
- axis=0:沿第0维广播
- axis=1:沿第1维广播
17.9.3 临时缓冲区
临时缓冲区的大小需要根据实际情况选择:
- 太小可能导致计算失败
- 太大可能浪费内存
- 使用GetBroadCastMaxMinTmpSize获取建议值
17.9.4 数据类型
Broadcast支持多种数据类型(float16、float32、int8、uint8),但输入输出类型必须相同。
17.10 扩展:支持更多维度
17.10.1 支持3维Broadcast
如果要支持3维Broadcast,需要:
- 在Tiling函数中处理3维形状:
if (dim == 3) {
uint32_t d0 = src_shape->GetStorageShape().GetDim(0);
uint32_t d1 = src_shape->GetStorageShape().GetDim(1);
uint32_t d2 = src_shape->GetStorageShape().GetDim(2);
// ...
}
- 在Kernel中使用3维BroadCast API:
const uint32_t srcShape[] = {d0, d1, d2};
const uint32_t dstShape[] = {...};
AscendC::BroadCast<DTYPE_X, 3, axis>(yLocal, xLocal, dstShape, srcShape);
17.10.2 支持多轴Broadcast
如果要支持同时沿多个轴广播,需要更复杂的逻辑,可能需要多次调用BroadCast API或使用更高级的API。
2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。
报名链接:https://www.hiascend.com/developer/activities/cann20252
社区地址:https://www.hiascend.com/developer
更多推荐

所有评论(0)