Skip to content

Commit 43fb73f

Browse files
authored
[Backport to 18] SPIRVReader: handle direct types with CooperativeMatrixLengthKHR (#2695) (#2706)
Translation of the attached test would currently fail due to the SPIRVReader attempting to process the `%matTy` operand as a regular value instead of a type. `OpCooperativeMatrixLengthKHR` seems to be pretty unique in taking an additional type operand beyond the result type, so special-case it in the reader. The translator currently accepts a non-type operand for `OpCooperativeMatrixLengthKHR` too, even though that's not within the specification; see various TODOs in the existing SPV_KHR_cooperative_matrix tests. Leave that relaxation in place, by only translating the operand as a type when it is an `OpTypeCooperativeMatrixKHR`. (cherry picked from commit 2b5f15d)
1 parent 242df2c commit 43fb73f

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3343,8 +3343,16 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
33433343
isSplitBarrierINTELOpCode(OC) || OC == OpControlBarrier)
33443344
Func->addFnAttr(Attribute::Convergent);
33453345
}
3346-
auto *Call =
3347-
CallInst::Create(Func, transValue(Ops, BB->getParent(), BB), "", BB);
3346+
CallInst *Call;
3347+
if (BI->getOpCode() == OpCooperativeMatrixLengthKHR &&
3348+
Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) {
3349+
// OpCooperativeMatrixLengthKHR needs special handling as its operand is
3350+
// a Type instead of a Value.
3351+
llvm::Type *MatTy = transType(reinterpret_cast<SPIRVType *>(Ops[0]));
3352+
Call = CallInst::Create(Func, Constant::getNullValue(MatTy), "", BB);
3353+
} else {
3354+
Call = CallInst::Create(Func, transValue(Ops, BB->getParent(), BB), "", BB);
3355+
}
33483356
setName(Call, BI);
33493357
setAttrByCalledFunc(Call);
33503358
SPIRVDBG(spvdbgs() << "[transInstToBuiltinCall] " << *BI << " -> ";

lib/SPIRV/libSPIRV/SPIRVInstruction.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ SPIRVInstruction::getOperandTypes(const std::vector<SPIRVValue *> &Ops) {
146146
SPIRVType *Ty = nullptr;
147147
if (I->getOpCode() == OpFunction)
148148
Ty = reinterpret_cast<SPIRVFunction *>(I)->getFunctionType();
149+
else if (I->getOpCode() == OpTypeCooperativeMatrixKHR)
150+
Ty = reinterpret_cast<SPIRVType *>(I);
149151
else
150152
Ty = I->getType();
151153

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; RUN: spirv-as --target-env spv1.0 -o %t.spv %s
2+
; RUN: spirv-val %t.spv
3+
; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s
4+
5+
OpCapability Addresses
6+
OpCapability Kernel
7+
OpCapability CooperativeMatrixKHR
8+
OpExtension "SPV_KHR_cooperative_matrix"
9+
OpMemoryModel Physical64 OpenCL
10+
OpEntryPoint Kernel %1 "testCoopMat"
11+
%void = OpTypeVoid
12+
%float = OpTypeFloat 32
13+
%fnTy = OpTypeFunction %void
14+
%uint = OpTypeInt 32 0
15+
%uint_3 = OpConstant %uint 3
16+
%uint_0 = OpConstant %uint 0
17+
%uint_8 = OpConstant %uint 8
18+
%matTy = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0
19+
%1 = OpFunction %void None %fnTy
20+
%2 = OpLabel
21+
%3 = OpCooperativeMatrixLengthKHR %uint %matTy
22+
OpReturn
23+
OpFunctionEnd
24+
25+
; CHECK: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS143__spirv_CooperativeMatrixKHR__float_3_8_8_0(target("spirv.CooperativeMatrixKHR", float, 3, 8, 8, 0) zeroinitializer)

0 commit comments

Comments
 (0)