Skip to content

Commit 9b7f810

Browse files
committed
xe2: jit: gemm: use atomic loads for fused post-op readback
1 parent 9df9fa5 commit 9b7f810

File tree

3 files changed

+27
-16
lines changed

3 files changed

+27
-16
lines changed

src/gpu/intel/gemm/jit/generator/pieces/atomic_fusions.cxx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,11 @@ bool Generator<hw>::gemmFusedPostOpsFinalize(Label &labelLateExit, GEMMProblem &
530530
modState.ra.safeRelease(zero);
531531

532532
status << "Load completed C tile" << status_stream::endl;
533-
modStrategy.C.atomic = modStrategy.CO.atomic = false;
534-
modState.Cext_strategy.atomic = false;
533+
if (hw < HW::Xe2) {
534+
/* Xe2 + later need atomic loads */
535+
modStrategy.C.atomic = modStrategy.CO.atomic = false;
536+
modState.Cext_strategy.atomic = false;
537+
}
535538
gemmAccessC(COperation::Load, modProblem, modStrategy, modState);
536539

537540
if (strategy.zeroTempC) {

src/gpu/intel/gemm/jit/generator/pieces/c_update.cxx

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ bool Generator<hw>::gemmAccessC(COperation op, const GEMMProblem &problem, const
172172
bool stdCRemainder = !(altCRemainder && (strategy.remHandling[LoopM] == RemainderHandling::KnownRemainder)
173173
&& (strategy.remHandling[LoopN] == RemainderHandling::KnownRemainder));
174174

175-
if ((op != COperation::UpdateStore) && strategy.C.atomic) stub();
176-
177175
if (state.allowEmptyC && (remainderM || remainderN)) {
178176
if (!state.isNested) stub();
179177
int simt = strategy.fused ? 16 : 1;
@@ -846,6 +844,7 @@ void Generator<hw>::updateCLayout(const RegisterLayout &layoutExt, const GRFRang
846844
#define FOR_EACH_C for (int q = 0; q < C_count; q++)
847845
auto Tc = problem.Tc, Tc_ext = problem.Tc_ext, Ts = problem.Ts;
848846
bool loadOnly = (op == COperation::Load);
847+
bool atomicUpdate = (op == COperation::UpdateStore) && strategy.C.atomic;
849848
bool beta0 = problem.beta0();
850849
bool needLoad = (!beta0 && !loadOnly);
851850
bool copyC = state.copyC;
@@ -876,7 +875,7 @@ void Generator<hw>::updateCLayout(const RegisterLayout &layoutExt, const GRFRang
876875
}
877876

878877
// Prepare for late C conversion.
879-
bool lateCConvert = (!loadOnly && !strategy.C.atomic && problem.needsTsConvert() && state.Tacc != Ts);
878+
bool lateCConvert = (!loadOnly && !atomicUpdate && problem.needsTsConvert() && state.Tacc != Ts);
880879
bool copyCLoad = needLoad && (copyC || lateCConvert);
881880
if (lateCConvert && Tc.isComplex()) stub();
882881

@@ -986,7 +985,7 @@ void Generator<hw>::updateCLayout(const RegisterLayout &layoutExt, const GRFRang
986985
setupAddr(C_addrsWith0, state.effC[q], sublayoutWith0, state.inputs.ldc[q], strategy, state, C_params, state.ldcMultiples[q], 1);
987986
}
988987

989-
if (strategy.C.atomic) {
988+
if (atomicUpdate) {
990989
// Atomic update.
991990
// Alpha scaling is done earlier; beta scaling isn't supported.
992991
if (!problem.alpha1() || !problem.beta1()) stub();
@@ -1185,6 +1184,7 @@ bool Generator<hw>::doStdCRemainder(RegisterLayout &layoutExt, RegisterLayout &l
11851184
if (!C_blockUnmasked0 && !layoutExtUnmasked.empty()) C_blockUnmasked0 = &layoutExtUnmasked[0];
11861185

11871186
bool canEOT = !state.isNested && (op == COperation::UpdateStore);
1187+
bool atomicUpdate = strategy.C.atomic && (op == COperation::UpdateStore);
11881188

11891189
Label lEnd;
11901190

@@ -1199,7 +1199,7 @@ bool Generator<hw>::doStdCRemainder(RegisterLayout &layoutExt, RegisterLayout &l
11991199
status << status_stream::endl;
12001200

12011201
// Allocate temporaries for emulated atomic addition if needed.
1202-
if (!inside && strategy.C.atomic) allocEAtomicAddRegs(hw, Tc_ext, layoutExt, problem.C, strategy.C, state);
1202+
if (!inside && atomicUpdate) allocEAtomicAddRegs(hw, Tc_ext, layoutExt, problem.C, strategy.C, state);
12031203

12041204
// Handle a subproblem. Return true if successful.
12051205
auto descend = [&](RegisterLayout &sublayoutExt, RegisterLayout &sublayoutExtUnmasked, bool full = false) -> bool {
@@ -1586,7 +1586,7 @@ bool Generator<hw>::doStdCRemainder(RegisterLayout &layoutExt, RegisterLayout &l
15861586
mark(lEnd);
15871587
success ? appendCurrentStream() : discardStream();
15881588

1589-
if (!inside && strategy.C.atomic) freeEAtomicAddRegs(state);
1589+
if (!inside && atomicUpdate) freeEAtomicAddRegs(state);
15901590

15911591
return success;
15921592
}
@@ -1602,7 +1602,8 @@ void Generator<hw>::doAlternateCRemainder(COperation op, const GEMMProblem &prob
16021602
#define FOR_EACH_C_REV for (int q = C_count - 1; q >= 0; q--)
16031603

16041604
bool lateYLoopCheck = false;
1605-
bool atomic = strategy.C.atomic;
1605+
bool atomicUpdate = strategy.C.atomic && (op == COperation::UpdateStore);
1606+
bool atomicLoad = strategy.C.atomic && !atomicUpdate;
16061607

16071608
bool surface = !strategy.C.base.isStateless();
16081609
bool loadOnly = (op == COperation::Load);
@@ -1620,7 +1621,7 @@ void Generator<hw>::doAlternateCRemainder(COperation op, const GEMMProblem &prob
16201621
nec = nbytes >> 2;
16211622

16221623
// 8-byte+ types can use scattered qword. Only atomic for now.
1623-
bool nativeAtomic = atomic && hasNativeAtomicAdd(hw, Tc_ext.real(), problem.C, strategy.C);
1624+
bool nativeAtomic = atomicUpdate && hasNativeAtomicAdd(hw, Tc_ext.real(), problem.C, strategy.C);
16241625
bool qword = false;
16251626
int rshift = qword ? 3 : 2; // log2(data stride in regs)
16261627
int rsimd = 64 >> rshift;
@@ -1904,7 +1905,7 @@ void Generator<hw>::doAlternateCRemainder(COperation op, const GEMMProblem &prob
19041905

19051906
#undef IGNORE_SWSB
19061907

1907-
if (atomic) {
1908+
if (atomicUpdate) {
19081909
// Atomic update. Requires beta = 0/1, alpha prescaled.
19091910
if (!problem.alpha1() || !problem.beta1()) stub();
19101911
if (C_count > 1) stub();
@@ -1958,7 +1959,12 @@ void Generator<hw>::doAlternateCRemainder(COperation op, const GEMMProblem &prob
19581959
// Regular update.
19591960
if (loadOnly || !problem.beta0()) {
19601961
doReadSuppressionWA(strategy, state);
1961-
if (strategy.C.newDP) {
1962+
if (atomicLoad && hw >= HW::Xe2) {
1963+
if (!strategy.C.newDP) stub();
1964+
!byte_access ? atomic(AtomicOp::load, 16 | mod, Cload, D32 | strategy.C.cachingR, strategy.C.base, header[0]) :
1965+
(Tc_ext.size() == 2) ? atomic(AtomicOp::load, 16 | mod, Cload, D16U32 | strategy.C.cachingR, strategy.C.base, header[0])
1966+
: stub();
1967+
} else if (strategy.C.newDP) {
19621968
!byte_access ? load(16 | mod, Cload, D32 | strategy.C.cachingR, strategy.C.base, header[0]) :
19631969
(Tc_ext.size() == 2) ? load(16 | mod, Cload, D16U32 | strategy.C.cachingR, strategy.C.base, header[0])
19641970
: load(16 | mod, Cload, D8U32 | strategy.C.cachingR, strategy.C.base, header[0]);
@@ -2203,9 +2209,9 @@ void Generator<hw>::gemmAccessSums(COperation op, const GEMMProblem &problem, co
22032209
auto Tco = problem.Tco;
22042210
auto cor = sumA ? strategy.unroll[LoopM] : 1;
22052211
auto coc = sumB ? strategy.unroll[LoopN] : 1;
2206-
bool atomic = strategy.CO.atomic;
2212+
bool atomicUpdate = strategy.CO.atomic && (op == COperation::UpdateStore);
22072213
bool loadOnly = (op == COperation::Load);
2208-
bool load = (op != COperation::Store && !problem.beta0() && !(problem.beta1() && atomic));
2214+
bool load = (op != COperation::Store && !problem.beta0() && !(problem.beta1() && atomicUpdate));
22092215

22102216
auto CO = problem.CO;
22112217
auto CO_strategy = strategy.CO;
@@ -2282,7 +2288,7 @@ void Generator<hw>::gemmAccessSums(COperation op, const GEMMProblem &problem, co
22822288
}
22832289

22842290
auto &effCO_regs = share ? Xs_regs : CO_regs;
2285-
if (atomic) {
2291+
if (atomicUpdate) {
22862292
allocEAtomicAddRegs(hw, Tco, CO_layout, CO, CO_strategy, state, state.flagAP);
22872293
atomicAddMatrix(effCO_regs, CO_layout, CO_addrs, problem, strategy, state);
22882294
freeEAtomicAddRegs(state, state.flagAP);

src/gpu/intel/gemm/jit/generator/pieces/matrix_access.cxx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ void Generator<hw>::loadMatrixBlock(const Register &dest, const RegisterBlock &b
104104
case AccessType::Scattered:
105105
case AccessType::ChannelScattered: {
106106
auto spec = getDataSpecLSC(atype, astrategy, block, AccessClass::Read);
107-
if (block.descAssigned) {
107+
if (astrategy.atomic && hw >= HW::Xe2)
108+
atomic(AtomicOp::load, mod, dest, spec, astrategy.base, getAddress(addr, block, astrategy));
109+
else if (block.descAssigned) {
108110
MessageDescriptor desc;
109111
ExtendedMessageDescriptor exdesc;
110112
encodeLoadDescriptors(hw, desc, exdesc, block.simdSize, r0, spec, astrategy.base, null);

0 commit comments

Comments
 (0)