Skip to content

Commit 002114a

Browse files
committed
AMDGPU: Handle rewriting VGPR MFMA fed from AGPR copy
Previously we handled the inverse situation only.
1 parent dd7d411 commit 002114a

File tree

3 files changed

+249
-290
lines changed

3 files changed

+249
-290
lines changed

llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp

Lines changed: 191 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
/// MFMA opcode.
1515
///
1616
/// TODO:
17+
/// - Handle rewrites of phis. This must be more careful than normal about the
18+
/// reassignment. We do not want to introduce an AGPR-to-AGPR copy inside of a
19+
/// loop, so it depends on the exact assignment of the copy.
20+
///
1721
/// - Update LiveIntervals incrementally instead of recomputing from scratch
1822
///
1923
//===----------------------------------------------------------------------===//
@@ -60,6 +64,32 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
6064
return TII.isMAI(MI) && AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode()) != -1;
6165
}
6266

67+
/// Find AV_* registers assigned to AGPRs (or virtual registers which were
68+
/// already required to be AGPR).
69+
///
70+
/// \return the assigned physical register that \p VReg is assigned to if it
71+
/// is an AGPR, otherwise MCRegister().
72+
MCRegister getAssignedAGPR(Register VReg) const {
73+
MCRegister PhysReg = VRM.getPhys(VReg);
74+
if (!PhysReg)
75+
return MCRegister();
76+
77+
const TargetRegisterClass *VirtRegRC = MRI.getRegClass(VReg);
78+
if (!TRI.hasAGPRs(VirtRegRC))
79+
return MCRegister();
80+
81+
if (!TRI.hasVGPRs(VirtRegRC))
82+
return PhysReg;
83+
84+
// If this is an AV register, we have to check if the actual assignment is
85+
// to an AGPR
86+
const TargetRegisterClass *AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
87+
return TRI.isAGPRClass(AssignedRC) ? PhysReg : MCRegister();
88+
}
89+
90+
bool tryReassigningMFMAChain(MachineInstr &MFMA, unsigned HintOpIdx,
91+
MCPhysReg PhysRegHint) const;
92+
6393
/// Compute the register class constraints based on the uses of \p Reg,
6494
/// excluding MFMA uses from which can be rewritten to change the register
6595
/// class constraint. This should be nearly identical to
@@ -74,6 +104,8 @@ class AMDGPURewriteAGPRCopyMFMAImpl {
74104
Register Reg, SmallVectorImpl<MachineInstr *> &RewriteCandidates,
75105
SmallSetVector<Register, 4> &RewriteRegs) const;
76106

107+
bool tryFoldCopiesToAGPR(Register VReg, MCRegister AssignedAGPR) const;
108+
bool tryFoldCopiesFromAGPR(Register VReg, MCRegister AssignedAGPR) const;
77109
bool run(MachineFunction &MF) const;
78110
};
79111

@@ -152,6 +184,88 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::recomputeRegClassExceptRewritable(
152184
return true;
153185
}
154186

187+
bool AMDGPURewriteAGPRCopyMFMAImpl::tryReassigningMFMAChain(
188+
MachineInstr &MFMA, unsigned HintOpIdx, MCPhysReg PhysRegHint) const {
189+
// src2 and dst have the same physical class constraint; try to preserve
190+
// the original src2 subclass if one were to exist.
191+
SmallVector<MachineInstr *, 4> RewriteCandidates = {&MFMA};
192+
SmallSetVector<Register, 4> RewriteRegs;
193+
194+
Register MFMAHintReg = MFMA.getOperand(HintOpIdx).getReg();
195+
// Make sure we reassign the MFMA we found the copy from first. We want
196+
// to ensure dst ends up in the physreg we were originally copying to.
197+
RewriteRegs.insert(MFMAHintReg);
198+
199+
// We've found av = COPY (MFMA), and need to verify that we can trivially
200+
// rewrite src2 to use the new AGPR. If we can't trivially replace it,
201+
// we're going to induce as many copies as we would have emitted in the
202+
// first place, as well as need to assign another register, and need to
203+
// figure out where to put them. The live range splitting is smarter than
204+
// anything we're doing here, so trust it did something reasonable.
205+
//
206+
// Note recomputeRegClassExceptRewritable will consider the constraints of
207+
// this MFMA's src2 as well as the src2/dst of any transitive MFMA users.
208+
if (!recomputeRegClassExceptRewritable(MFMAHintReg, RewriteCandidates,
209+
RewriteRegs)) {
210+
LLVM_DEBUG(dbgs() << "Could not recompute the regclass of dst reg "
211+
<< printReg(MFMAHintReg, &TRI) << '\n');
212+
return false;
213+
}
214+
215+
// If src2 and dst are different registers, we need to also reassign the
216+
// input to an available AGPR if it is compatible with all other uses.
217+
//
218+
// If we can't reassign it, we'd need to introduce a different copy
219+
// which is likely worse than the copy we'd be saving.
220+
//
221+
// It's likely that the MFMA is used in sequence with other MFMAs; if we
222+
// cannot migrate the full use/def chain of MFMAs, we would need to
223+
// introduce intermediate copies somewhere. So we only make the
224+
// transform if all the interfering MFMAs can also be migrated. Collect
225+
// the set of rewritable MFMAs and check if we can assign an AGPR at
226+
// that point.
227+
//
228+
// If any of the MFMAs aren't reassignable, we give up and rollback to
229+
// the original register assignments.
230+
231+
using RecoloringStack =
232+
SmallVector<std::pair<const LiveInterval *, MCRegister>, 8>;
233+
RecoloringStack TentativeReassignments;
234+
235+
for (Register RewriteReg : RewriteRegs) {
236+
LiveInterval &LI = LIS.getInterval(RewriteReg);
237+
TentativeReassignments.push_back({&LI, VRM.getPhys(RewriteReg)});
238+
LRM.unassign(LI);
239+
}
240+
241+
if (!attemptReassignmentsToAGPR(RewriteRegs, PhysRegHint)) {
242+
// Roll back the register assignments to the original state.
243+
for (auto [LI, OldAssign] : TentativeReassignments) {
244+
if (VRM.hasPhys(LI->reg()))
245+
LRM.unassign(*LI);
246+
LRM.assign(*LI, OldAssign);
247+
}
248+
249+
return false;
250+
}
251+
252+
// Fixup the register classes of the virtual registers now that we've
253+
// committed to the reassignments.
254+
for (Register InterferingReg : RewriteRegs) {
255+
const TargetRegisterClass *EquivalentAGPRRegClass =
256+
TRI.getEquivalentAGPRClass(MRI.getRegClass(InterferingReg));
257+
MRI.setRegClass(InterferingReg, EquivalentAGPRRegClass);
258+
}
259+
260+
for (MachineInstr *RewriteCandidate : RewriteCandidates) {
261+
int NewMFMAOp =
262+
AMDGPU::getMFMASrcCVDstAGPROp(RewriteCandidate->getOpcode());
263+
RewriteCandidate->setDesc(TII.get(NewMFMAOp));
264+
}
265+
266+
return true;
267+
}
268+
155269
/// Attempt to reassign the registers in \p InterferingRegs to be AGPRs, with a
156270
/// preference to use \p PhysReg first. Returns false if the reassignments
157271
/// cannot be trivially performed.
@@ -204,6 +318,78 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::attemptReassignmentsToAGPR(
204318
return true;
205319
}
206320

321+
/// Identify copies that look like:
322+
/// %vdst:vgpr = V_MFMA_.. %src0:av, %src1:av, %src2:vgpr
323+
/// %agpr = COPY %vgpr
324+
///
325+
/// Then try to replace the transitive uses of %src2 and %vdst with the AGPR
326+
/// versions of the MFMA. This should cover the common case.
327+
bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesToAGPR(
328+
Register VReg, MCRegister AssignedAGPR) const {
329+
bool MadeChange = false;
330+
for (MachineInstr &UseMI : MRI.def_instructions(VReg)) {
331+
if (!UseMI.isCopy())
332+
continue;
333+
334+
Register CopySrcReg = UseMI.getOperand(1).getReg();
335+
if (!CopySrcReg.isVirtual())
336+
continue;
337+
338+
// TODO: Handle loop phis copied to AGPR. e.g.
339+
//
340+
// loop:
341+
// %phi:vgpr = COPY %mfma:vgpr
342+
// %mfma:vgpr = V_MFMA_xxx_vgprcd_e64 %a, %b, %phi
343+
// s_cbranch_vccnz loop
344+
//
345+
// endloop:
346+
// %agpr = mfma
347+
//
348+
// We need to be sure that %phi is assigned to the same physical register as
349+
// %mfma, or else we will just be moving copies into the loop.
350+
351+
for (MachineInstr &CopySrcDefMI : MRI.def_instructions(CopySrcReg)) {
352+
if (isRewriteCandidate(CopySrcDefMI) &&
353+
tryReassigningMFMAChain(CopySrcDefMI, 0, AssignedAGPR))
354+
MadeChange = true;
355+
}
356+
}
357+
358+
return MadeChange;
359+
}
360+
361+
/// Identify copies that look like:
362+
/// %src:vgpr = COPY %src:agpr
363+
/// %vdst:vgpr = V_MFMA_... %src0:av, %src1:av, %src:vgpr
364+
///
365+
/// Then try to replace the transitive uses of %src2 and %vdst with the AGPR
366+
/// versions of the MFMA. This should cover rarer cases, and will generally be
367+
/// redundant with tryFoldCopiesToAGPR.
368+
bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesFromAGPR(
369+
Register VReg, MCRegister AssignedAGPR) const {
370+
bool MadeChange = false;
371+
for (MachineInstr &UseMI : MRI.use_instructions(VReg)) {
372+
if (!UseMI.isCopy())
373+
continue;
374+
375+
Register CopyDstReg = UseMI.getOperand(0).getReg();
376+
if (!CopyDstReg.isVirtual())
377+
continue;
378+
379+
for (MachineInstr &CopyUseMI : MRI.use_instructions(CopyDstReg)) {
380+
if (isRewriteCandidate(CopyUseMI)) {
381+
const MachineOperand *Op =
382+
CopyUseMI.findRegisterUseOperand(CopyDstReg, /*TRI=*/nullptr);
383+
if (tryReassigningMFMAChain(CopyUseMI, Op->getOperandNo(),
384+
VRM.getPhys(Op->getReg())))
385+
MadeChange = true;
386+
}
387+
}
388+
}
389+
390+
return MadeChange;
391+
}
392+
207393
bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
208394
// This only applies on subtargets that have a configurable AGPR vs. VGPR
209395
// allocation.
@@ -220,124 +406,14 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
220406

221407
for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
222408
Register VReg = Register::index2VirtReg(I);
223-
Register PhysReg = VRM.getPhys(VReg);
224-
if (!PhysReg)
225-
continue;
226-
227-
// Find AV_* registers assigned to AGPRs.
228-
const TargetRegisterClass *VirtRegRC = MRI.getRegClass(VReg);
229-
if (!TRI.hasAGPRs(VirtRegRC))
409+
MCRegister AssignedAGPR = getAssignedAGPR(VReg);
410+
if (!AssignedAGPR)
230411
continue;
231412

232-
const TargetRegisterClass *AssignedRC = VirtRegRC;
233-
if (TRI.hasVGPRs(VirtRegRC)) {
234-
// If this is an AV register, we have to check if the actual assignment is
235-
// to an AGPR
236-
AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
237-
if (!TRI.isAGPRClass(AssignedRC))
238-
continue;
239-
}
240-
241-
LiveInterval &LI = LIS.getInterval(VReg);
242-
243-
for (VNInfo *VNI : LI.vnis()) {
244-
if (VNI->isPHIDef() || VNI->isUnused())
245-
continue;
246-
247-
MachineInstr *DefMI = LIS.getInstructionFromIndex(VNI->def);
248-
if (!DefMI || !DefMI->isCopy())
249-
continue;
250-
251-
Register MFMADstReg = DefMI->getOperand(1).getReg();
252-
if (!MFMADstReg.isVirtual())
253-
continue;
254-
255-
LiveInterval &CopySrcLI = LIS.getInterval(MFMADstReg);
256-
LiveQueryResult LRQ = CopySrcLI.Query(VNI->def.getRegSlot());
257-
MachineInstr *MFMA = LIS.getInstructionFromIndex(LRQ.valueIn()->def);
258-
if (!MFMA || !isRewriteCandidate(*MFMA))
259-
continue;
260-
261-
// src2 and dst have the same physical class constraint; try to preserve
262-
// the original src2 subclass if one were to exist.
263-
SmallVector<MachineInstr *, 4> RewriteCandidates = {MFMA};
264-
SmallSetVector<Register, 4> RewriteRegs;
265-
266-
// Make sure we reassign the MFMA we found the copy from first. We want
267-
// to ensure dst ends up in the physreg we were originally copying to.
268-
RewriteRegs.insert(MFMADstReg);
269-
270-
// We've found av = COPY (MFMA), and need to verify that we can trivially
271-
// rewrite src2 to use the new AGPR. If we can't trivially replace it,
272-
// we're going to induce as many copies as we would have emitted in the
273-
// first place, as well as need to assign another register, and need to
274-
// figure out where to put them. The live range splitting is smarter than
275-
// anything we're doing here, so trust it did something reasonable.
276-
//
277-
// Note recomputeRegClassExceptRewritable will consider the constraints of
278-
// this MFMA's src2 as well as the src2/dst of any transitive MFMA users.
279-
if (!recomputeRegClassExceptRewritable(MFMADstReg, RewriteCandidates,
280-
RewriteRegs)) {
281-
LLVM_DEBUG(dbgs() << "Could not recompute the regclass of dst reg "
282-
<< printReg(MFMADstReg, &TRI) << '\n');
283-
continue;
284-
}
285-
286-
// If src2 and dst are different registers, we need to also reassign the
287-
// input to an available AGPR if it is compatible with all other uses.
288-
//
289-
// If we can't reassign it, we'd need to introduce a different copy
290-
// which is likely worse than the copy we'd be saving.
291-
//
292-
// It's likely that the MFMA is used in sequence with other MFMAs; if we
293-
// cannot migrate the full use/def chain of MFMAs, we would need to
294-
// introduce intermediate copies somewhere. So we only make the
295-
// transform if all the interfering MFMAs can also be migrated. Collect
296-
// the set of rewritable MFMAs and check if we can assign an AGPR at
297-
// that point.
298-
//
299-
// If any of the MFMAs aren't reassignable, we give up and rollback to
300-
// the original register assignments.
301-
302-
using RecoloringStack =
303-
SmallVector<std::pair<const LiveInterval *, MCRegister>, 8>;
304-
RecoloringStack TentativeReassignments;
305-
306-
for (Register RewriteReg : RewriteRegs) {
307-
LiveInterval &LI = LIS.getInterval(RewriteReg);
308-
TentativeReassignments.push_back({&LI, VRM.getPhys(RewriteReg)});
309-
LRM.unassign(LI);
310-
}
311-
312-
if (!attemptReassignmentsToAGPR(RewriteRegs, PhysReg)) {
313-
// Roll back the register assignments to the original state.
314-
for (auto [LI, OldAssign] : TentativeReassignments) {
315-
if (VRM.hasPhys(LI->reg()))
316-
LRM.unassign(*LI);
317-
LRM.assign(*LI, OldAssign);
318-
}
319-
320-
continue;
321-
}
322-
323-
// Fixup the register classes of the virtual registers now that we've
324-
// committed to the reassignments.
325-
for (Register InterferingReg : RewriteRegs) {
326-
const TargetRegisterClass *EquivalentAGPRRegClass =
327-
TRI.getEquivalentAGPRClass(MRI.getRegClass(InterferingReg));
328-
MRI.setRegClass(InterferingReg, EquivalentAGPRRegClass);
329-
}
330-
331-
for (MachineInstr *RewriteCandidate : RewriteCandidates) {
332-
int NewMFMAOp =
333-
AMDGPU::getMFMASrcCVDstAGPROp(RewriteCandidate->getOpcode());
334-
RewriteCandidate->setDesc(TII.get(NewMFMAOp));
335-
}
336-
337-
// We likely left an identity copy behind after assignment; let
338-
// VirtRegRewriter deal with it later.
413+
if (tryFoldCopiesToAGPR(VReg, AssignedAGPR))
414+
MadeChange = true;
415+
if (tryFoldCopiesFromAGPR(VReg, AssignedAGPR))
339416
MadeChange = true;
340-
}
341417
}
342418

343419
return MadeChange;

0 commit comments

Comments
 (0)