| File: | src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/X86/X86LowerAMXType.cpp |
| Warning: | line 425, column 20 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// | |||
| 2 | // | |||
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
| 4 | // See https://llvm.org/LICENSE.txt for license information. | |||
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
| 6 | // | |||
| 7 | //===----------------------------------------------------------------------===// | |||
| 8 | // | |||
| 9 | /// \file Pass to transform <256 x i32> load/store | |||
| 10 | /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only | |||
| 11 | /// provides simple operation on x86_amx. The basic elementwise operation | |||
| 12 | /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> | |||
| 13 | /// and only AMX intrinsics can operate on the type, we need transform | |||
| 14 | /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can | |||
| 15 | /// not be combined with load/store, we transform the bitcast to amx load/store | |||
| 16 | /// and <256 x i32> store/load. | |||
| 17 | /// | |||
| 18 | /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S | |||
| 19 | /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, | |||
| 20 | /// because that is necessary for AMX fast register allocation. (In Fast | |||
| 21 | /// registera allocation, register will be allocated before spill/reload, so | |||
| 22 | /// there is no additional register for amx to identify the step in spill.) | |||
| 23 | /// The volatileTileData() will handle this case. | |||
| 24 | /// e.g. | |||
| 25 | /// ---------------------------------------------------------- | |||
| 26 | /// | def %td = ... | | |||
| 27 | /// | ... | | |||
| 28 | /// | "use %td" | | |||
| 29 | /// ---------------------------------------------------------- | |||
| 30 | /// will transfer to --> | |||
| 31 | /// ---------------------------------------------------------- | |||
| 32 | /// | def %td = ... | | |||
| 33 | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | | |||
| 34 | /// | ... | | |||
| 35 | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| | |||
| 36 | /// | "use %td2" | | |||
| 37 | /// ---------------------------------------------------------- | |||
| 38 | // | |||
| 39 | //===----------------------------------------------------------------------===// | |||
| 40 | // | |||
| 41 | #include "X86.h" | |||
| 42 | #include "llvm/ADT/PostOrderIterator.h" | |||
| 43 | #include "llvm/ADT/SmallSet.h" | |||
| 44 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" | |||
| 45 | #include "llvm/Analysis/TargetTransformInfo.h" | |||
| 46 | #include "llvm/CodeGen/Passes.h" | |||
| 47 | #include "llvm/CodeGen/TargetPassConfig.h" | |||
| 48 | #include "llvm/CodeGen/ValueTypes.h" | |||
| 49 | #include "llvm/IR/DataLayout.h" | |||
| 50 | #include "llvm/IR/Function.h" | |||
| 51 | #include "llvm/IR/IRBuilder.h" | |||
| 52 | #include "llvm/IR/Instructions.h" | |||
| 53 | #include "llvm/IR/IntrinsicInst.h" | |||
| 54 | #include "llvm/IR/IntrinsicsX86.h" | |||
| 55 | #include "llvm/IR/PatternMatch.h" | |||
| 56 | #include "llvm/InitializePasses.h" | |||
| 57 | #include "llvm/Pass.h" | |||
| 58 | #include "llvm/Target/TargetMachine.h" | |||
| 59 | ||||
| 60 | using namespace llvm; | |||
| 61 | using namespace PatternMatch; | |||
| 62 | ||||
| 63 | #define DEBUG_TYPE"lower-amx-type" "lower-amx-type" | |||
| 64 | ||||
| 65 | static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, | |||
| 66 | BasicBlock *BB) { | |||
| 67 | Function &F = *BB->getParent(); | |||
| 68 | Module *M = BB->getModule(); | |||
| 69 | const DataLayout &DL = M->getDataLayout(); | |||
| 70 | ||||
| 71 | Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); | |||
| 72 | LLVMContext &Ctx = Builder.getContext(); | |||
| 73 | auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); | |||
| 74 | unsigned AllocaAS = DL.getAllocaAddrSpace(); | |||
| 75 | AllocaInst *AllocaRes = | |||
| 76 | new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front()); | |||
| 77 | AllocaRes->setAlignment(AllocaAlignment); | |||
| 78 | return AllocaRes; | |||
| 79 | } | |||
| 80 | ||||
| 81 | namespace { | |||
| 82 | class X86LowerAMXType { | |||
| 83 | Function &Func; | |||
| 84 | TargetMachine *TM = nullptr; | |||
| 85 | ||||
| 86 | // In AMX intrinsics we let Shape = {Row, Col}, but the | |||
| 87 | // RealCol = Col / ElementSize. We may use the RealCol | |||
| 88 | // as a new Row for other new created AMX intrinsics. | |||
| 89 | std::map<Value *, Value *> Col2Row; | |||
| 90 | ||||
| 91 | public: | |||
| 92 | X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {} | |||
| 93 | bool visit(); | |||
| 94 | void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); | |||
| 95 | void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); | |||
| 96 | bool transformBitcast(BitCastInst *Bitcast); | |||
| 97 | std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); | |||
| 98 | Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); | |||
| 99 | }; | |||
| 100 | ||||
| 101 | Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V, | |||
| 102 | unsigned Granularity) { | |||
| 103 | if (Col2Row.count(V)) | |||
| 104 | return Col2Row[V]; | |||
| 105 | IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt()); | |||
| 106 | if (auto *I = dyn_cast<Instruction>(V)) { | |||
| 107 | BasicBlock::iterator Iter = I->getIterator(); | |||
| 108 | ++Iter; | |||
| 109 | Builder.SetInsertPoint(&*Iter); | |||
| 110 | } | |||
| 111 | ConstantInt *Gran = Builder.getInt16(Granularity); | |||
| 112 | Value *RealRow = Builder.CreateUDiv(V, Gran); | |||
| 113 | Col2Row[V] = RealRow; | |||
| 114 | return RealRow; | |||
| 115 | } | |||
| 116 | ||||
| 117 | std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II, | |||
| 118 | unsigned OpNo) { | |||
| 119 | Value *Row = nullptr, *Col = nullptr; | |||
| 120 | switch (II->getIntrinsicID()) { | |||
| 121 | default: | |||
| 122 | llvm_unreachable("Expect amx intrinsics")__builtin_unreachable(); | |||
| 123 | case Intrinsic::x86_tileloadd64_internal: | |||
| 124 | case Intrinsic::x86_tileloaddt164_internal: | |||
| 125 | case Intrinsic::x86_tilestored64_internal: { | |||
| 126 | Row = II->getArgOperand(0); | |||
| 127 | Col = II->getArgOperand(1); | |||
| 128 | break; | |||
| 129 | } | |||
| 130 | // a * b + c | |||
| 131 | // The shape depends on which operand. | |||
| 132 | case Intrinsic::x86_tdpbssd_internal: | |||
| 133 | case Intrinsic::x86_tdpbsud_internal: | |||
| 134 | case Intrinsic::x86_tdpbusd_internal: | |||
| 135 | case Intrinsic::x86_tdpbuud_internal: | |||
| 136 | case Intrinsic::x86_tdpbf16ps_internal: { | |||
| 137 | switch (OpNo) { | |||
| 138 | case 3: | |||
| 139 | Row = II->getArgOperand(0); | |||
| 140 | Col = II->getArgOperand(1); | |||
| 141 | break; | |||
| 142 | case 4: | |||
| 143 | Row = II->getArgOperand(0); | |||
| 144 | Col = II->getArgOperand(2); | |||
| 145 | break; | |||
| 146 | case 5: | |||
| 147 | Row = II->getArgOperand(2); | |||
| 148 | // FIXME: There is a design bug for AMX shape, which the Col should be | |||
| 149 | // Col/4 if it will be used as Row, but current Greedy RA can't handle | |||
| 150 | // this case well, it may failed if we generate a new Shape definition. | |||
| 151 | // So Let's just do it in O0 first. | |||
| 152 | // Row = Row / 4 | |||
| 153 | if (TM->getOptLevel() == CodeGenOpt::None) | |||
| 154 | Row = getRowFromCol(II, Row, 4); | |||
| 155 | Col = II->getArgOperand(1); | |||
| 156 | break; | |||
| 157 | } | |||
| 158 | break; | |||
| 159 | } | |||
| 160 | } | |||
| 161 | ||||
| 162 | return std::make_pair(Row, Col); | |||
| 163 | } | |||
| 164 | ||||
| 165 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
| 166 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
| 167 | // --> | |||
| 168 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
| 169 | // i8* %addr, i64 %stride64) | |||
| 170 | void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { | |||
| 171 | Value *Row = nullptr, *Col = nullptr; | |||
| 172 | Use &U = *(Bitcast->use_begin()); | |||
| 173 | unsigned OpNo = U.getOperandNo(); | |||
| 174 | auto *II = cast<IntrinsicInst>(U.getUser()); | |||
| 175 | std::tie(Row, Col) = getShape(II, OpNo); | |||
| 176 | IRBuilder<> Builder(Bitcast); | |||
| 177 | // Use the maximun column as stride. | |||
| 178 | Value *Stride = Builder.getInt64(64); | |||
| 179 | Value *I8Ptr = | |||
| 180 | Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); | |||
| 181 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; | |||
| 182 | ||||
| 183 | Value *NewInst = | |||
| 184 | Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); | |||
| 185 | Bitcast->replaceAllUsesWith(NewInst); | |||
| 186 | } | |||
| 187 | ||||
| 188 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, | |||
| 189 | // %stride); | |||
| 190 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
| 191 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
| 192 | // --> | |||
| 193 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
| 194 | // %stride64, %13) | |||
| 195 | void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { | |||
| 196 | ||||
| 197 | Value *Tile = Bitcast->getOperand(0); | |||
| 198 | auto *II = cast<IntrinsicInst>(Tile); | |||
| 199 | // Tile is output from AMX intrinsic. The first operand of the | |||
| 200 | // intrinsic is row, the second operand of the intrinsic is column. | |||
| 201 | Value *Row = II->getOperand(0); | |||
| 202 | Value *Col = II->getOperand(1); | |||
| 203 | IRBuilder<> Builder(ST); | |||
| 204 | // Use the maximum column as stride. It must be the same with load | |||
| 205 | // stride. | |||
| 206 | Value *Stride = Builder.getInt64(64); | |||
| 207 | Value *I8Ptr = | |||
| 208 | Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); | |||
| 209 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; | |||
| 210 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
| 211 | if (Bitcast->hasOneUse()) | |||
| 212 | return; | |||
| 213 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
| 214 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
| 215 | // %add = <256 x i32> %13, <256 x i32> %src2 | |||
| 216 | // --> | |||
| 217 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
| 218 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
| 219 | // %stride64, %13) | |||
| 220 | // %14 = load <256 x i32>, %addr | |||
| 221 | // %add = <256 x i32> %14, <256 x i32> %src2 | |||
| 222 | Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); | |||
| 223 | Bitcast->replaceAllUsesWith(Vec); | |||
| 224 | } | |||
| 225 | ||||
| 226 | // transform bitcast to <store, load> instructions. | |||
| 227 | bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { | |||
| 228 | IRBuilder<> Builder(Bitcast); | |||
| 229 | AllocaInst *AllocaAddr; | |||
| 230 | Value *I8Ptr, *Stride; | |||
| 231 | auto *Src = Bitcast->getOperand(0); | |||
| 232 | ||||
| 233 | auto Prepare = [&]() { | |||
| 234 | AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent()); | |||
| 235 | I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); | |||
| 236 | Stride = Builder.getInt64(64); | |||
| 237 | }; | |||
| 238 | ||||
| 239 | if (Bitcast->getType()->isX86_AMXTy()) { | |||
| 240 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
| 241 | // --> | |||
| 242 | // %addr = alloca <256 x i32>, align 64 | |||
| 243 | // store <256 x i32> %src, <256 x i32>* %addr, align 64 | |||
| 244 | // %addr2 = bitcast <256 x i32>* to i8* | |||
| 245 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
| 246 | // i8* %addr2, | |||
| 247 | // i64 64) | |||
| 248 | Use &U = *(Bitcast->use_begin()); | |||
| 249 | unsigned OpNo = U.getOperandNo(); | |||
| 250 | auto *II = dyn_cast<IntrinsicInst>(U.getUser()); | |||
| 251 | if (!II) | |||
| 252 | return false; // May be bitcast from x86amx to <256 x i32>. | |||
| 253 | Prepare(); | |||
| 254 | Builder.CreateStore(Src, AllocaAddr); | |||
| 255 | // TODO we can pick an constant operand for the shape. | |||
| 256 | Value *Row = nullptr, *Col = nullptr; | |||
| 257 | std::tie(Row, Col) = getShape(II, OpNo); | |||
| 258 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; | |||
| 259 | Value *NewInst = Builder.CreateIntrinsic( | |||
| 260 | Intrinsic::x86_tileloadd64_internal, None, Args); | |||
| 261 | Bitcast->replaceAllUsesWith(NewInst); | |||
| 262 | } else { | |||
| 263 | // %2 = bitcast x86_amx %src to <256 x i32> | |||
| 264 | // --> | |||
| 265 | // %addr = alloca <256 x i32>, align 64 | |||
| 266 | // %addr2 = bitcast <256 x i32>* to i8* | |||
| 267 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, | |||
| 268 | // i8* %addr2, i64 %stride) | |||
| 269 | // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
| 270 | auto *II = dyn_cast<IntrinsicInst>(Src); | |||
| 271 | if (!II) | |||
| 272 | return false; // May be bitcast from <256 x i32> to x86amx. | |||
| 273 | Prepare(); | |||
| 274 | Value *Row = II->getOperand(0); | |||
| 275 | Value *Col = II->getOperand(1); | |||
| 276 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; | |||
| 277 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
| 278 | Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); | |||
| 279 | Bitcast->replaceAllUsesWith(NewInst); | |||
| 280 | } | |||
| 281 | ||||
| 282 | return true; | |||
| 283 | } | |||
| 284 | ||||
| 285 | bool X86LowerAMXType::visit() { | |||
| 286 | SmallVector<Instruction *, 8> DeadInsts; | |||
| 287 | Col2Row.clear(); | |||
| 288 | ||||
| 289 | for (BasicBlock *BB : post_order(&Func)) { | |||
| 290 | for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); | |||
| 291 | II != IE;) { | |||
| 292 | Instruction &Inst = *II++; | |||
| 293 | auto *Bitcast = dyn_cast<BitCastInst>(&Inst); | |||
| 294 | if (!Bitcast) | |||
| 295 | continue; | |||
| 296 | ||||
| 297 | Value *Src = Bitcast->getOperand(0); | |||
| 298 | if (Bitcast->getType()->isX86_AMXTy()) { | |||
| 299 | if (Bitcast->user_empty()) { | |||
| 300 | DeadInsts.push_back(Bitcast); | |||
| 301 | continue; | |||
| 302 | } | |||
| 303 | LoadInst *LD = dyn_cast<LoadInst>(Src); | |||
| 304 | if (!LD) { | |||
| 305 | if (transformBitcast(Bitcast)) | |||
| 306 | DeadInsts.push_back(Bitcast); | |||
| 307 | continue; | |||
| 308 | } | |||
| 309 | // If load has mutli-user, duplicate a vector load. | |||
| 310 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
| 311 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
| 312 | // %add = add <256 x i32> %src, <256 x i32> %src2 | |||
| 313 | // --> | |||
| 314 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
| 315 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
| 316 | // i8* %addr, i64 %stride64) | |||
| 317 | // %add = add <256 x i32> %src, <256 x i32> %src2 | |||
| 318 | ||||
| 319 | // If load has one user, the load will be eliminated in DAG ISel. | |||
| 320 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
| 321 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
| 322 | // --> | |||
| 323 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
| 324 | // i8* %addr, i64 %stride64) | |||
| 325 | combineLoadBitcast(LD, Bitcast); | |||
| 326 | DeadInsts.push_back(Bitcast); | |||
| 327 | if (LD->hasOneUse()) | |||
| 328 | DeadInsts.push_back(LD); | |||
| 329 | } else if (Src->getType()->isX86_AMXTy()) { | |||
| 330 | if (Bitcast->user_empty()) { | |||
| 331 | DeadInsts.push_back(Bitcast); | |||
| 332 | continue; | |||
| 333 | } | |||
| 334 | StoreInst *ST = nullptr; | |||
| 335 | for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); | |||
| 336 | UI != UE;) { | |||
| 337 | Value *I = (UI++)->getUser(); | |||
| 338 | ST = dyn_cast<StoreInst>(I); | |||
| 339 | if (ST) | |||
| 340 | break; | |||
| 341 | } | |||
| 342 | if (!ST) { | |||
| 343 | if (transformBitcast(Bitcast)) | |||
| 344 | DeadInsts.push_back(Bitcast); | |||
| 345 | continue; | |||
| 346 | } | |||
| 347 | // If bitcast (%13) has one use, combine bitcast and store to amx store. | |||
| 348 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, | |||
| 349 | // %stride); | |||
| 350 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
| 351 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
| 352 | // --> | |||
| 353 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
| 354 | // %stride64, %13) | |||
| 355 | // | |||
| 356 | // If bitcast (%13) has multi-use, transform as below. | |||
| 357 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
| 358 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
| 359 | // %add = <256 x i32> %13, <256 x i32> %src2 | |||
| 360 | // --> | |||
| 361 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
| 362 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
| 363 | // %stride64, %13) | |||
| 364 | // %14 = load <256 x i32>, %addr | |||
| 365 | // %add = <256 x i32> %14, <256 x i32> %src2 | |||
| 366 | // | |||
| 367 | combineBitcastStore(Bitcast, ST); | |||
| 368 | // Delete user first. | |||
| 369 | DeadInsts.push_back(ST); | |||
| 370 | DeadInsts.push_back(Bitcast); | |||
| 371 | } | |||
| 372 | } | |||
| 373 | } | |||
| 374 | ||||
| 375 | bool C = !DeadInsts.empty(); | |||
| 376 | ||||
| 377 | for (auto *Inst : DeadInsts) | |||
| 378 | Inst->eraseFromParent(); | |||
| 379 | ||||
| 380 | return C; | |||
| 381 | } | |||
| 382 | } // anonymous namespace | |||
| 383 | ||||
| 384 | static Value *getAllocaPos(BasicBlock *BB) { | |||
| 385 | Module *M = BB->getModule(); | |||
| 386 | Function *F = BB->getParent(); | |||
| 387 | IRBuilder<> Builder(&F->getEntryBlock().front()); | |||
| 388 | const DataLayout &DL = M->getDataLayout(); | |||
| 389 | unsigned AllocaAS = DL.getAllocaAddrSpace(); | |||
| 390 | Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); | |||
| 391 | AllocaInst *AllocaRes = | |||
| 392 | new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front()); | |||
| 393 | BasicBlock::iterator Iter = AllocaRes->getIterator(); | |||
| 394 | ++Iter; | |||
| 395 | Builder.SetInsertPoint(&*Iter); | |||
| 396 | Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy()); | |||
| 397 | return I8Ptr; | |||
| 398 | } | |||
| 399 | ||||
| 400 | static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { | |||
| 401 | assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!")((void)0); | |||
| 402 | auto *II = cast<IntrinsicInst>(TileDef); | |||
| 403 | assert(II && "Not tile intrinsic!")((void)0); | |||
| 404 | Value *Row = II->getOperand(0); | |||
| 405 | Value *Col = II->getOperand(1); | |||
| 406 | ||||
| 407 | BasicBlock *BB = TileDef->getParent(); | |||
| 408 | BasicBlock::iterator Iter = TileDef->getIterator(); | |||
| 409 | IRBuilder<> Builder(BB, ++Iter); | |||
| 410 | Value *Stride = Builder.getInt64(64); | |||
| 411 | std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; | |||
| 412 | ||||
| 413 | Instruction *TileStore = | |||
| 414 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
| 415 | return TileStore; | |||
| 416 | } | |||
| 417 | ||||
| 418 | static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { | |||
| 419 | Value *V = U.get(); | |||
| 420 | assert(V->getType()->isX86_AMXTy() && "Not define tile!")((void)0); | |||
| 421 | ||||
| 422 | // Get tile shape. | |||
| 423 | IntrinsicInst *II = nullptr; | |||
| 424 | if (IsPHI
| |||
| 425 | Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0); | |||
| ||||
| 426 | II = cast<IntrinsicInst>(PhiOp); | |||
| 427 | } else { | |||
| 428 | II = cast<IntrinsicInst>(V); | |||
| 429 | } | |||
| 430 | Value *Row = II->getOperand(0); | |||
| 431 | Value *Col = II->getOperand(1); | |||
| 432 | ||||
| 433 | Instruction *UserI = dyn_cast<Instruction>(U.getUser()); | |||
| 434 | IRBuilder<> Builder(UserI); | |||
| 435 | Value *Stride = Builder.getInt64(64); | |||
| 436 | std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; | |||
| 437 | ||||
| 438 | Value *TileLoad = | |||
| 439 | Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); | |||
| 440 | UserI->replaceUsesOfWith(V, TileLoad); | |||
| 441 | } | |||
| 442 | ||||
| 443 | static bool isIncomingOfPHI(Instruction *I) { | |||
| 444 | for (Use &U : I->uses()) { | |||
| 445 | User *V = U.getUser(); | |||
| 446 | if (isa<PHINode>(V)) | |||
| 447 | return true; | |||
| 448 | } | |||
| 449 | return false; | |||
| 450 | } | |||
| 451 | ||||
| 452 | // Let all AMX tile data become volatile data, shorten the life range | |||
| 453 | // of each tile register before fast register allocation. | |||
| 454 | namespace { | |||
| 455 | class X86VolatileTileData { | |||
| 456 | Function &F; | |||
| 457 | ||||
| 458 | public: | |||
| 459 | X86VolatileTileData(Function &Func) : F(Func) {} | |||
| 460 | Value *updatePhiIncomings(BasicBlock *BB, | |||
| 461 | SmallVector<Instruction *, 2> &Incomings); | |||
| 462 | void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); | |||
| 463 | bool volatileTileData(); | |||
| 464 | void volatileTilePHI(PHINode *Inst); | |||
| 465 | void volatileTileNonPHI(Instruction *I); | |||
| 466 | }; | |||
| 467 | ||||
| 468 | Value *X86VolatileTileData::updatePhiIncomings( | |||
| 469 | BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { | |||
| 470 | Value *I8Ptr = getAllocaPos(BB); | |||
| 471 | ||||
| 472 | for (auto *I : Incomings) { | |||
| 473 | User *Store = createTileStore(I, I8Ptr); | |||
| 474 | ||||
| 475 | // All its uses (except phi) should load from stored mem. | |||
| 476 | for (Use &U : I->uses()) { | |||
| 477 | User *V = U.getUser(); | |||
| 478 | if (isa<PHINode>(V) || V == Store) | |||
| 479 | continue; | |||
| 480 | replaceWithTileLoad(U, I8Ptr); | |||
| 481 | } | |||
| 482 | } | |||
| 483 | return I8Ptr; | |||
| 484 | } | |||
| 485 | ||||
| 486 | void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, | |||
| 487 | Value *StorePtr) { | |||
| 488 | for (Use &U : PHI->uses()) | |||
| 489 | replaceWithTileLoad(U, StorePtr, true); | |||
| 490 | PHI->eraseFromParent(); | |||
| 491 | } | |||
| 492 | ||||
| 493 | // Smilar with volatileTileNonPHI, this function only handle PHI Nodes | |||
| 494 | // and their related AMX intrinsics. | |||
| 495 | // 1) PHI Def should change to tileload. | |||
| 496 | // 2) PHI Incoming Values should tilestored in just after their def. | |||
| 497 | // 3) The mem of these tileload and tilestores should be same. | |||
| 498 | // e.g. | |||
| 499 | // ------------------------------------------------------ | |||
| 500 | // bb_dom: | |||
| 501 | // ... | |||
| 502 | // br i1 %bool.cond, label %if.else, label %if.then | |||
| 503 | // | |||
| 504 | // if.then: | |||
| 505 | // def %t0 = ... | |||
| 506 | // ... | |||
| 507 | // use %t0 | |||
| 508 | // ... | |||
| 509 | // br label %if.end | |||
| 510 | // | |||
| 511 | // if.else: | |||
| 512 | // def %t1 = ... | |||
| 513 | // br label %if.end | |||
| 514 | // | |||
| 515 | // if.end: | |||
| 516 | // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] | |||
| 517 | // ... | |||
| 518 | // use %td | |||
| 519 | // ------------------------------------------------------ | |||
| 520 | // --> | |||
| 521 | // ------------------------------------------------------ | |||
| 522 | // bb_entry: | |||
| 523 | // %mem = alloca <256 x i32>, align 1024 * | |||
| 524 | // ... | |||
| 525 | // bb_dom: | |||
| 526 | // ... | |||
| 527 | // br i1 %bool.cond, label %if.else, label %if.then | |||
| 528 | // | |||
| 529 | // if.then: | |||
| 530 | // def %t0 = ... | |||
| 531 | // call void @llvm.x86.tilestored64.internal(mem, %t0) * | |||
| 532 | // ... | |||
| 533 | // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* | |||
| 534 | // use %t0` * | |||
| 535 | // ... | |||
| 536 | // br label %if.end | |||
| 537 | // | |||
| 538 | // if.else: | |||
| 539 | // def %t1 = ... | |||
| 540 | // call void @llvm.x86.tilestored64.internal(mem, %t1) * | |||
| 541 | // br label %if.end | |||
| 542 | // | |||
| 543 | // if.end: | |||
| 544 | // ... | |||
| 545 | // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * | |||
| 546 | // use %td | |||
| 547 | // ------------------------------------------------------ | |||
| 548 | void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { | |||
| 549 | BasicBlock *BB = PHI->getParent(); | |||
| 550 | SmallVector<Instruction *, 2> Incomings; | |||
| 551 | ||||
| 552 | for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { | |||
| 553 | Value *Op = PHI->getIncomingValue(I); | |||
| 554 | Instruction *Inst = dyn_cast<Instruction>(Op); | |||
| 555 | assert(Inst && "We shouldn't fold AMX instrution!")((void)0); | |||
| 556 | Incomings.push_back(Inst); | |||
| 557 | } | |||
| 558 | ||||
| 559 | Value *StorePtr = updatePhiIncomings(BB, Incomings); | |||
| 560 | replacePhiDefWithLoad(PHI, StorePtr); | |||
| 561 | } | |||
| 562 | ||||
| 563 | // Store the defined tile and load it before use. | |||
| 564 | // All its users are not PHI. | |||
| 565 | // e.g. | |||
| 566 | // ------------------------------------------------------ | |||
| 567 | // def %td = ... | |||
| 568 | // ... | |||
| 569 | // "use %td" | |||
| 570 | // ------------------------------------------------------ | |||
| 571 | // --> | |||
| 572 | // ------------------------------------------------------ | |||
| 573 | // def %td = ... | |||
| 574 | // call void @llvm.x86.tilestored64.internal(mem, %td) | |||
| 575 | // ... | |||
| 576 | // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) | |||
| 577 | // "use %td2" | |||
| 578 | // ------------------------------------------------------ | |||
| 579 | void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { | |||
| 580 | BasicBlock *BB = I->getParent(); | |||
| 581 | Value *I8Ptr = getAllocaPos(BB); | |||
| 582 | User *Store = createTileStore(I, I8Ptr); | |||
| 583 | ||||
| 584 | // All its uses should load from stored mem. | |||
| 585 | for (Use &U : I->uses()) { | |||
| 586 | User *V = U.getUser(); | |||
| 587 | assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!")((void)0); | |||
| 588 | if (V != Store) | |||
| 589 | replaceWithTileLoad(U, I8Ptr); | |||
| 590 | } | |||
| 591 | } | |||
| 592 | ||||
| 593 | // Volatile Tile Model: | |||
| 594 | // 1) All the uses of tile data comes from tileload in time. | |||
| 595 | // 2) All the defs of tile data tilestore into mem immediately. | |||
| 596 | // For example: | |||
| 597 | // -------------------------------------------------------------------------- | |||
| 598 | // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key | |||
| 599 | // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) | |||
| 600 | // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx | |||
| 601 | // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) | |||
| 602 | // call void @llvm.x86.tilestored64.internal(... td) area | |||
| 603 | // -------------------------------------------------------------------------- | |||
| 604 | // 3) No terminator, call or other amx instructions in the key amx area. | |||
| 605 | bool X86VolatileTileData::volatileTileData() { | |||
| 606 | bool Changed = false; | |||
| 607 | for (BasicBlock &BB : F) { | |||
| 608 | SmallVector<Instruction *, 2> PHIInsts; | |||
| 609 | SmallVector<Instruction *, 8> AMXDefInsts; | |||
| 610 | ||||
| 611 | for (Instruction &I : BB) { | |||
| 612 | if (!I.getType()->isX86_AMXTy()) | |||
| 613 | continue; | |||
| 614 | if (isa<PHINode>(&I)) | |||
| 615 | PHIInsts.push_back(&I); | |||
| 616 | else | |||
| 617 | AMXDefInsts.push_back(&I); | |||
| 618 | } | |||
| 619 | ||||
| 620 | // First we "volatile" the non-phi related amx intrinsics. | |||
| 621 | for (Instruction *I : AMXDefInsts) { | |||
| 622 | if (isIncomingOfPHI(I)) | |||
| 623 | continue; | |||
| 624 | volatileTileNonPHI(I); | |||
| 625 | Changed = true; | |||
| 626 | } | |||
| 627 | ||||
| 628 | for (Instruction *I : PHIInsts) { | |||
| 629 | volatileTilePHI(dyn_cast<PHINode>(I)); | |||
| 630 | Changed = true; | |||
| 631 | } | |||
| 632 | } | |||
| 633 | return Changed; | |||
| 634 | } | |||
| 635 | ||||
| 636 | } // anonymous namespace | |||
| 637 | ||||
| 638 | namespace { | |||
| 639 | ||||
| 640 | class X86LowerAMXTypeLegacyPass : public FunctionPass { | |||
| 641 | public: | |||
| 642 | static char ID; | |||
| 643 | ||||
| 644 | X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { | |||
| 645 | initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); | |||
| 646 | } | |||
| 647 | ||||
| 648 | bool runOnFunction(Function &F) override { | |||
| 649 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); | |||
| 650 | ||||
| 651 | X86LowerAMXType LAT(F, TM); | |||
| 652 | bool C = LAT.visit(); | |||
| 653 | ||||
| 654 | // Prepare for fast register allocation at O0. | |||
| 655 | // Todo: May better check the volatile model of AMX code, not just | |||
| 656 | // by checking Attribute::OptimizeNone and CodeGenOpt::None. | |||
| 657 | if (TM->getOptLevel() == CodeGenOpt::None) { | |||
| ||||
| 658 | // If Front End not use O0 but the Mid/Back end use O0, (e.g. | |||
| 659 | // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make | |||
| 660 | // sure the amx data is volatile, that is nessary for AMX fast | |||
| 661 | // register allocation. | |||
| 662 | if (!F.hasFnAttribute(Attribute::OptimizeNone)) { | |||
| 663 | X86VolatileTileData VTD(F); | |||
| 664 | C = VTD.volatileTileData() || C; | |||
| 665 | } | |||
| 666 | } | |||
| 667 | ||||
| 668 | return C; | |||
| 669 | } | |||
| 670 | ||||
| 671 | void getAnalysisUsage(AnalysisUsage &AU) const override { | |||
| 672 | AU.setPreservesCFG(); | |||
| 673 | AU.addRequired<TargetPassConfig>(); | |||
| 674 | } | |||
| 675 | }; | |||
| 676 | ||||
| 677 | } // anonymous namespace | |||
| 678 | ||||
| 679 | static const char PassName[] = "Lower AMX type for load/store"; | |||
| 680 | char X86LowerAMXTypeLegacyPass::ID = 0; | |||
| 681 | INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry &Registry) { | |||
| 682 | false)static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry &Registry) { | |||
| 683 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)initializeTargetPassConfigPass(Registry); | |||
| 684 | INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,PassInfo *PI = new PassInfo( PassName, "lower-amx-type", & X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor <X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass (*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag ; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag , initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry )); } | |||
| 685 | false)PassInfo *PI = new PassInfo( PassName, "lower-amx-type", & X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor <X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass (*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag ; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag , initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry )); } | |||
| 686 | ||||
| 687 | FunctionPass *llvm::createX86LowerAMXTypePass() { | |||
| 688 | return new X86LowerAMXTypeLegacyPass(); | |||
| 689 | } |