在这里插入图片描述

17.1 什么是Broadcast

17.1.1 Broadcast的概念

Broadcast(广播)是深度学习中常见的操作,用于将小形状的tensor扩展为大形状的tensor。比如输入shape为(2, 1),输出shape为(2, 16),Broadcast会将原来的一列扩展为相同的16列。

Broadcast规则

  • 从最右边的维度开始对齐
  • 如果维度大小为1,可以广播到任意大小
  • 如果维度大小相同,保持不变
  • 如果维度大小不同且都不为1,则不能广播

示例

输入shape: (2, 1)
输出shape: (2, 16)
结果: 将第2维从1扩展到16,每一行重复16次

输入: [[1], [2]]
输出: [[1, 1, 1, ..., 1], [2, 2, 2, ..., 2]]  // 每行16个

17.1.2 Broadcast的应用场景

形状不匹配的运算:当两个tensor形状不同但可以广播时,使用Broadcast使它们形状匹配。

维度扩展:将低维tensor扩展到高维,比如将标量扩展到向量。

数据复制:将数据沿某个维度复制多次。

17.1.3 Broadcast的实现方式

Broadcast可以通过两种方式实现:

  • 内存复制:直接复制数据到目标位置
  • 索引计算:通过索引计算访问源数据,避免实际复制

Ascend C的BroadCast API内部会优化选择实现方式。


17.2 Broadcast算子实现

17.2.1 JSON描述文件

BroadcastCustom.json定义了算子规格:

[
    {
        "op": "BroadcastCustom",
        "input_desc": [
            {
                "name": "x",
                "param_type": "required",
                "format": ["ND"],
                "type": ["float16"]
            }
        ],
        "output_desc": [
            {
                "name": "y",
                "param_type": "required",
                "format": ["ND"],
                "type": ["float16"]
            }
        ]
    }
]

Broadcast算子有一个输入和一个输出,输入和输出的数据类型相同。

17.2.2 算子属性

Broadcast算子通过属性(Attr)传递参数:

this->Attr("bufferMode").AttrType(REQUIRED).Int(0);      // 缓冲区模式
this->Attr("dim").AttrType(REQUIRED).Int(0);             // 维度
this->Attr("isReuseSource").AttrType(REQUIRED).Int(0);  // 是否重用源数据
this->Attr("axis").AttrType(REQUIRED).Int(0);           // 广播轴
this->Attr("num").AttrType(REQUIRED).Int(0);            // 广播倍数

bufferMode:临时缓冲区模式

  • 0:不分配额外缓冲区
  • 1:分配最小最优临时空间
  • 2:分配最大最优临时空间
  • 其他:分配介于最大和最小之间的空间

dim:输入tensor的维度(1或2)

isReuseSource:是否重用源操作数的内存

axis:广播的轴(0或1)

num:广播的倍数,输出大小除以输入大小


17.3 Host端实现

17.3.1 Tiling函数

Tiling函数计算Broadcast所需的参数:

namespace optiling {
static ge::graphStatus TilingFunc(gert::TilingContext *context)
{
    BroadcastTilingData tiling;
    
    // 1. 获取算子属性
    const gert::RuntimeAttrs *broadcastattrs = context->GetAttrs();
    const uint32_t bufferMode = *(broadcastattrs->GetAttrPointer<uint32_t>(0));
    const uint32_t dim = *(broadcastattrs->GetAttrPointer<uint32_t>(1));
    const uint32_t isReuseSource = *(broadcastattrs->GetAttrPointer<uint32_t>(2));
    const uint32_t axis = *(broadcastattrs->GetAttrPointer<uint32_t>(3));
    const uint32_t num = *(broadcastattrs->GetAttrPointer<uint32_t>(4));
    
    // 2. 获取输入信息
    uint32_t totalLength = context->GetInputShape(0)->GetOriginShape().GetShapeSize();
    auto dt = context->GetInputDesc(0)->GetDataType();
    uint32_t dtypesize;
    if (dt == ge::DT_FLOAT16) {
        dtypesize = 2;
    } else {
        dtypesize = 4;
    }
    
    // 3. 计算输入输出形状
    const gert::StorageShape *src_shape = context->GetInputShape(0);
    uint32_t bLength, sLength;
    ge::Shape inputShape, outputShape;
    
    if (dim == 1) {
        // 1维:shape为(1),广播到(num)
        bLength = src_shape->GetStorageShape().GetDim(0);
        std::vector<int64_t> inputShapeDim = {1};
        inputShape = ge::Shape(inputShapeDim);
        std::vector<int64_t> outputShapeDim = {num};
        outputShape = ge::Shape(outputShapeDim);
    } else {
        // 2维:shape为(bLength, sLength)
        bLength = src_shape->GetStorageShape().GetDim(0);
        sLength = src_shape->GetStorageShape().GetDim(1);
        std::vector<int64_t> inputShapeDim = {bLength, sLength};
        inputShape = ge::Shape(inputShapeDim);
        
        if (axis == 0) {
            // 沿第0维广播:(bLength, sLength) -> (bLength * num, sLength)
            std::vector<int64_t> outputShapeDim = {bLength * num, sLength};
            outputShape = ge::Shape(outputShapeDim);
        } else {
            // 沿第1维广播:(bLength, sLength) -> (bLength, sLength * num)
            std::vector<int64_t> outputShapeDim = {bLength, sLength * num};
            outputShape = ge::Shape(outputShapeDim);
        }
    }
    
    // 4. 计算临时缓冲区大小
    uint32_t tmpSize;
    uint32_t maxsize = 0, minsize = 0;
    auto platformInfo = context->GetPlatformInfo();
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
    
    if (isReuseSource == 0) {
        AscendC::GetBroadCastMaxMinTmpSize(ascendcPlatform, inputShape, outputShape, 
                                          dtypesize, false, maxsize, minsize);
    } else {
        AscendC::GetBroadCastMaxMinTmpSize(ascendcPlatform, inputShape, outputShape, 
                                          dtypesize, true, maxsize, minsize);
    }
    
    // 根据bufferMode选择tmpSize
    if (bufferMode == 0) {
        tmpSize = 0;
    } else if (bufferMode == 1) {
        tmpSize = minsize;
    } else if (bufferMode == 2) {
        tmpSize = maxsize;
    } else {
        tmpSize = (maxsize + minsize) / 2;
    }
    
    // 5. 设置Tiling参数
    tiling.set_tmpSize(tmpSize);
    context->SetBlockDim(1);
    tiling.set_totalLength(totalLength);
    tiling.set_tilenum(1);
    tiling.set_isReuseSource(isReuseSource);
    tiling.set_axis(axis);
    tiling.set_num(num);
    tiling.set_bLength(bLength);
    tiling.set_dim(dim);
    
    context->SetTilingKey(1);
    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), 
                       context->GetRawTilingData()->GetCapacity());
    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
    size_t *currentWorkSpace = context->GetWorkspaceSizes(1);
    currentWorkSpace[0] = 0;
    return ge::GRAPH_SUCCESS;
}
} // namespace optiling

关键步骤

  1. 获取算子属性(bufferMode、dim、axis等)
  2. 计算输入输出形状
  3. 调用GetBroadCastMaxMinTmpSize计算临时缓冲区大小
  4. 根据bufferMode选择tmpSize
  5. 设置Tiling参数

17.3.2 算子注册

算子注册时定义属性:

namespace ops {
class BroadcastCustom : public OpDef {
public:
    BroadcastCustom(const char *name) : OpDef(name)
    {
        this->Input("x")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_UINT8})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
        
        this->Output("y")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT8, ge::DT_UINT8})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
        
        // 定义算子属性
        this->Attr("bufferMode").AttrType(REQUIRED).Int(0);
        this->Attr("dim").AttrType(REQUIRED).Int(0);
        this->Attr("isReuseSource").AttrType(REQUIRED).Int(0);
        this->Attr("axis").AttrType(REQUIRED).Int(0);
        this->Attr("num").AttrType(REQUIRED).Int(0);
        
        this->AICore()
            .SetTiling(optiling::TilingFunc)
            .AddConfig("ascend310p")
            .AddConfig("ascend910b");
    }
};

OP_ADD(BroadcastCustom);
} // namespace ops

17.4 Kernel端实现

17.4.1 Kernel类设计

KernelBroadcastCustom类实现Broadcast逻辑:

class KernelBroadcastCustom {
public:
    __aicore__ inline KernelBroadcastCustom() {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, uint32_t totalLength, 
                                uint32_t tilenum, uint32_t tmpSize,
                                uint32_t dim, uint32_t axis, uint32_t num, 
                                uint32_t bLength)
    {
        this->blockLength = totalLength / AscendC::GetBlockNum();
        this->tilenum = tilenum;
        this->tileLength = this->blockLength / tilenum / BUFFER_NUM;
        this->dim = dim;
        this->tmpSize = tmpSize;
        this->axis = axis;
        this->num = num;
        this->bLength = bLength;
        
        // 计算输出tile长度
        if (this->dim == 1) {
            this->tileLength2 = num;
        } else {
            this->tileLength2 = this->tileLength * num;
        }
        
        // 设置GlobalTensor
        xGm.SetGlobalBuffer((__gm__ DTYPE_X *)x + 
                           this->blockLength * AscendC::GetBlockIdx(), 
                           this->blockLength);
        yGm.SetGlobalBuffer((__gm__ DTYPE_Y *)y + 
                           this->tileLength2 * AscendC::GetBlockIdx(), 
                           this->blockLength);
        
        // 初始化队列
        pipe.InitBuffer(inQueueX, BUFFER_NUM, 
                       this->tileLength * sizeof(DTYPE_X));
        pipe.InitBuffer(outQueueY, BUFFER_NUM, 
                       this->tileLength2 * sizeof(DTYPE_Y));
        
        // 初始化临时缓冲区(如果需要)
        if (this->tmpSize != 0) {
            pipe.InitBuffer(tmpQueue, this->tmpSize);
        }
    }
    
    __aicore__ inline void Process()
    {
        int32_t loopCount = this->tilenum * BUFFER_NUM;
        for (int32_t i = 0; i < loopCount; i++) {
            CopyIn(i);
            Compute(i);
            CopyOut(i);
        }
    }

17.4.2 Compute函数

Compute函数使用BroadCast API进行广播:

__aicore__ inline void Compute(int32_t progress)
{
    AscendC::LocalTensor<DTYPE_X> xLocal = inQueueX.DeQue<DTYPE_X>();
    AscendC::LocalTensor<DTYPE_Y> yLocal = outQueueY.AllocTensor<DTYPE_Y>();
    
    if (this->tmpSize == 0) {
        // 不使用临时缓冲区
        if (this->dim == 1) {
            // 1维广播
            const uint32_t srcShape[] = {1};
            const uint32_t dstShape[] = {this->num};
            AscendC::BroadCast<DTYPE_X, 1, 0>(yLocal, xLocal, dstShape, srcShape);
        } else {
            // 2维广播
            const uint32_t srcShape[] = {this->bLength, this->tileLength / this->bLength};
            if (this->axis == 0) {
                // 沿第0维广播
                const uint32_t dstShape[] = {this->bLength * this->num, 
                                            this->tileLength / this->bLength};
                AscendC::BroadCast<DTYPE_X, 2, 0>(yLocal, xLocal, dstShape, srcShape);
            } else {
                // 沿第1维广播
                const uint32_t dstShape[] = {this->bLength, 
                                            this->tileLength / this->bLength * this->num};
                AscendC::BroadCast<DTYPE_X, 2, 1>(yLocal, xLocal, dstShape, srcShape);
            }
        }
    } else {
        // 使用临时缓冲区
        AscendC::LocalTensor<uint8_t> tmpTensor = tmpQueue.Get<uint8_t>();
        if (this->dim == 1) {
            const uint32_t srcShape[] = {1};
            const uint32_t dstShape[] = {this->num};
            AscendC::BroadCast<DTYPE_X, 1, 0>(yLocal, xLocal, dstShape, srcShape, tmpTensor);
        } else {
            const uint32_t srcShape[] = {this->bLength, this->tileLength / this->bLength};
            if (this->axis == 0) {
                const uint32_t dstShape[] = {this->bLength * this->num, 
                                            this->tileLength / this->bLength};
                AscendC::BroadCast<DTYPE_X, 2, 0>(yLocal, xLocal, dstShape, srcShape, tmpTensor);
            } else {
                const uint32_t dstShape[] = {this->bLength, 
                                            this->tileLength / this->bLength * this->num};
                AscendC::BroadCast<DTYPE_X, 2, 1>(yLocal, xLocal, dstShape, srcShape, tmpTensor);
            }
        }
        tmpQueue.FreeTensor(tmpTensor);
    }
    
    outQueueY.EnQue<DTYPE_Y>(yLocal);
    inQueueX.FreeTensor(xLocal);
}

BroadCast API

AscendC::BroadCast<DataType, Dim, Axis>(dst, src, dstShape, srcShape);
AscendC::BroadCast<DataType, Dim, Axis>(dst, src, dstShape, srcShape, tmpTensor);
  • DataType:数据类型
  • Dim:维度(1或2)
  • Axis:广播轴(0或1)
  • dst:输出LocalTensor
  • src:输入LocalTensor
  • dstShape:输出形状数组
  • srcShape:输入形状数组
  • tmpTensor:临时缓冲区(可选)

17.4.3 Kernel函数

extern "C" __global__ __aicore__ void broadcast_custom(
    GM_ADDR x, GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    KernelBroadcastCustom op;
    op.Init(x, y, tilingData.totalLength, tilingData.tilenum, 
           tilingData.tmpSize, tilingData.dim, tilingData.axis,
           tilingData.num, tilingData.bLength);
    if (TILING_KEY_IS(1)) {
        op.Process();
    }
}

17.5 Broadcast API详解

17.5.1 1维Broadcast

1维Broadcast将shape为(1)的tensor广播到(num):

const uint32_t srcShape[] = {1};
const uint32_t dstShape[] = {num};
AscendC::BroadCast<DTYPE_X, 1, 0>(yLocal, xLocal, dstShape, srcShape);

示例:

输入: [a]
输出: [a, a, a, ..., a]  // num个a

17.5.2 2维Broadcast(axis=0)

沿第0维广播,将(bLength, sLength)广播到(bLength * num, sLength):

const uint32_t srcShape[] = {bLength, sLength};
const uint32_t dstShape[] = {bLength * num, sLength};
AscendC::BroadCast<DTYPE_X, 2, 0>(yLocal, xLocal, dstShape, srcShape);

示例:

输入: [[1, 2], [3, 4]]  // shape (2, 2)
输出: [[1, 2], [3, 4], [1, 2], [3, 4], ...]  // shape (2*num, 2)

17.5.3 2维Broadcast(axis=1)

沿第1维广播,将(bLength, sLength)广播到(bLength, sLength * num):

const uint32_t srcShape[] = {bLength, sLength};
const uint32_t dstShape[] = {bLength, sLength * num};
AscendC::BroadCast<DTYPE_X, 2, 1>(yLocal, xLocal, dstShape, srcShape);

示例:

输入: [[1], [2]]  // shape (2, 1)
输出: [[1, 1, 1, ...], [2, 2, 2, ...]]  // shape (2, num)

17.5.4 临时缓冲区

对于复杂的Broadcast,可能需要临时缓冲区:

AscendC::LocalTensor<uint8_t> tmpTensor = tmpQueue.Get<uint8_t>();
AscendC::BroadCast<DTYPE_X, 2, 0>(yLocal, xLocal, dstShape, srcShape, tmpTensor);
tmpQueue.FreeTensor(tmpTensor);

临时缓冲区的大小通过GetBroadCastMaxMinTmpSize计算。


17.6 调用示例

17.6.1 调用流程

AclNNInvocationNaive/main.cpp展示了如何调用Broadcast算子:

int main(int argc, char **argv)
{
    // 1. 初始化ACL
    int32_t deviceId = 0;
    aclrtStream stream;
    Init(deviceId, &stream);
    
    // 2. 创建输入输出Tensor
    std::vector<int64_t> inputXShape = {16, 1};   // 输入shape
    std::vector<int64_t> outputYShape = {16, 3};  // 输出shape
    
    aclTensor *inputX = CreateAclTensor(...);
    aclTensor *outputY = CreateAclTensor(...);
    
    // 3. 设置Broadcast参数
    uint32_t axis = 1;              // 沿第1维广播
    bool isReuseSource = false;     // 不重用源数据
    uint32_t bufferMode = 1;        // 最小最优临时空间
    uint32_t dim = 2;               // 2维
    uint32_t num = (outputYShape[0] * outputYShape[1]) / 
                   (inputXShape[0] * inputXShape[1]);  // 广播倍数
    
    // 4. 调用算子API
    uint64_t workspaceSize = 0;
    aclOpExecutor *executor;
    aclnnBroadcastCustomGetWorkspaceSize(inputX, bufferMode, dim, 
                                        isReuseSource, axis, num, outputY,
                                        &workspaceSize, &executor);
    
    void *workspaceAddr = nullptr;
    if (workspaceSize > 0) {
        aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
    }
    
    aclnnBroadcastCustom(workspaceAddr, workspaceSize, executor, stream);
    
    // 5. 同步并获取结果
    aclrtSynchronizeStream(stream);
    // ...
}

17.6.2 参数计算

num的计算

uint32_t num = (outputYShape[0] * outputYShape[1]) / 
               (inputXShape[0] * inputXShape[1]);

num表示输出大小是输入的多少倍,用于确定广播的倍数。

示例

  • 输入shape: (16, 1),元素数 = 16
  • 输出shape: (16, 3),元素数 = 48
  • num = 48 / 16 = 3

17.7 临时缓冲区管理

17.7.1 缓冲区大小计算

GetBroadCastMaxMinTmpSize函数计算临时缓冲区的大小:

uint32_t maxsize = 0, minsize = 0;
auto platformInfo = context->GetPlatformInfo();
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);

if (isReuseSource == 0) {
    AscendC::GetBroadCastMaxMinTmpSize(ascendcPlatform, inputShape, outputShape, 
                                      dtypesize, false, maxsize, minsize);
} else {
    AscendC::GetBroadCastMaxMinTmpSize(ascendcPlatform, inputShape, outputShape, 
                                      dtypesize, true, maxsize, minsize);
}

参数说明

  • platformInfo:平台信息
  • inputShape:输入形状
  • outputShape:输出形状
  • dtypesize:数据类型大小(2或4字节)
  • isReuseSource:是否重用源数据
  • maxsize:最大最优临时空间大小(输出)
  • minsize:最小最优临时空间大小(输出)

17.7.2 缓冲区模式选择

根据bufferMode选择tmpSize:

if (bufferMode == 0) {
    tmpSize = 0;  // 不分配
} else if (bufferMode == 1) {
    tmpSize = minsize;  // 最小
} else if (bufferMode == 2) {
    tmpSize = maxsize;  // 最大
} else {
    tmpSize = (maxsize + minsize) / 2;  // 中间值
}

选择建议

  • bufferMode=0:内存充足时,不使用临时缓冲区可能更快
  • bufferMode=1:内存受限时,使用最小缓冲区
  • bufferMode=2:性能优先时,使用最大缓冲区
  • 其他值:平衡性能和内存

17.8 与Add算子的对比

17.8.1 复杂度对比

Add算子

  • 输入输出形状相同
  • 简单的element-wise操作
  • 不需要复杂的形状处理

Broadcast算子

  • 输入输出形状不同
  • 需要处理形状变换
  • 需要计算临时缓冲区

17.8.2 实现对比

Add算子

AscendC::Add(zLocal, xLocal, yLocal, tileLength);

Broadcast算子

AscendC::BroadCast<DTYPE_X, 2, 1>(yLocal, xLocal, dstShape, srcShape);

Broadcast需要指定形状和轴,更复杂。

17.8.3 应用场景对比

Add算子:用于两个相同形状tensor的相加。

Broadcast算子:用于形状不匹配时的扩展,通常配合其他算子使用。


17.9 关键注意事项

17.9.1 形状计算

必须正确计算输入输出形状,确保Broadcast规则正确:

  • 从最右边维度对齐
  • 维度大小为1可以广播
  • 维度大小相同保持不变

17.9.2 轴的选择

axis必须是0或1(对于2维):

  • axis=0:沿第0维广播
  • axis=1:沿第1维广播

17.9.3 临时缓冲区

临时缓冲区的大小需要根据实际情况选择:

  • 太小可能导致计算失败
  • 太大可能浪费内存
  • 使用GetBroadCastMaxMinTmpSize获取建议值

17.9.4 数据类型

Broadcast支持多种数据类型(float16、float32、int8、uint8),但输入输出类型必须相同。


17.10 扩展:支持更多维度

17.10.1 支持3维Broadcast

如果要支持3维Broadcast,需要:

  1. 在Tiling函数中处理3维形状:
if (dim == 3) {
    uint32_t d0 = src_shape->GetStorageShape().GetDim(0);
    uint32_t d1 = src_shape->GetStorageShape().GetDim(1);
    uint32_t d2 = src_shape->GetStorageShape().GetDim(2);
    // ...
}
  1. 在Kernel中使用3维BroadCast API:
const uint32_t srcShape[] = {d0, d1, d2};
const uint32_t dstShape[] = {...};
AscendC::BroadCast<DTYPE_X, 3, axis>(yLocal, xLocal, dstShape, srcShape);

17.10.2 支持多轴Broadcast

如果要支持同时沿多个轴广播,需要更复杂的逻辑,可能需要多次调用BroadCast API或使用更高级的API。


2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。

报名链接:https://www.hiascend.com/developer/activities/cann20252

社区地址:https://www.hiascend.com/developer

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐