File: | src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/X86/X86LowerAMXType.cpp |
Warning: | line 549, column 20 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | /// \file Pass to transform <256 x i32> load/store | |||
10 | /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only | |||
11 | /// provides simple operation on x86_amx. The basic elementwise operation | |||
12 | /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> | |||
13 | /// and only AMX intrinsics can operate on the type, we need transform | |||
14 | /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can | |||
15 | /// not be combined with load/store, we transform the bitcast to amx load/store | |||
16 | /// and <256 x i32> store/load. | |||
17 | /// | |||
18 | /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S | |||
19 | /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, | |||
20 | /// because that is necessary for AMX fast register allocation. (In Fast | |||
21 | /// registera allocation, register will be allocated before spill/reload, so | |||
22 | /// there is no additional register for amx to identify the step in spill.) | |||
23 | /// The volatileTileData() will handle this case. | |||
24 | /// e.g. | |||
25 | /// ---------------------------------------------------------- | |||
26 | /// | def %td = ... | | |||
27 | /// | ... | | |||
28 | /// | "use %td" | | |||
29 | /// ---------------------------------------------------------- | |||
30 | /// will transfer to --> | |||
31 | /// ---------------------------------------------------------- | |||
32 | /// | def %td = ... | | |||
33 | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | | |||
34 | /// | ... | | |||
35 | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| | |||
36 | /// | "use %td2" | | |||
37 | /// ---------------------------------------------------------- | |||
38 | // | |||
39 | //===----------------------------------------------------------------------===// | |||
40 | // | |||
41 | #include "X86.h" | |||
42 | #include "llvm/ADT/PostOrderIterator.h" | |||
43 | #include "llvm/ADT/SmallSet.h" | |||
44 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" | |||
45 | #include "llvm/Analysis/TargetTransformInfo.h" | |||
46 | #include "llvm/CodeGen/Passes.h" | |||
47 | #include "llvm/CodeGen/TargetPassConfig.h" | |||
48 | #include "llvm/CodeGen/ValueTypes.h" | |||
49 | #include "llvm/IR/DataLayout.h" | |||
50 | #include "llvm/IR/Function.h" | |||
51 | #include "llvm/IR/IRBuilder.h" | |||
52 | #include "llvm/IR/Instructions.h" | |||
53 | #include "llvm/IR/IntrinsicInst.h" | |||
54 | #include "llvm/IR/IntrinsicsX86.h" | |||
55 | #include "llvm/IR/PatternMatch.h" | |||
56 | #include "llvm/InitializePasses.h" | |||
57 | #include "llvm/Pass.h" | |||
58 | #include "llvm/Target/TargetMachine.h" | |||
59 | ||||
60 | using namespace llvm; | |||
61 | using namespace PatternMatch; | |||
62 | ||||
63 | #define DEBUG_TYPE"lower-amx-type" "lower-amx-type" | |||
64 | ||||
65 | static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, | |||
66 | BasicBlock *BB) { | |||
67 | Function &F = *BB->getParent(); | |||
68 | Module *M = BB->getModule(); | |||
69 | const DataLayout &DL = M->getDataLayout(); | |||
70 | ||||
71 | Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); | |||
72 | LLVMContext &Ctx = Builder.getContext(); | |||
73 | auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); | |||
74 | unsigned AllocaAS = DL.getAllocaAddrSpace(); | |||
75 | AllocaInst *AllocaRes = | |||
76 | new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front()); | |||
77 | AllocaRes->setAlignment(AllocaAlignment); | |||
78 | return AllocaRes; | |||
79 | } | |||
80 | ||||
81 | namespace { | |||
82 | class X86LowerAMXType { | |||
83 | Function &Func; | |||
84 | TargetMachine *TM = nullptr; | |||
85 | ||||
86 | // In AMX intrinsics we let Shape = {Row, Col}, but the | |||
87 | // RealCol = Col / ElementSize. We may use the RealCol | |||
88 | // as a new Row for other new created AMX intrinsics. | |||
89 | std::map<Value *, Value *> Col2Row; | |||
90 | ||||
91 | public: | |||
92 | X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {} | |||
93 | bool visit(); | |||
94 | void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); | |||
95 | void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); | |||
96 | bool transformBitcast(BitCastInst *Bitcast); | |||
97 | std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); | |||
98 | Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); | |||
99 | }; | |||
100 | ||||
101 | Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V, | |||
102 | unsigned Granularity) { | |||
103 | if (Col2Row.count(V)) | |||
104 | return Col2Row[V]; | |||
105 | IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt()); | |||
106 | if (auto *I = dyn_cast<Instruction>(V)) { | |||
107 | BasicBlock::iterator Iter = I->getIterator(); | |||
108 | ++Iter; | |||
109 | Builder.SetInsertPoint(&*Iter); | |||
110 | } | |||
111 | ConstantInt *Gran = Builder.getInt16(Granularity); | |||
112 | Value *RealRow = Builder.CreateUDiv(V, Gran); | |||
113 | Col2Row[V] = RealRow; | |||
114 | return RealRow; | |||
115 | } | |||
116 | ||||
117 | std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II, | |||
118 | unsigned OpNo) { | |||
119 | Value *Row = nullptr, *Col = nullptr; | |||
120 | switch (II->getIntrinsicID()) { | |||
121 | default: | |||
122 | llvm_unreachable("Expect amx intrinsics")__builtin_unreachable(); | |||
123 | case Intrinsic::x86_tileloadd64_internal: | |||
124 | case Intrinsic::x86_tileloaddt164_internal: | |||
125 | case Intrinsic::x86_tilestored64_internal: { | |||
126 | Row = II->getArgOperand(0); | |||
127 | Col = II->getArgOperand(1); | |||
128 | break; | |||
129 | } | |||
130 | // a * b + c | |||
131 | // The shape depends on which operand. | |||
132 | case Intrinsic::x86_tdpbssd_internal: | |||
133 | case Intrinsic::x86_tdpbsud_internal: | |||
134 | case Intrinsic::x86_tdpbusd_internal: | |||
135 | case Intrinsic::x86_tdpbuud_internal: | |||
136 | case Intrinsic::x86_tdpbf16ps_internal: { | |||
137 | switch (OpNo) { | |||
138 | case 3: | |||
139 | Row = II->getArgOperand(0); | |||
140 | Col = II->getArgOperand(1); | |||
141 | break; | |||
142 | case 4: | |||
143 | Row = II->getArgOperand(0); | |||
144 | Col = II->getArgOperand(2); | |||
145 | break; | |||
146 | case 5: | |||
147 | Row = II->getArgOperand(2); | |||
148 | // FIXME: There is a design bug for AMX shape, which the Col should be | |||
149 | // Col/4 if it will be used as Row, but current Greedy RA can't handle | |||
150 | // this case well, it may failed if we generate a new Shape definition. | |||
151 | // So Let's just do it in O0 first. | |||
152 | // Row = Row / 4 | |||
153 | if (TM->getOptLevel() == CodeGenOpt::None) | |||
154 | Row = getRowFromCol(II, Row, 4); | |||
155 | Col = II->getArgOperand(1); | |||
156 | break; | |||
157 | } | |||
158 | break; | |||
159 | } | |||
160 | } | |||
161 | ||||
162 | return std::make_pair(Row, Col); | |||
163 | } | |||
164 | ||||
165 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
166 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
167 | // --> | |||
168 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
169 | // i8* %addr, i64 %stride64) | |||
170 | void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { | |||
171 | Value *Row = nullptr, *Col = nullptr; | |||
172 | Use &U = *(Bitcast->use_begin()); | |||
173 | unsigned OpNo = U.getOperandNo(); | |||
174 | auto *II = cast<IntrinsicInst>(U.getUser()); | |||
175 | std::tie(Row, Col) = getShape(II, OpNo); | |||
176 | IRBuilder<> Builder(Bitcast); | |||
177 | // Use the maximun column as stride. | |||
178 | Value *Stride = Builder.getInt64(64); | |||
179 | Value *I8Ptr = | |||
180 | Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); | |||
181 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; | |||
182 | ||||
183 | Value *NewInst = | |||
184 | Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); | |||
185 | Bitcast->replaceAllUsesWith(NewInst); | |||
186 | } | |||
187 | ||||
188 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, | |||
189 | // %stride); | |||
190 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
191 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
192 | // --> | |||
193 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
194 | // %stride64, %13) | |||
195 | void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { | |||
196 | ||||
197 | Value *Tile = Bitcast->getOperand(0); | |||
198 | auto *II = cast<IntrinsicInst>(Tile); | |||
199 | // Tile is output from AMX intrinsic. The first operand of the | |||
200 | // intrinsic is row, the second operand of the intrinsic is column. | |||
201 | Value *Row = II->getOperand(0); | |||
202 | Value *Col = II->getOperand(1); | |||
203 | IRBuilder<> Builder(ST); | |||
204 | // Use the maximum column as stride. It must be the same with load | |||
205 | // stride. | |||
206 | Value *Stride = Builder.getInt64(64); | |||
207 | Value *I8Ptr = | |||
208 | Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); | |||
209 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; | |||
210 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
211 | if (Bitcast->hasOneUse()) | |||
212 | return; | |||
213 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
214 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
215 | // %add = <256 x i32> %13, <256 x i32> %src2 | |||
216 | // --> | |||
217 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
218 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
219 | // %stride64, %13) | |||
220 | // %14 = load <256 x i32>, %addr | |||
221 | // %add = <256 x i32> %14, <256 x i32> %src2 | |||
222 | Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); | |||
223 | Bitcast->replaceAllUsesWith(Vec); | |||
224 | } | |||
225 | ||||
226 | // transform bitcast to <store, load> instructions. | |||
227 | bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { | |||
228 | IRBuilder<> Builder(Bitcast); | |||
229 | AllocaInst *AllocaAddr; | |||
230 | Value *I8Ptr, *Stride; | |||
231 | auto *Src = Bitcast->getOperand(0); | |||
232 | ||||
233 | auto Prepare = [&]() { | |||
234 | AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent()); | |||
235 | I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); | |||
236 | Stride = Builder.getInt64(64); | |||
237 | }; | |||
238 | ||||
239 | if (Bitcast->getType()->isX86_AMXTy()) { | |||
240 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
241 | // --> | |||
242 | // %addr = alloca <256 x i32>, align 64 | |||
243 | // store <256 x i32> %src, <256 x i32>* %addr, align 64 | |||
244 | // %addr2 = bitcast <256 x i32>* to i8* | |||
245 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
246 | // i8* %addr2, | |||
247 | // i64 64) | |||
248 | Use &U = *(Bitcast->use_begin()); | |||
249 | unsigned OpNo = U.getOperandNo(); | |||
250 | auto *II = dyn_cast<IntrinsicInst>(U.getUser()); | |||
251 | if (!II) | |||
252 | return false; // May be bitcast from x86amx to <256 x i32>. | |||
253 | Prepare(); | |||
254 | Builder.CreateStore(Src, AllocaAddr); | |||
255 | // TODO we can pick an constant operand for the shape. | |||
256 | Value *Row = nullptr, *Col = nullptr; | |||
257 | std::tie(Row, Col) = getShape(II, OpNo); | |||
258 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; | |||
259 | Value *NewInst = Builder.CreateIntrinsic( | |||
260 | Intrinsic::x86_tileloadd64_internal, None, Args); | |||
261 | Bitcast->replaceAllUsesWith(NewInst); | |||
262 | } else { | |||
263 | // %2 = bitcast x86_amx %src to <256 x i32> | |||
264 | // --> | |||
265 | // %addr = alloca <256 x i32>, align 64 | |||
266 | // %addr2 = bitcast <256 x i32>* to i8* | |||
267 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, | |||
268 | // i8* %addr2, i64 %stride) | |||
269 | // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
270 | auto *II = dyn_cast<IntrinsicInst>(Src); | |||
271 | if (!II) | |||
272 | return false; // May be bitcast from <256 x i32> to x86amx. | |||
273 | Prepare(); | |||
274 | Value *Row = II->getOperand(0); | |||
275 | Value *Col = II->getOperand(1); | |||
276 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; | |||
277 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
278 | Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); | |||
279 | Bitcast->replaceAllUsesWith(NewInst); | |||
280 | } | |||
281 | ||||
282 | return true; | |||
283 | } | |||
284 | ||||
285 | bool X86LowerAMXType::visit() { | |||
286 | SmallVector<Instruction *, 8> DeadInsts; | |||
287 | Col2Row.clear(); | |||
288 | ||||
289 | for (BasicBlock *BB : post_order(&Func)) { | |||
290 | for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); | |||
291 | II != IE;) { | |||
292 | Instruction &Inst = *II++; | |||
293 | auto *Bitcast = dyn_cast<BitCastInst>(&Inst); | |||
294 | if (!Bitcast) | |||
295 | continue; | |||
296 | ||||
297 | Value *Src = Bitcast->getOperand(0); | |||
298 | if (Bitcast->getType()->isX86_AMXTy()) { | |||
299 | if (Bitcast->user_empty()) { | |||
300 | DeadInsts.push_back(Bitcast); | |||
301 | continue; | |||
302 | } | |||
303 | LoadInst *LD = dyn_cast<LoadInst>(Src); | |||
304 | if (!LD) { | |||
305 | if (transformBitcast(Bitcast)) | |||
306 | DeadInsts.push_back(Bitcast); | |||
307 | continue; | |||
308 | } | |||
309 | // If load has mutli-user, duplicate a vector load. | |||
310 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
311 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
312 | // %add = add <256 x i32> %src, <256 x i32> %src2 | |||
313 | // --> | |||
314 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
315 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
316 | // i8* %addr, i64 %stride64) | |||
317 | // %add = add <256 x i32> %src, <256 x i32> %src2 | |||
318 | ||||
319 | // If load has one user, the load will be eliminated in DAG ISel. | |||
320 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 | |||
321 | // %2 = bitcast <256 x i32> %src to x86_amx | |||
322 | // --> | |||
323 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, | |||
324 | // i8* %addr, i64 %stride64) | |||
325 | combineLoadBitcast(LD, Bitcast); | |||
326 | DeadInsts.push_back(Bitcast); | |||
327 | if (LD->hasOneUse()) | |||
328 | DeadInsts.push_back(LD); | |||
329 | } else if (Src->getType()->isX86_AMXTy()) { | |||
330 | if (Bitcast->user_empty()) { | |||
331 | DeadInsts.push_back(Bitcast); | |||
332 | continue; | |||
333 | } | |||
334 | StoreInst *ST = nullptr; | |||
335 | for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); | |||
336 | UI != UE;) { | |||
337 | Value *I = (UI++)->getUser(); | |||
338 | ST = dyn_cast<StoreInst>(I); | |||
339 | if (ST) | |||
340 | break; | |||
341 | } | |||
342 | if (!ST) { | |||
343 | if (transformBitcast(Bitcast)) | |||
344 | DeadInsts.push_back(Bitcast); | |||
345 | continue; | |||
346 | } | |||
347 | // If bitcast (%13) has one use, combine bitcast and store to amx store. | |||
348 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, | |||
349 | // %stride); | |||
350 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
351 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
352 | // --> | |||
353 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
354 | // %stride64, %13) | |||
355 | // | |||
356 | // If bitcast (%13) has multi-use, transform as below. | |||
357 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
358 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 | |||
359 | // %add = <256 x i32> %13, <256 x i32> %src2 | |||
360 | // --> | |||
361 | // %13 = bitcast x86_amx %src to <256 x i32> | |||
362 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, | |||
363 | // %stride64, %13) | |||
364 | // %14 = load <256 x i32>, %addr | |||
365 | // %add = <256 x i32> %14, <256 x i32> %src2 | |||
366 | // | |||
367 | combineBitcastStore(Bitcast, ST); | |||
368 | // Delete user first. | |||
369 | DeadInsts.push_back(ST); | |||
370 | DeadInsts.push_back(Bitcast); | |||
371 | } | |||
372 | } | |||
373 | } | |||
374 | ||||
375 | bool C = !DeadInsts.empty(); | |||
376 | ||||
377 | for (auto *Inst : DeadInsts) | |||
378 | Inst->eraseFromParent(); | |||
379 | ||||
380 | return C; | |||
381 | } | |||
382 | } // anonymous namespace | |||
383 | ||||
384 | static Value *getAllocaPos(BasicBlock *BB) { | |||
385 | Module *M = BB->getModule(); | |||
386 | Function *F = BB->getParent(); | |||
387 | IRBuilder<> Builder(&F->getEntryBlock().front()); | |||
388 | const DataLayout &DL = M->getDataLayout(); | |||
389 | unsigned AllocaAS = DL.getAllocaAddrSpace(); | |||
390 | Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); | |||
391 | AllocaInst *AllocaRes = | |||
392 | new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front()); | |||
393 | BasicBlock::iterator Iter = AllocaRes->getIterator(); | |||
394 | ++Iter; | |||
395 | Builder.SetInsertPoint(&*Iter); | |||
396 | Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy()); | |||
397 | return I8Ptr; | |||
398 | } | |||
399 | ||||
400 | static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { | |||
401 | assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!")((void)0); | |||
402 | auto *II = cast<IntrinsicInst>(TileDef); | |||
403 | assert(II && "Not tile intrinsic!")((void)0); | |||
404 | Value *Row = II->getOperand(0); | |||
405 | Value *Col = II->getOperand(1); | |||
406 | ||||
407 | BasicBlock *BB = TileDef->getParent(); | |||
408 | BasicBlock::iterator Iter = TileDef->getIterator(); | |||
409 | IRBuilder<> Builder(BB, ++Iter); | |||
410 | Value *Stride = Builder.getInt64(64); | |||
411 | std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; | |||
412 | ||||
413 | Instruction *TileStore = | |||
414 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); | |||
415 | return TileStore; | |||
416 | } | |||
417 | ||||
418 | static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { | |||
419 | Value *V = U.get(); | |||
420 | assert(V->getType()->isX86_AMXTy() && "Not define tile!")((void)0); | |||
421 | ||||
422 | // Get tile shape. | |||
423 | IntrinsicInst *II = nullptr; | |||
424 | if (IsPHI) { | |||
425 | Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0); | |||
426 | II = cast<IntrinsicInst>(PhiOp); | |||
427 | } else { | |||
428 | II = cast<IntrinsicInst>(V); | |||
429 | } | |||
430 | Value *Row = II->getOperand(0); | |||
431 | Value *Col = II->getOperand(1); | |||
432 | ||||
433 | Instruction *UserI = dyn_cast<Instruction>(U.getUser()); | |||
434 | IRBuilder<> Builder(UserI); | |||
435 | Value *Stride = Builder.getInt64(64); | |||
436 | std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; | |||
437 | ||||
438 | Value *TileLoad = | |||
439 | Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); | |||
440 | UserI->replaceUsesOfWith(V, TileLoad); | |||
441 | } | |||
442 | ||||
443 | static bool isIncomingOfPHI(Instruction *I) { | |||
444 | for (Use &U : I->uses()) { | |||
445 | User *V = U.getUser(); | |||
446 | if (isa<PHINode>(V)) | |||
447 | return true; | |||
448 | } | |||
449 | return false; | |||
450 | } | |||
451 | ||||
452 | // Let all AMX tile data become volatile data, shorten the life range | |||
453 | // of each tile register before fast register allocation. | |||
454 | namespace { | |||
455 | class X86VolatileTileData { | |||
456 | Function &F; | |||
457 | ||||
458 | public: | |||
459 | X86VolatileTileData(Function &Func) : F(Func) {} | |||
460 | Value *updatePhiIncomings(BasicBlock *BB, | |||
461 | SmallVector<Instruction *, 2> &Incomings); | |||
462 | void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); | |||
463 | bool volatileTileData(); | |||
464 | void volatileTilePHI(PHINode *Inst); | |||
465 | void volatileTileNonPHI(Instruction *I); | |||
466 | }; | |||
467 | ||||
468 | Value *X86VolatileTileData::updatePhiIncomings( | |||
469 | BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { | |||
470 | Value *I8Ptr = getAllocaPos(BB); | |||
471 | ||||
472 | for (auto *I : Incomings) { | |||
473 | User *Store = createTileStore(I, I8Ptr); | |||
474 | ||||
475 | // All its uses (except phi) should load from stored mem. | |||
476 | for (Use &U : I->uses()) { | |||
477 | User *V = U.getUser(); | |||
478 | if (isa<PHINode>(V) || V == Store) | |||
479 | continue; | |||
480 | replaceWithTileLoad(U, I8Ptr); | |||
481 | } | |||
482 | } | |||
483 | return I8Ptr; | |||
484 | } | |||
485 | ||||
486 | void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, | |||
487 | Value *StorePtr) { | |||
488 | for (Use &U : PHI->uses()) | |||
489 | replaceWithTileLoad(U, StorePtr, true); | |||
490 | PHI->eraseFromParent(); | |||
491 | } | |||
492 | ||||
493 | // Smilar with volatileTileNonPHI, this function only handle PHI Nodes | |||
494 | // and their related AMX intrinsics. | |||
495 | // 1) PHI Def should change to tileload. | |||
496 | // 2) PHI Incoming Values should tilestored in just after their def. | |||
497 | // 3) The mem of these tileload and tilestores should be same. | |||
498 | // e.g. | |||
499 | // ------------------------------------------------------ | |||
500 | // bb_dom: | |||
501 | // ... | |||
502 | // br i1 %bool.cond, label %if.else, label %if.then | |||
503 | // | |||
504 | // if.then: | |||
505 | // def %t0 = ... | |||
506 | // ... | |||
507 | // use %t0 | |||
508 | // ... | |||
509 | // br label %if.end | |||
510 | // | |||
511 | // if.else: | |||
512 | // def %t1 = ... | |||
513 | // br label %if.end | |||
514 | // | |||
515 | // if.end: | |||
516 | // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] | |||
517 | // ... | |||
518 | // use %td | |||
519 | // ------------------------------------------------------ | |||
520 | // --> | |||
521 | // ------------------------------------------------------ | |||
522 | // bb_entry: | |||
523 | // %mem = alloca <256 x i32>, align 1024 * | |||
524 | // ... | |||
525 | // bb_dom: | |||
526 | // ... | |||
527 | // br i1 %bool.cond, label %if.else, label %if.then | |||
528 | // | |||
529 | // if.then: | |||
530 | // def %t0 = ... | |||
531 | // call void @llvm.x86.tilestored64.internal(mem, %t0) * | |||
532 | // ... | |||
533 | // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* | |||
534 | // use %t0` * | |||
535 | // ... | |||
536 | // br label %if.end | |||
537 | // | |||
538 | // if.else: | |||
539 | // def %t1 = ... | |||
540 | // call void @llvm.x86.tilestored64.internal(mem, %t1) * | |||
541 | // br label %if.end | |||
542 | // | |||
543 | // if.end: | |||
544 | // ... | |||
545 | // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * | |||
546 | // use %td | |||
547 | // ------------------------------------------------------ | |||
548 | void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { | |||
549 | BasicBlock *BB = PHI->getParent(); | |||
| ||||
550 | SmallVector<Instruction *, 2> Incomings; | |||
551 | ||||
552 | for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { | |||
553 | Value *Op = PHI->getIncomingValue(I); | |||
554 | Instruction *Inst = dyn_cast<Instruction>(Op); | |||
555 | assert(Inst && "We shouldn't fold AMX instrution!")((void)0); | |||
556 | Incomings.push_back(Inst); | |||
557 | } | |||
558 | ||||
559 | Value *StorePtr = updatePhiIncomings(BB, Incomings); | |||
560 | replacePhiDefWithLoad(PHI, StorePtr); | |||
561 | } | |||
562 | ||||
563 | // Store the defined tile and load it before use. | |||
564 | // All its users are not PHI. | |||
565 | // e.g. | |||
566 | // ------------------------------------------------------ | |||
567 | // def %td = ... | |||
568 | // ... | |||
569 | // "use %td" | |||
570 | // ------------------------------------------------------ | |||
571 | // --> | |||
572 | // ------------------------------------------------------ | |||
573 | // def %td = ... | |||
574 | // call void @llvm.x86.tilestored64.internal(mem, %td) | |||
575 | // ... | |||
576 | // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) | |||
577 | // "use %td2" | |||
578 | // ------------------------------------------------------ | |||
579 | void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { | |||
580 | BasicBlock *BB = I->getParent(); | |||
581 | Value *I8Ptr = getAllocaPos(BB); | |||
582 | User *Store = createTileStore(I, I8Ptr); | |||
583 | ||||
584 | // All its uses should load from stored mem. | |||
585 | for (Use &U : I->uses()) { | |||
586 | User *V = U.getUser(); | |||
587 | assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!")((void)0); | |||
588 | if (V != Store) | |||
589 | replaceWithTileLoad(U, I8Ptr); | |||
590 | } | |||
591 | } | |||
592 | ||||
593 | // Volatile Tile Model: | |||
594 | // 1) All the uses of tile data comes from tileload in time. | |||
595 | // 2) All the defs of tile data tilestore into mem immediately. | |||
596 | // For example: | |||
597 | // -------------------------------------------------------------------------- | |||
598 | // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key | |||
599 | // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) | |||
600 | // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx | |||
601 | // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) | |||
602 | // call void @llvm.x86.tilestored64.internal(... td) area | |||
603 | // -------------------------------------------------------------------------- | |||
604 | // 3) No terminator, call or other amx instructions in the key amx area. | |||
605 | bool X86VolatileTileData::volatileTileData() { | |||
606 | bool Changed = false; | |||
607 | for (BasicBlock &BB : F) { | |||
608 | SmallVector<Instruction *, 2> PHIInsts; | |||
609 | SmallVector<Instruction *, 8> AMXDefInsts; | |||
610 | ||||
611 | for (Instruction &I : BB) { | |||
612 | if (!I.getType()->isX86_AMXTy()) | |||
613 | continue; | |||
614 | if (isa<PHINode>(&I)) | |||
615 | PHIInsts.push_back(&I); | |||
616 | else | |||
617 | AMXDefInsts.push_back(&I); | |||
618 | } | |||
619 | ||||
620 | // First we "volatile" the non-phi related amx intrinsics. | |||
621 | for (Instruction *I : AMXDefInsts) { | |||
622 | if (isIncomingOfPHI(I)) | |||
623 | continue; | |||
624 | volatileTileNonPHI(I); | |||
625 | Changed = true; | |||
626 | } | |||
627 | ||||
628 | for (Instruction *I : PHIInsts) { | |||
629 | volatileTilePHI(dyn_cast<PHINode>(I)); | |||
630 | Changed = true; | |||
631 | } | |||
632 | } | |||
633 | return Changed; | |||
634 | } | |||
635 | ||||
636 | } // anonymous namespace | |||
637 | ||||
638 | namespace { | |||
639 | ||||
640 | class X86LowerAMXTypeLegacyPass : public FunctionPass { | |||
641 | public: | |||
642 | static char ID; | |||
643 | ||||
644 | X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { | |||
645 | initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); | |||
646 | } | |||
647 | ||||
648 | bool runOnFunction(Function &F) override { | |||
649 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); | |||
650 | ||||
651 | X86LowerAMXType LAT(F, TM); | |||
652 | bool C = LAT.visit(); | |||
653 | ||||
654 | // Prepare for fast register allocation at O0. | |||
655 | // Todo: May better check the volatile model of AMX code, not just | |||
656 | // by checking Attribute::OptimizeNone and CodeGenOpt::None. | |||
657 | if (TM->getOptLevel() == CodeGenOpt::None) { | |||
| ||||
658 | // If Front End not use O0 but the Mid/Back end use O0, (e.g. | |||
659 | // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make | |||
660 | // sure the amx data is volatile, that is nessary for AMX fast | |||
661 | // register allocation. | |||
662 | if (!F.hasFnAttribute(Attribute::OptimizeNone)) { | |||
663 | X86VolatileTileData VTD(F); | |||
664 | C = VTD.volatileTileData() || C; | |||
665 | } | |||
666 | } | |||
667 | ||||
668 | return C; | |||
669 | } | |||
670 | ||||
671 | void getAnalysisUsage(AnalysisUsage &AU) const override { | |||
672 | AU.setPreservesCFG(); | |||
673 | AU.addRequired<TargetPassConfig>(); | |||
674 | } | |||
675 | }; | |||
676 | ||||
677 | } // anonymous namespace | |||
678 | ||||
679 | static const char PassName[] = "Lower AMX type for load/store"; | |||
680 | char X86LowerAMXTypeLegacyPass::ID = 0; | |||
681 | INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry &Registry) { | |||
682 | false)static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry &Registry) { | |||
683 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)initializeTargetPassConfigPass(Registry); | |||
684 | INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,PassInfo *PI = new PassInfo( PassName, "lower-amx-type", & X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor <X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass (*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag ; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag , initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry )); } | |||
685 | false)PassInfo *PI = new PassInfo( PassName, "lower-amx-type", & X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor <X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass (*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag ; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag , initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry )); } | |||
686 | ||||
687 | FunctionPass *llvm::createX86LowerAMXTypePass() { | |||
688 | return new X86LowerAMXTypeLegacyPass(); | |||
689 | } |