Bug Summary

File:src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/X86/X86LowerAMXType.cpp
Warning:line 425, column 20
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple amd64-unknown-openbsd7.0 -analyze -disable-free -disable-llvm-verifier -discard-value-names -main-file-name X86LowerAMXType.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -mrelocation-model pic -pic-level 1 -fhalf-no-semantic-interposition -mframe-pointer=all -relaxed-aliasing -fno-rounding-math -mconstructor-aliases -munwind-tables -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -fcoverage-compilation-dir=/usr/src/gnu/usr.bin/clang/libLLVM/obj -resource-dir /usr/local/lib/clang/13.0.0 -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/AMDGPU -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Analysis -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/ASMParser -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/BinaryFormat -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Bitcode -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Bitcode -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Bitstream -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /include/llvm/CodeGen -I /include/llvm/CodeGen/PBQP -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/IR -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/IR -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms/Coroutines -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/ProfileData/Coverage -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/DebugInfo/CodeView -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/DebugInfo/DWARF -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/DebugInfo -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/DebugInfo/MSF -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/DebugInfo/PDB -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Demangle -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/ExecutionEngine -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/ExecutionEngine/JITLink -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/ExecutionEngine/Orc -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Frontend -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Frontend/OpenACC -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Frontend -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Frontend/OpenMP -I /include/llvm/CodeGen/GlobalISel -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/IRReader -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms/InstCombine -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/Transforms/InstCombine -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/LTO -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Linker -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/MC -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/MC/MCParser -I /include/llvm/CodeGen/MIRParser -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Object -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Option -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Passes -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/ -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/ProfileData -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms/Scalar -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/ADT -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Support -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/DebugInfo/Symbolize -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Target -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms/Utils -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms/Vectorize -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include/llvm/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/X86 -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include/llvm/Transforms/IPO -I /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/include -I /usr/src/gnu/usr.bin/clang/libLLVM/../include -I /usr/src/gnu/usr.bin/clang/libLLVM/obj -I /usr/src/gnu/usr.bin/clang/libLLVM/obj/../include -D NDEBUG -D __STDC_LIMIT_MACROS -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D LLVM_PREFIX="/usr" -D PIC -internal-isystem /usr/include/c++/v1 -internal-isystem /usr/local/lib/clang/13.0.0/include -internal-externc-isystem /usr/include -O2 -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/usr/src/gnu/usr.bin/clang/libLLVM/obj -ferror-limit 19 -fvisibility-inlines-hidden -fwrapv -D_RET_PROTECTOR -ret-protector -fno-rtti -fgnuc-version=4.2.1 -vectorize-loops -vectorize-slp -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-valloc -fno-builtin-free -fno-builtin-strdup -fno-builtin-strndup -analyzer-output=html -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /home/ben/Projects/vmm/scan-build/2022-01-12-194120-40624-1 -x c++ /usr/src/gnu/usr.bin/clang/libLLVM/../../../llvm/llvm/lib/Target/X86/X86LowerAMXType.cpp
1//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file Pass to transform <256 x i32> load/store
10/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11/// provides simple operation on x86_amx. The basic elementwise operation
12/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13/// and only AMX intrinsics can operate on the type, we need transform
14/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15/// not be combined with load/store, we transform the bitcast to amx load/store
16/// and <256 x i32> store/load.
17///
18/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20/// because that is necessary for AMX fast register allocation. (In Fast
21/// registera allocation, register will be allocated before spill/reload, so
22/// there is no additional register for amx to identify the step in spill.)
23/// The volatileTileData() will handle this case.
24/// e.g.
25/// ----------------------------------------------------------
26/// | def %td = ... |
27/// | ... |
28/// | "use %td" |
29/// ----------------------------------------------------------
30/// will transfer to -->
31/// ----------------------------------------------------------
32/// | def %td = ... |
33/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34/// | ... |
35/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36/// | "use %td2" |
37/// ----------------------------------------------------------
38//
39//===----------------------------------------------------------------------===//
40//
41#include "X86.h"
42#include "llvm/ADT/PostOrderIterator.h"
43#include "llvm/ADT/SmallSet.h"
44#include "llvm/Analysis/OptimizationRemarkEmitter.h"
45#include "llvm/Analysis/TargetTransformInfo.h"
46#include "llvm/CodeGen/Passes.h"
47#include "llvm/CodeGen/TargetPassConfig.h"
48#include "llvm/CodeGen/ValueTypes.h"
49#include "llvm/IR/DataLayout.h"
50#include "llvm/IR/Function.h"
51#include "llvm/IR/IRBuilder.h"
52#include "llvm/IR/Instructions.h"
53#include "llvm/IR/IntrinsicInst.h"
54#include "llvm/IR/IntrinsicsX86.h"
55#include "llvm/IR/PatternMatch.h"
56#include "llvm/InitializePasses.h"
57#include "llvm/Pass.h"
58#include "llvm/Target/TargetMachine.h"
59
60using namespace llvm;
61using namespace PatternMatch;
62
63#define DEBUG_TYPE"lower-amx-type" "lower-amx-type"
64
65static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder,
66 BasicBlock *BB) {
67 Function &F = *BB->getParent();
68 Module *M = BB->getModule();
69 const DataLayout &DL = M->getDataLayout();
70
71 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
72 LLVMContext &Ctx = Builder.getContext();
73 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
74 unsigned AllocaAS = DL.getAllocaAddrSpace();
75 AllocaInst *AllocaRes =
76 new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
77 AllocaRes->setAlignment(AllocaAlignment);
78 return AllocaRes;
79}
80
81namespace {
82class X86LowerAMXType {
83 Function &Func;
84 TargetMachine *TM = nullptr;
85
86 // In AMX intrinsics we let Shape = {Row, Col}, but the
87 // RealCol = Col / ElementSize. We may use the RealCol
88 // as a new Row for other new created AMX intrinsics.
89 std::map<Value *, Value *> Col2Row;
90
91public:
92 X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {}
93 bool visit();
94 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
95 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
96 bool transformBitcast(BitCastInst *Bitcast);
97 std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
98 Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
99};
100
101Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
102 unsigned Granularity) {
103 if (Col2Row.count(V))
104 return Col2Row[V];
105 IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
106 if (auto *I = dyn_cast<Instruction>(V)) {
107 BasicBlock::iterator Iter = I->getIterator();
108 ++Iter;
109 Builder.SetInsertPoint(&*Iter);
110 }
111 ConstantInt *Gran = Builder.getInt16(Granularity);
112 Value *RealRow = Builder.CreateUDiv(V, Gran);
113 Col2Row[V] = RealRow;
114 return RealRow;
115}
116
117std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
118 unsigned OpNo) {
119 Value *Row = nullptr, *Col = nullptr;
120 switch (II->getIntrinsicID()) {
121 default:
122 llvm_unreachable("Expect amx intrinsics")__builtin_unreachable();
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tileloaddt164_internal:
125 case Intrinsic::x86_tilestored64_internal: {
126 Row = II->getArgOperand(0);
127 Col = II->getArgOperand(1);
128 break;
129 }
130 // a * b + c
131 // The shape depends on which operand.
132 case Intrinsic::x86_tdpbssd_internal:
133 case Intrinsic::x86_tdpbsud_internal:
134 case Intrinsic::x86_tdpbusd_internal:
135 case Intrinsic::x86_tdpbuud_internal:
136 case Intrinsic::x86_tdpbf16ps_internal: {
137 switch (OpNo) {
138 case 3:
139 Row = II->getArgOperand(0);
140 Col = II->getArgOperand(1);
141 break;
142 case 4:
143 Row = II->getArgOperand(0);
144 Col = II->getArgOperand(2);
145 break;
146 case 5:
147 Row = II->getArgOperand(2);
148 // FIXME: There is a design bug for AMX shape, which the Col should be
149 // Col/4 if it will be used as Row, but current Greedy RA can't handle
150 // this case well, it may failed if we generate a new Shape definition.
151 // So Let's just do it in O0 first.
152 // Row = Row / 4
153 if (TM->getOptLevel() == CodeGenOpt::None)
154 Row = getRowFromCol(II, Row, 4);
155 Col = II->getArgOperand(1);
156 break;
157 }
158 break;
159 }
160 }
161
162 return std::make_pair(Row, Col);
163}
164
165// %src = load <256 x i32>, <256 x i32>* %addr, align 64
166// %2 = bitcast <256 x i32> %src to x86_amx
167// -->
168// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
169// i8* %addr, i64 %stride64)
170void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
171 Value *Row = nullptr, *Col = nullptr;
172 Use &U = *(Bitcast->use_begin());
173 unsigned OpNo = U.getOperandNo();
174 auto *II = cast<IntrinsicInst>(U.getUser());
175 std::tie(Row, Col) = getShape(II, OpNo);
176 IRBuilder<> Builder(Bitcast);
177 // Use the maximun column as stride.
178 Value *Stride = Builder.getInt64(64);
179 Value *I8Ptr =
180 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
181 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
182
183 Value *NewInst =
184 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
185 Bitcast->replaceAllUsesWith(NewInst);
186}
187
188// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
189// %stride);
190// %13 = bitcast x86_amx %src to <256 x i32>
191// store <256 x i32> %13, <256 x i32>* %addr, align 64
192// -->
193// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
194// %stride64, %13)
195void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
196
197 Value *Tile = Bitcast->getOperand(0);
198 auto *II = cast<IntrinsicInst>(Tile);
199 // Tile is output from AMX intrinsic. The first operand of the
200 // intrinsic is row, the second operand of the intrinsic is column.
201 Value *Row = II->getOperand(0);
202 Value *Col = II->getOperand(1);
203 IRBuilder<> Builder(ST);
204 // Use the maximum column as stride. It must be the same with load
205 // stride.
206 Value *Stride = Builder.getInt64(64);
207 Value *I8Ptr =
208 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
209 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
210 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
211 if (Bitcast->hasOneUse())
212 return;
213 // %13 = bitcast x86_amx %src to <256 x i32>
214 // store <256 x i32> %13, <256 x i32>* %addr, align 64
215 // %add = <256 x i32> %13, <256 x i32> %src2
216 // -->
217 // %13 = bitcast x86_amx %src to <256 x i32>
218 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
219 // %stride64, %13)
220 // %14 = load <256 x i32>, %addr
221 // %add = <256 x i32> %14, <256 x i32> %src2
222 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
223 Bitcast->replaceAllUsesWith(Vec);
224}
225
226// transform bitcast to <store, load> instructions.
227bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
228 IRBuilder<> Builder(Bitcast);
229 AllocaInst *AllocaAddr;
230 Value *I8Ptr, *Stride;
231 auto *Src = Bitcast->getOperand(0);
232
233 auto Prepare = [&]() {
234 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent());
235 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
236 Stride = Builder.getInt64(64);
237 };
238
239 if (Bitcast->getType()->isX86_AMXTy()) {
240 // %2 = bitcast <256 x i32> %src to x86_amx
241 // -->
242 // %addr = alloca <256 x i32>, align 64
243 // store <256 x i32> %src, <256 x i32>* %addr, align 64
244 // %addr2 = bitcast <256 x i32>* to i8*
245 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
246 // i8* %addr2,
247 // i64 64)
248 Use &U = *(Bitcast->use_begin());
249 unsigned OpNo = U.getOperandNo();
250 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
251 if (!II)
252 return false; // May be bitcast from x86amx to <256 x i32>.
253 Prepare();
254 Builder.CreateStore(Src, AllocaAddr);
255 // TODO we can pick an constant operand for the shape.
256 Value *Row = nullptr, *Col = nullptr;
257 std::tie(Row, Col) = getShape(II, OpNo);
258 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
259 Value *NewInst = Builder.CreateIntrinsic(
260 Intrinsic::x86_tileloadd64_internal, None, Args);
261 Bitcast->replaceAllUsesWith(NewInst);
262 } else {
263 // %2 = bitcast x86_amx %src to <256 x i32>
264 // -->
265 // %addr = alloca <256 x i32>, align 64
266 // %addr2 = bitcast <256 x i32>* to i8*
267 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
268 // i8* %addr2, i64 %stride)
269 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
270 auto *II = dyn_cast<IntrinsicInst>(Src);
271 if (!II)
272 return false; // May be bitcast from <256 x i32> to x86amx.
273 Prepare();
274 Value *Row = II->getOperand(0);
275 Value *Col = II->getOperand(1);
276 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
277 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
278 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
279 Bitcast->replaceAllUsesWith(NewInst);
280 }
281
282 return true;
283}
284
285bool X86LowerAMXType::visit() {
286 SmallVector<Instruction *, 8> DeadInsts;
287 Col2Row.clear();
288
289 for (BasicBlock *BB : post_order(&Func)) {
290 for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
291 II != IE;) {
292 Instruction &Inst = *II++;
293 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
294 if (!Bitcast)
295 continue;
296
297 Value *Src = Bitcast->getOperand(0);
298 if (Bitcast->getType()->isX86_AMXTy()) {
299 if (Bitcast->user_empty()) {
300 DeadInsts.push_back(Bitcast);
301 continue;
302 }
303 LoadInst *LD = dyn_cast<LoadInst>(Src);
304 if (!LD) {
305 if (transformBitcast(Bitcast))
306 DeadInsts.push_back(Bitcast);
307 continue;
308 }
309 // If load has mutli-user, duplicate a vector load.
310 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
311 // %2 = bitcast <256 x i32> %src to x86_amx
312 // %add = add <256 x i32> %src, <256 x i32> %src2
313 // -->
314 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
315 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
316 // i8* %addr, i64 %stride64)
317 // %add = add <256 x i32> %src, <256 x i32> %src2
318
319 // If load has one user, the load will be eliminated in DAG ISel.
320 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
321 // %2 = bitcast <256 x i32> %src to x86_amx
322 // -->
323 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
324 // i8* %addr, i64 %stride64)
325 combineLoadBitcast(LD, Bitcast);
326 DeadInsts.push_back(Bitcast);
327 if (LD->hasOneUse())
328 DeadInsts.push_back(LD);
329 } else if (Src->getType()->isX86_AMXTy()) {
330 if (Bitcast->user_empty()) {
331 DeadInsts.push_back(Bitcast);
332 continue;
333 }
334 StoreInst *ST = nullptr;
335 for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
336 UI != UE;) {
337 Value *I = (UI++)->getUser();
338 ST = dyn_cast<StoreInst>(I);
339 if (ST)
340 break;
341 }
342 if (!ST) {
343 if (transformBitcast(Bitcast))
344 DeadInsts.push_back(Bitcast);
345 continue;
346 }
347 // If bitcast (%13) has one use, combine bitcast and store to amx store.
348 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
349 // %stride);
350 // %13 = bitcast x86_amx %src to <256 x i32>
351 // store <256 x i32> %13, <256 x i32>* %addr, align 64
352 // -->
353 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
354 // %stride64, %13)
355 //
356 // If bitcast (%13) has multi-use, transform as below.
357 // %13 = bitcast x86_amx %src to <256 x i32>
358 // store <256 x i32> %13, <256 x i32>* %addr, align 64
359 // %add = <256 x i32> %13, <256 x i32> %src2
360 // -->
361 // %13 = bitcast x86_amx %src to <256 x i32>
362 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
363 // %stride64, %13)
364 // %14 = load <256 x i32>, %addr
365 // %add = <256 x i32> %14, <256 x i32> %src2
366 //
367 combineBitcastStore(Bitcast, ST);
368 // Delete user first.
369 DeadInsts.push_back(ST);
370 DeadInsts.push_back(Bitcast);
371 }
372 }
373 }
374
375 bool C = !DeadInsts.empty();
376
377 for (auto *Inst : DeadInsts)
378 Inst->eraseFromParent();
379
380 return C;
381}
382} // anonymous namespace
383
384static Value *getAllocaPos(BasicBlock *BB) {
385 Module *M = BB->getModule();
386 Function *F = BB->getParent();
387 IRBuilder<> Builder(&F->getEntryBlock().front());
388 const DataLayout &DL = M->getDataLayout();
389 unsigned AllocaAS = DL.getAllocaAddrSpace();
390 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
391 AllocaInst *AllocaRes =
392 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
393 BasicBlock::iterator Iter = AllocaRes->getIterator();
394 ++Iter;
395 Builder.SetInsertPoint(&*Iter);
396 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
397 return I8Ptr;
398}
399
400static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
401 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!")((void)0);
402 auto *II = cast<IntrinsicInst>(TileDef);
403 assert(II && "Not tile intrinsic!")((void)0);
404 Value *Row = II->getOperand(0);
405 Value *Col = II->getOperand(1);
406
407 BasicBlock *BB = TileDef->getParent();
408 BasicBlock::iterator Iter = TileDef->getIterator();
409 IRBuilder<> Builder(BB, ++Iter);
410 Value *Stride = Builder.getInt64(64);
411 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
412
413 Instruction *TileStore =
414 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
415 return TileStore;
416}
417
418static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
419 Value *V = U.get();
420 assert(V->getType()->isX86_AMXTy() && "Not define tile!")((void)0);
421
422 // Get tile shape.
423 IntrinsicInst *II = nullptr;
424 if (IsPHI
13.1
'IsPHI' is true
) {
14
Taking true branch
425 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
15
Assuming 'V' is not a 'PHINode'
16
Called C++ object pointer is null
426 II = cast<IntrinsicInst>(PhiOp);
427 } else {
428 II = cast<IntrinsicInst>(V);
429 }
430 Value *Row = II->getOperand(0);
431 Value *Col = II->getOperand(1);
432
433 Instruction *UserI = dyn_cast<Instruction>(U.getUser());
434 IRBuilder<> Builder(UserI);
435 Value *Stride = Builder.getInt64(64);
436 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
437
438 Value *TileLoad =
439 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
440 UserI->replaceUsesOfWith(V, TileLoad);
441}
442
443static bool isIncomingOfPHI(Instruction *I) {
444 for (Use &U : I->uses()) {
445 User *V = U.getUser();
446 if (isa<PHINode>(V))
447 return true;
448 }
449 return false;
450}
451
452// Let all AMX tile data become volatile data, shorten the life range
453// of each tile register before fast register allocation.
454namespace {
455class X86VolatileTileData {
456 Function &F;
457
458public:
459 X86VolatileTileData(Function &Func) : F(Func) {}
460 Value *updatePhiIncomings(BasicBlock *BB,
461 SmallVector<Instruction *, 2> &Incomings);
462 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
463 bool volatileTileData();
464 void volatileTilePHI(PHINode *Inst);
465 void volatileTileNonPHI(Instruction *I);
466};
467
468Value *X86VolatileTileData::updatePhiIncomings(
469 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
470 Value *I8Ptr = getAllocaPos(BB);
471
472 for (auto *I : Incomings) {
473 User *Store = createTileStore(I, I8Ptr);
474
475 // All its uses (except phi) should load from stored mem.
476 for (Use &U : I->uses()) {
477 User *V = U.getUser();
478 if (isa<PHINode>(V) || V == Store)
479 continue;
480 replaceWithTileLoad(U, I8Ptr);
481 }
482 }
483 return I8Ptr;
484}
485
486void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
487 Value *StorePtr) {
488 for (Use &U : PHI->uses())
489 replaceWithTileLoad(U, StorePtr, true);
13
Calling 'replaceWithTileLoad'
490 PHI->eraseFromParent();
491}
492
493// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
494// and their related AMX intrinsics.
495// 1) PHI Def should change to tileload.
496// 2) PHI Incoming Values should tilestored in just after their def.
497// 3) The mem of these tileload and tilestores should be same.
498// e.g.
499// ------------------------------------------------------
500// bb_dom:
501// ...
502// br i1 %bool.cond, label %if.else, label %if.then
503//
504// if.then:
505// def %t0 = ...
506// ...
507// use %t0
508// ...
509// br label %if.end
510//
511// if.else:
512// def %t1 = ...
513// br label %if.end
514//
515// if.end:
516// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
517// ...
518// use %td
519// ------------------------------------------------------
520// -->
521// ------------------------------------------------------
522// bb_entry:
523// %mem = alloca <256 x i32>, align 1024 *
524// ...
525// bb_dom:
526// ...
527// br i1 %bool.cond, label %if.else, label %if.then
528//
529// if.then:
530// def %t0 = ...
531// call void @llvm.x86.tilestored64.internal(mem, %t0) *
532// ...
533// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
534// use %t0` *
535// ...
536// br label %if.end
537//
538// if.else:
539// def %t1 = ...
540// call void @llvm.x86.tilestored64.internal(mem, %t1) *
541// br label %if.end
542//
543// if.end:
544// ...
545// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
546// use %td
547// ------------------------------------------------------
548void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
549 BasicBlock *BB = PHI->getParent();
550 SmallVector<Instruction *, 2> Incomings;
551
552 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
10
Assuming 'I' is equal to 'E'
11
Loop condition is false. Execution continues on line 559
553 Value *Op = PHI->getIncomingValue(I);
554 Instruction *Inst = dyn_cast<Instruction>(Op);
555 assert(Inst && "We shouldn't fold AMX instrution!")((void)0);
556 Incomings.push_back(Inst);
557 }
558
559 Value *StorePtr = updatePhiIncomings(BB, Incomings);
560 replacePhiDefWithLoad(PHI, StorePtr);
12
Calling 'X86VolatileTileData::replacePhiDefWithLoad'
561}
562
563// Store the defined tile and load it before use.
564// All its users are not PHI.
565// e.g.
566// ------------------------------------------------------
567// def %td = ...
568// ...
569// "use %td"
570// ------------------------------------------------------
571// -->
572// ------------------------------------------------------
573// def %td = ...
574// call void @llvm.x86.tilestored64.internal(mem, %td)
575// ...
576// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
577// "use %td2"
578// ------------------------------------------------------
579void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
580 BasicBlock *BB = I->getParent();
581 Value *I8Ptr = getAllocaPos(BB);
582 User *Store = createTileStore(I, I8Ptr);
583
584 // All its uses should load from stored mem.
585 for (Use &U : I->uses()) {
586 User *V = U.getUser();
587 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!")((void)0);
588 if (V != Store)
589 replaceWithTileLoad(U, I8Ptr);
590 }
591}
592
593// Volatile Tile Model:
594// 1) All the uses of tile data comes from tileload in time.
595// 2) All the defs of tile data tilestore into mem immediately.
596// For example:
597// --------------------------------------------------------------------------
598// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
599// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
600// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
601// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
602// call void @llvm.x86.tilestored64.internal(... td) area
603// --------------------------------------------------------------------------
604// 3) No terminator, call or other amx instructions in the key amx area.
605bool X86VolatileTileData::volatileTileData() {
606 bool Changed = false;
607 for (BasicBlock &BB : F) {
608 SmallVector<Instruction *, 2> PHIInsts;
609 SmallVector<Instruction *, 8> AMXDefInsts;
610
611 for (Instruction &I : BB) {
612 if (!I.getType()->isX86_AMXTy())
613 continue;
614 if (isa<PHINode>(&I))
615 PHIInsts.push_back(&I);
616 else
617 AMXDefInsts.push_back(&I);
618 }
619
620 // First we "volatile" the non-phi related amx intrinsics.
621 for (Instruction *I : AMXDefInsts) {
6
Assuming '__begin2' is equal to '__end2'
622 if (isIncomingOfPHI(I))
623 continue;
624 volatileTileNonPHI(I);
625 Changed = true;
626 }
627
628 for (Instruction *I : PHIInsts) {
7
Assuming '__begin2' is not equal to '__end2'
629 volatileTilePHI(dyn_cast<PHINode>(I));
8
Assuming 'I' is a 'PHINode'
9
Calling 'X86VolatileTileData::volatileTilePHI'
630 Changed = true;
631 }
632 }
633 return Changed;
634}
635
636} // anonymous namespace
637
638namespace {
639
640class X86LowerAMXTypeLegacyPass : public FunctionPass {
641public:
642 static char ID;
643
644 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
645 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
646 }
647
648 bool runOnFunction(Function &F) override {
649 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
650
651 X86LowerAMXType LAT(F, TM);
652 bool C = LAT.visit();
653
654 // Prepare for fast register allocation at O0.
655 // Todo: May better check the volatile model of AMX code, not just
656 // by checking Attribute::OptimizeNone and CodeGenOpt::None.
657 if (TM->getOptLevel() == CodeGenOpt::None) {
1
Assuming the condition is true
2
Taking true branch
658 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
659 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
660 // sure the amx data is volatile, that is nessary for AMX fast
661 // register allocation.
662 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
3
Assuming the condition is true
4
Taking true branch
663 X86VolatileTileData VTD(F);
664 C = VTD.volatileTileData() || C;
5
Calling 'X86VolatileTileData::volatileTileData'
665 }
666 }
667
668 return C;
669 }
670
671 void getAnalysisUsage(AnalysisUsage &AU) const override {
672 AU.setPreservesCFG();
673 AU.addRequired<TargetPassConfig>();
674 }
675};
676
677} // anonymous namespace
678
679static const char PassName[] = "Lower AMX type for load/store";
680char X86LowerAMXTypeLegacyPass::ID = 0;
681INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry
&Registry) {
682 false)static void *initializeX86LowerAMXTypeLegacyPassPassOnce(PassRegistry
&Registry) {
683INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)initializeTargetPassConfigPass(Registry);
684INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,PassInfo *PI = new PassInfo( PassName, "lower-amx-type", &
X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor
<X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass
(*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag
; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag
, initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry
)); }
685 false)PassInfo *PI = new PassInfo( PassName, "lower-amx-type", &
X86LowerAMXTypeLegacyPass::ID, PassInfo::NormalCtor_t(callDefaultCtor
<X86LowerAMXTypeLegacyPass>), false, false); Registry.registerPass
(*PI, true); return PI; } static llvm::once_flag InitializeX86LowerAMXTypeLegacyPassPassFlag
; void llvm::initializeX86LowerAMXTypeLegacyPassPass(PassRegistry
&Registry) { llvm::call_once(InitializeX86LowerAMXTypeLegacyPassPassFlag
, initializeX86LowerAMXTypeLegacyPassPassOnce, std::ref(Registry
)); }
686
687FunctionPass *llvm::createX86LowerAMXTypePass() {
688 return new X86LowerAMXTypeLegacyPass();
689}