blob: 7b9e9fb988bc68e25cf6256e247c718adfcc0236 [file] [log] [blame]
//===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file implements the machine function pass to insert read/write of CSR-s
// of the RISC-V instructions.
//
// Currently the pass implements:
// -Writing and saving frm before an RVV floating-point instruction with a
// static rounding mode and restores the value after.
//
//===----------------------------------------------------------------------===//
#include "MCTargetDesc/RISCVBaseInfo.h"
#include "RISCV.h"
#include "RISCVSubtarget.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
using namespace llvm;
#define DEBUG_TYPE "riscv-insert-read-write-csr"
#define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
static cl::opt<bool>
DisableFRMInsertOpt("riscv-disable-frm-insert-opt", cl::init(false),
cl::Hidden,
cl::desc("Disable optimized frm insertion."));
namespace {
class RISCVInsertReadWriteCSR : public MachineFunctionPass {
const TargetInstrInfo *TII;
public:
static char ID;
RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
MachineFunctionPass::getAnalysisUsage(AU);
}
StringRef getPassName() const override {
return RISCV_INSERT_READ_WRITE_CSR_NAME;
}
private:
bool emitWriteRoundingMode(MachineBasicBlock &MBB);
bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB);
};
} // end anonymous namespace
char RISCVInsertReadWriteCSR::ID = 0;
INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
// TODO: Use more accurate rounding mode at the start of MBB.
bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) {
bool Changed = false;
MachineInstr *LastFRMChanger = nullptr;
unsigned CurrentRM = RISCVFPRndMode::DYN;
Register SavedFRM;
for (MachineInstr &MI : MBB) {
if (MI.getOpcode() == RISCV::SwapFRMImm ||
MI.getOpcode() == RISCV::WriteFRMImm) {
CurrentRM = MI.getOperand(0).getImm();
SavedFRM = Register();
continue;
}
if (MI.getOpcode() == RISCV::WriteFRM) {
CurrentRM = RISCVFPRndMode::DYN;
SavedFRM = Register();
continue;
}
if (MI.isCall() || MI.isInlineAsm() ||
MI.readsRegister(RISCV::FRM, /*TRI=*/nullptr)) {
// Restore FRM before unknown operations.
if (SavedFRM.isValid())
BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM))
.addReg(SavedFRM);
CurrentRM = RISCVFPRndMode::DYN;
SavedFRM = Register();
continue;
}
assert(!MI.modifiesRegister(RISCV::FRM, /*TRI=*/nullptr) &&
"Expected that MI could not modify FRM.");
int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
if (FRMIdx < 0)
continue;
unsigned InstrRM = MI.getOperand(FRMIdx).getImm();
LastFRMChanger = &MI;
// Make MI implicit use FRM.
MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
/*IsImp*/ true));
Changed = true;
// Skip if MI uses same rounding mode as FRM.
if (InstrRM == CurrentRM)
continue;
if (!SavedFRM.isValid()) {
// Save current FRM value to SavedFRM.
MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), SavedFRM)
.addImm(InstrRM);
} else {
// Don't need to save current FRM when SavedFRM having value.
BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm))
.addImm(InstrRM);
}
CurrentRM = InstrRM;
}
// Restore FRM if needed.
if (SavedFRM.isValid()) {
assert(LastFRMChanger && "Expected valid pointer.");
MachineInstrBuilder MIB =
BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
.addReg(SavedFRM);
MBB.insertAfter(LastFRMChanger, MIB);
}
return Changed;
}
// This function also swaps frm and restores it when encountering an RVV
// floating point instruction with a static rounding mode.
bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
bool Changed = false;
for (MachineInstr &MI : MBB) {
int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
if (FRMIdx < 0)
continue;
unsigned FRMImm = MI.getOperand(FRMIdx).getImm();
// The value is a hint to this pass to not alter the frm value.
if (FRMImm == RISCVFPRndMode::DYN)
continue;
Changed = true;
// Save
MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
SavedFRM)
.addImm(FRMImm);
MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
/*IsImp*/ true));
// Restore
MachineInstrBuilder MIB =
BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
.addReg(SavedFRM);
MBB.insertAfter(MI, MIB);
}
return Changed;
}
bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
// Skip if the vector extension is not enabled.
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
if (!ST.hasVInstructions())
return false;
TII = ST.getInstrInfo();
bool Changed = false;
for (MachineBasicBlock &MBB : MF) {
if (DisableFRMInsertOpt)
Changed |= emitWriteRoundingMode(MBB);
else
Changed |= emitWriteRoundingModeOpt(MBB);
}
return Changed;
}
FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() {
return new RISCVInsertReadWriteCSR();
}