| 1 | //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements in-dialect lowering of the all-reduce op to a block of |
| 10 | // simpler instructions. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 16 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| 17 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| 18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 19 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 20 | #include "mlir/IR/Builders.h" |
| 21 | #include "mlir/IR/IRMapping.h" |
| 22 | #include "mlir/IR/PatternMatch.h" |
| 23 | #include "mlir/Pass/Pass.h" |
| 24 | #include "llvm/Support/ErrorHandling.h" |
| 25 | |
| 26 | using namespace mlir; |
| 27 | |
| 28 | namespace { |
| 29 | |
| 30 | struct GpuAllReduceRewriter { |
| 31 | using AccumulatorFactory = std::function<Value(Value, Value)>; |
| 32 | |
| 33 | GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp, |
| 34 | PatternRewriter &rewriter) |
| 35 | : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter), |
| 36 | loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()), |
| 37 | indexType(IndexType::get(reduceOp.getContext())), |
| 38 | int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} |
| 39 | |
| 40 | /// Creates an all_reduce across the workgroup. |
| 41 | /// |
| 42 | /// First reduce the elements within a subgroup. The first invocation of each |
| 43 | /// subgroup writes the intermediate result to workgroup memory. After |
| 44 | /// synchronizing the workgroup, the first subgroup reduces the values from |
| 45 | /// workgroup memory. The result is broadcasted to all invocations through |
| 46 | /// workgroup memory. |
| 47 | /// |
| 48 | /// %subgroup_reduce = `createSubgroupReduce(%operand)` |
| 49 | /// cf.cond_br %is_first_lane, ^then1, ^continue1 |
| 50 | /// ^then1: |
| 51 | /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] |
| 52 | /// cf.br ^continue1 |
| 53 | /// ^continue1: |
| 54 | /// gpu.barrier |
| 55 | /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups |
| 56 | /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2 |
| 57 | /// ^then2: |
| 58 | /// %partial_reduce = load %workgroup_buffer[%invocation_idx] |
| 59 | /// %all_reduce = `createSubgroupReduce(%partial_reduce)` |
| 60 | /// store %all_reduce, %workgroup_buffer[%zero] |
| 61 | /// llvm.br ^continue2 |
| 62 | /// ^continue2: |
| 63 | /// gpu.barrier |
| 64 | /// %result = load %workgroup_buffer[%zero] |
| 65 | /// return %result |
| 66 | /// |
| 67 | void rewrite() { |
| 68 | rewriter.setInsertionPoint(reduceOp); |
| 69 | |
| 70 | // Compute linear invocation index and workgroup size. |
| 71 | Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x); |
| 72 | Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y); |
| 73 | Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z); |
| 74 | Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x); |
| 75 | Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y); |
| 76 | Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z); |
| 77 | Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY); |
| 78 | Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY); |
| 79 | Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX); |
| 80 | Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY); |
| 81 | Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX); |
| 82 | Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ); |
| 83 | |
| 84 | // Compute lane id (invocation id withing the subgroup). |
| 85 | Value subgroupMask = |
| 86 | create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type); |
| 87 | Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask); |
| 88 | Value isFirstLane = |
| 89 | create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId, |
| 90 | create<arith::ConstantIntOp>(0, int32Type)); |
| 91 | |
| 92 | Value numThreadsWithSmallerSubgroupId = |
| 93 | create<arith::SubIOp>(invocationIdx, laneId); |
| 94 | // The number of active invocations starting from the current subgroup. |
| 95 | // The consumers do not require the value to be clamped to the size of the |
| 96 | // subgroup. |
| 97 | Value activeWidth = |
| 98 | create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); |
| 99 | |
| 100 | // Create factory for op which accumulates to values. |
| 101 | AccumulatorFactory accumFactory = getFactory(); |
| 102 | assert(accumFactory && "failed to create accumulator factory" ); |
| 103 | |
| 104 | // Reduce elements within each subgroup to produce the intermediate results. |
| 105 | Value subgroupReduce = createSubgroupReduce( |
| 106 | activeWidth, laneId, reduceOp.getValue(), accumFactory); |
| 107 | |
| 108 | // Add workgroup buffer to parent function for intermediate result. |
| 109 | Value buffer = createWorkgroupBuffer(); |
| 110 | |
| 111 | // Write the intermediate results to workgroup memory, using the first lane |
| 112 | // of each subgroup. |
| 113 | createPredicatedBlock(condition: isFirstLane, predicatedOpsFactory: [&] { |
| 114 | Value subgroupId = getDivideBySubgroupSize(value: invocationIdx); |
| 115 | Value index = create<arith::IndexCastOp>(indexType, subgroupId); |
| 116 | create<memref::StoreOp>(subgroupReduce, buffer, index); |
| 117 | }); |
| 118 | create<gpu::BarrierOp>(); |
| 119 | |
| 120 | // Compute number of active subgroups. |
| 121 | Value biasedBlockSize = |
| 122 | create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask); |
| 123 | Value numSubgroups = getDivideBySubgroupSize(value: biasedBlockSize); |
| 124 | Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, |
| 125 | invocationIdx, numSubgroups); |
| 126 | |
| 127 | // Use the first numSubgroups invocations to reduce the intermediate results |
| 128 | // from workgroup memory. The final result is written to workgroup memory |
| 129 | // again. |
| 130 | Value zero = create<arith::ConstantIndexOp>(args: 0); |
| 131 | createPredicatedBlock(condition: isValidSubgroup, predicatedOpsFactory: [&] { |
| 132 | Value index = create<arith::IndexCastOp>(indexType, invocationIdx); |
| 133 | Value value = create<memref::LoadOp>(valueType, buffer, index); |
| 134 | Value result = |
| 135 | createSubgroupReduce(activeWidth: numSubgroups, laneId, operand: value, accumFactory); |
| 136 | create<memref::StoreOp>(result, buffer, zero); |
| 137 | }); |
| 138 | |
| 139 | // Synchronize workgroup and load result from workgroup memory. |
| 140 | create<gpu::BarrierOp>(); |
| 141 | Value result = create<memref::LoadOp>(valueType, buffer, zero); |
| 142 | |
| 143 | rewriter.replaceOp(reduceOp, result); |
| 144 | } |
| 145 | |
| 146 | private: |
| 147 | // Shortcut to create an op from rewriter using loc as the first argument. |
| 148 | template <typename T, typename... Args> |
| 149 | T create(Args... args) { |
| 150 | return rewriter.create<T>(loc, std::forward<Args>(args)...); |
| 151 | } |
| 152 | |
| 153 | // Creates dimension op of type T, with the result casted to int32. |
| 154 | template <typename T> |
| 155 | Value getDimOp(gpu::Dimension dimension) { |
| 156 | Value dim = create<T>(indexType, dimension); |
| 157 | return create<arith::IndexCastOp>(int32Type, dim); |
| 158 | } |
| 159 | |
| 160 | /// Adds type to funcOp's workgroup attributions. |
| 161 | Value createWorkgroupBuffer() { |
| 162 | // TODO: Pick a proper location for the attribution. |
| 163 | auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get( |
| 164 | funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); |
| 165 | auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{}, |
| 166 | workgroupMemoryAddressSpace); |
| 167 | return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc()); |
| 168 | } |
| 169 | |
| 170 | /// Returns an accumulator factory using either the op attribute or the body |
| 171 | /// region. |
| 172 | AccumulatorFactory getFactory() { |
| 173 | auto &body = reduceOp.getBody(); |
| 174 | if (!body.empty()) |
| 175 | return getFactory(body); |
| 176 | auto opAttr = reduceOp.getOp(); |
| 177 | if (opAttr) |
| 178 | return getFactory(*opAttr); |
| 179 | return AccumulatorFactory(); |
| 180 | } |
| 181 | |
| 182 | /// Returns an accumulator factory that clones the body. The body's entry |
| 183 | /// block is expected to have 2 arguments. The gpu.yield return the |
| 184 | /// accumulated value of the same type. |
| 185 | AccumulatorFactory getFactory(Region &body) { |
| 186 | return [&body, this](Value lhs, Value rhs) -> Value { |
| 187 | Block *block = rewriter.getInsertionBlock(); |
| 188 | Block *split = rewriter.splitBlock(block, before: rewriter.getInsertionPoint()); |
| 189 | |
| 190 | // Insert accumulator body between split block. |
| 191 | IRMapping mapping; |
| 192 | mapping.map(from: body.getArgument(i: 0), to: lhs); |
| 193 | mapping.map(from: body.getArgument(i: 1), to: rhs); |
| 194 | rewriter.cloneRegionBefore(region&: body, parent&: *split->getParent(), |
| 195 | before: split->getIterator(), mapping); |
| 196 | |
| 197 | // Add branch before inserted body, into body. |
| 198 | block = block->getNextNode(); |
| 199 | create<cf::BranchOp>(block, ValueRange()); |
| 200 | |
| 201 | // Replace all gpu.yield ops with branch out of body. |
| 202 | for (; block != split; block = block->getNextNode()) { |
| 203 | Operation *terminator = block->getTerminator(); |
| 204 | if (!isa<gpu::YieldOp>(terminator)) |
| 205 | continue; |
| 206 | rewriter.setInsertionPointToEnd(block); |
| 207 | rewriter.replaceOpWithNewOp<cf::BranchOp>( |
| 208 | terminator, split, ValueRange(terminator->getOperand(0))); |
| 209 | } |
| 210 | |
| 211 | // Return accumulator result. |
| 212 | rewriter.setInsertionPointToStart(split); |
| 213 | return split->addArgument(type: lhs.getType(), loc: lhs.getLoc()); |
| 214 | }; |
| 215 | } |
| 216 | |
| 217 | /// Returns an accumulator factory that creates an op specified by opName. |
| 218 | AccumulatorFactory getFactory(gpu::AllReduceOperation opName) { |
| 219 | return [opName, this](Value lhs, Value rhs) { |
| 220 | return vector::makeArithReduction(rewriter, loc, |
| 221 | convertReductionKind(opName), lhs, rhs); |
| 222 | }; |
| 223 | } |
| 224 | |
| 225 | /// Creates an if-block skeleton and calls the two factories to generate the |
| 226 | /// ops in the `then` and `else` block.. |
| 227 | /// |
| 228 | /// llvm.cond_br %condition, ^then, ^continue |
| 229 | /// ^then: |
| 230 | /// %then_operands = `thenOpsFactory()` |
| 231 | /// llvm.br ^continue(%then_operands) |
| 232 | /// ^else: |
| 233 | /// %else_operands = `elseOpsFactory()` |
| 234 | /// llvm.br ^continue(%else_operands) |
| 235 | /// ^continue(%block_operands): |
| 236 | /// |
| 237 | template <typename ThenOpsFactory, typename ElseOpsFactory> |
| 238 | void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, |
| 239 | ElseOpsFactory &&elseOpsFactory) { |
| 240 | Block *currentBlock = rewriter.getInsertionBlock(); |
| 241 | auto currentPoint = rewriter.getInsertionPoint(); |
| 242 | |
| 243 | Block *thenBlock = rewriter.splitBlock(block: currentBlock, before: currentPoint); |
| 244 | Block *elseBlock = rewriter.splitBlock(block: thenBlock, before: thenBlock->begin()); |
| 245 | Block *continueBlock = rewriter.splitBlock(block: elseBlock, before: elseBlock->begin()); |
| 246 | |
| 247 | rewriter.setInsertionPointToEnd(currentBlock); |
| 248 | create<cf::CondBranchOp>(condition, thenBlock, |
| 249 | /*trueOperands=*/ArrayRef<Value>(), elseBlock, |
| 250 | /*falseOperands=*/ArrayRef<Value>()); |
| 251 | |
| 252 | rewriter.setInsertionPointToStart(thenBlock); |
| 253 | auto thenOperands = thenOpsFactory(); |
| 254 | create<cf::BranchOp>(continueBlock, thenOperands); |
| 255 | |
| 256 | rewriter.setInsertionPointToStart(elseBlock); |
| 257 | auto elseOperands = elseOpsFactory(); |
| 258 | create<cf::BranchOp>(continueBlock, elseOperands); |
| 259 | |
| 260 | assert(thenOperands.size() == elseOperands.size()); |
| 261 | rewriter.setInsertionPointToStart(continueBlock); |
| 262 | for (auto operand : thenOperands) |
| 263 | continueBlock->addArgument(type: operand.getType(), loc: operand.getLoc()); |
| 264 | } |
| 265 | |
| 266 | /// Shortcut for createIf with empty else block and no block operands. |
| 267 | template <typename Factory> |
| 268 | void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { |
| 269 | static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, |
| 270 | "predicatedOpsFactory should not return any value" ); |
| 271 | createIf( |
| 272 | condition, |
| 273 | [&] { |
| 274 | predicatedOpsFactory(); |
| 275 | return ArrayRef<Value>(); |
| 276 | }, |
| 277 | [&] { return ArrayRef<Value>(); }); |
| 278 | } |
| 279 | |
| 280 | /// Creates a reduction across the first activeWidth lanes of a subgroup, or |
| 281 | /// the entire subgroup if activeWidth is larger than the subgroup width. |
| 282 | /// The first lane returns the result, all others return values are undefined. |
| 283 | Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, |
| 284 | AccumulatorFactory &accumFactory) { |
| 285 | Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); |
| 286 | Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, |
| 287 | activeWidth, subgroupSize); |
| 288 | std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; |
| 289 | |
| 290 | createIf( |
| 291 | condition: isPartialSubgroup, |
| 292 | // Generate reduction over a (potentially) partial subgroup. |
| 293 | thenOpsFactory: [&] { |
| 294 | Value value = operand; |
| 295 | // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source |
| 296 | // lane is within the active range. The accumulated value is available |
| 297 | // in the first lane. |
| 298 | for (int i = 1; i < kSubgroupSize; i <<= 1) { |
| 299 | Value offset = create<arith::ConstantIntOp>(i, int32Type); |
| 300 | auto shuffleOp = create<gpu::ShuffleOp>( |
| 301 | shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR); |
| 302 | // Skip the accumulation if the shuffle op read from a lane outside |
| 303 | // of the active range. |
| 304 | createIf( |
| 305 | shuffleOp.getResult(1), |
| 306 | [&] { |
| 307 | return SmallVector<Value, 1>{ |
| 308 | accumFactory(value, shuffleOp.getResult(0))}; |
| 309 | }, |
| 310 | [&] { return llvm::ArrayRef(value); }); |
| 311 | value = rewriter.getInsertionBlock()->getArgument(i: 0); |
| 312 | } |
| 313 | return SmallVector<Value, 1>{value}; |
| 314 | }, |
| 315 | // Generate a reduction over the entire subgroup. This is a |
| 316 | // specialization of the above reduction with unconditional |
| 317 | // accumulation. |
| 318 | elseOpsFactory: [&] { |
| 319 | Value value = operand; |
| 320 | for (int i = 1; i < kSubgroupSize; i <<= 1) { |
| 321 | Value offset = create<arith::ConstantIntOp>(i, int32Type); |
| 322 | auto shuffleOp = |
| 323 | create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize, |
| 324 | gpu::ShuffleMode::XOR); |
| 325 | value = accumFactory(value, shuffleOp.getResult(0)); |
| 326 | } |
| 327 | return SmallVector<Value, 1>{value}; |
| 328 | }); |
| 329 | return rewriter.getInsertionBlock()->getArgument(i: 0); |
| 330 | } |
| 331 | |
| 332 | /// Returns value divided by the subgroup size (i.e. 32). |
| 333 | Value getDivideBySubgroupSize(Value value) { |
| 334 | Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); |
| 335 | return create<arith::DivSIOp>(int32Type, value, subgroupSize); |
| 336 | } |
| 337 | |
| 338 | gpu::GPUFuncOp funcOp; |
| 339 | gpu::AllReduceOp reduceOp; |
| 340 | PatternRewriter &rewriter; |
| 341 | |
| 342 | Location loc; |
| 343 | Type valueType; |
| 344 | Type indexType; |
| 345 | IntegerType int32Type; |
| 346 | |
| 347 | static constexpr int kSubgroupSize = 32; |
| 348 | }; |
| 349 | |
| 350 | struct GpuAllReduceRewrite : public RewritePattern { |
| 351 | explicit GpuAllReduceRewrite(MLIRContext *context) |
| 352 | : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} |
| 353 | |
| 354 | LogicalResult matchAndRewrite(Operation *op, |
| 355 | PatternRewriter &rewriter) const override { |
| 356 | auto funcOp = cast<gpu::GPUFuncOp>(op); |
| 357 | |
| 358 | SmallVector<gpu::AllReduceOp> reduceOps; |
| 359 | auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult { |
| 360 | if (!reduceOp.getUniform()) |
| 361 | return WalkResult::interrupt(); |
| 362 | |
| 363 | reduceOps.emplace_back(reduceOp); |
| 364 | return WalkResult::advance(); |
| 365 | }; |
| 366 | |
| 367 | if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty()) |
| 368 | return rewriter.notifyMatchFailure( |
| 369 | arg&: op, msg: "Non uniform reductions are not supported yet." ); |
| 370 | |
| 371 | for (gpu::AllReduceOp reduceOp : reduceOps) |
| 372 | GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); |
| 373 | |
| 374 | return success(); |
| 375 | } |
| 376 | }; |
| 377 | } // namespace |
| 378 | |
| 379 | void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) { |
| 380 | patterns.add<GpuAllReduceRewrite>(arg: patterns.getContext()); |
| 381 | } |
| 382 | |