diff --git a/chisel/playground/src/pipeline/execute/ExecuteUnit.scala b/chisel/playground/src/pipeline/execute/ExecuteUnit.scala index fcf5bf5..a36bd41 100644 --- a/chisel/playground/src/pipeline/execute/ExecuteUnit.scala +++ b/chisel/playground/src/pipeline/execute/ExecuteUnit.scala @@ -16,7 +16,7 @@ class ExecuteUnit(implicit val config: CpuConfig) extends Module { val csr = Flipped(new CsrExecuteUnit()) val bpu = new ExecuteUnitBranchPredictor() val fetchUnit = Output(new Bundle { - val branch = Bool() + val flush = Bool() val target = UInt(PC_WID.W) }) val decoderUnit = new Bundle { @@ -57,8 +57,7 @@ class ExecuteUnit(implicit val config: CpuConfig) extends Module { io.ctrl.inst(0).reg_waddr := io.executeStage.inst0.info.reg_waddr io.ctrl.inst(1).mem_wreg := io.executeStage.inst1.info.mem_wreg io.ctrl.inst(1).reg_waddr := io.executeStage.inst1.info.reg_waddr - io.ctrl.branch := valid(0) && io.ctrl.allow_to_go && - (fu.branch.jump_regiser || fu.branch.pred_fail) + io.ctrl.branch := io.fetchUnit.flush io.csr.in.valid := is_csr.asUInt.orR io.csr.in.info := MuxCase( @@ -116,7 +115,8 @@ class ExecuteUnit(implicit val config: CpuConfig) extends Module { io.bpu.branch := fu.branch.branch io.bpu.branch_inst := io.executeStage.inst0.jb_info.branch_inst - io.fetchUnit.branch := io.ctrl.branch + io.fetchUnit.flush := valid(0) && io.ctrl.allow_to_go && + fu.branch.flush io.fetchUnit.target := fu.branch.target io.ctrl.fu_stall := fu.stall_req @@ -151,8 +151,8 @@ class ExecuteUnit(implicit val config: CpuConfig) extends Module { ) ) io.memoryStage.inst0.ex.exception(instrAddrMisaligned) := io.executeStage.inst0.ex.exception(instrAddrMisaligned) || - io.fetchUnit.branch && io.fetchUnit.target(1, 0).orR - when(io.fetchUnit.branch && io.fetchUnit.target(1, 0).orR) { + io.fetchUnit.flush && io.fetchUnit.target(1, 0).orR + when(io.fetchUnit.flush && io.fetchUnit.target(1, 0).orR) { io.memoryStage.inst0.ex.tval := io.fetchUnit.target } @@ -178,8 +178,8 @@ class ExecuteUnit(implicit val config: CpuConfig) extends Module { ) ) io.memoryStage.inst1.ex.exception(instrAddrMisaligned) := io.executeStage.inst1.ex.exception(instrAddrMisaligned) || - io.fetchUnit.branch && io.fetchUnit.target(1, 0).orR - when(io.fetchUnit.branch && io.fetchUnit.target(1, 0).orR) { + io.fetchUnit.flush && io.fetchUnit.target(1, 0).orR + when(io.fetchUnit.flush && io.fetchUnit.target(1, 0).orR) { io.memoryStage.inst1.ex.tval := io.fetchUnit.target } diff --git a/chisel/playground/src/pipeline/execute/Fu.scala b/chisel/playground/src/pipeline/execute/Fu.scala index dff5d7f..34b0b2b 100644 --- a/chisel/playground/src/pipeline/execute/Fu.scala +++ b/chisel/playground/src/pipeline/execute/Fu.scala @@ -31,7 +31,7 @@ class Fu(implicit val config: CpuConfig) extends Module { val jump_regiser = Input(Bool()) val branch_target = Input(UInt(PC_WID.W)) val branch = Output(Bool()) - val pred_fail = Output(Bool()) + val flush = Output(Bool()) val target = Output(UInt(PC_WID.W)) } }) @@ -46,7 +46,7 @@ class Fu(implicit val config: CpuConfig) extends Module { branchCtrl.in.jump_regiser := io.branch.jump_regiser branchCtrl.in.branch_target := io.branch.branch_target io.branch.branch := branchCtrl.out.branch - io.branch.pred_fail := branchCtrl.out.pred_fail + io.branch.flush := (branchCtrl.out.pred_fail || io.branch.jump_regiser) io.branch.target := branchCtrl.out.target for (i <- 0 until (config.fuNum)) { diff --git a/chisel/playground/src/pipeline/fetch/FetchUnit.scala b/chisel/playground/src/pipeline/fetch/FetchUnit.scala index 3388875..10bded8 100644 --- a/chisel/playground/src/pipeline/fetch/FetchUnit.scala +++ b/chisel/playground/src/pipeline/fetch/FetchUnit.scala @@ -5,20 +5,21 @@ import chisel3.util._ import cpu.defines.Const._ import cpu.CpuConfig -class FetchUnit(implicit - val config: CpuConfig, -) extends Module { +class FetchUnit( + implicit + val config: CpuConfig) + extends Module { val io = IO(new Bundle { val memory = new Bundle { - val flush = Input(Bool()) - val flush_pc = Input(UInt(PC_WID.W)) + val flush = Input(Bool()) + val target = Input(UInt(PC_WID.W)) } val decoder = new Bundle { val branch = Input(Bool()) val target = Input(UInt(PC_WID.W)) } val execute = new Bundle { - val branch = Input(Bool()) + val flush = Input(Bool()) val target = Input(UInt(PC_WID.W)) } val instFifo = new Bundle { @@ -48,10 +49,10 @@ class FetchUnit(implicit io.iCache.pc_next := MuxCase( pc_next_temp, Seq( - io.memory.flush -> io.memory.flush_pc, - io.execute.branch -> io.execute.target, + io.memory.flush -> io.memory.target, + io.execute.flush -> io.execute.target, io.decoder.branch -> io.decoder.target, - io.instFifo.full -> pc, - ), + io.instFifo.full -> pc + ) ) } diff --git a/chisel/playground/src/pipeline/memory/MemoryUnit.scala b/chisel/playground/src/pipeline/memory/MemoryUnit.scala index 7099d7d..7959fc2 100644 --- a/chisel/playground/src/pipeline/memory/MemoryUnit.scala +++ b/chisel/playground/src/pipeline/memory/MemoryUnit.scala @@ -14,8 +14,8 @@ class MemoryUnit(implicit val config: CpuConfig) extends Module { val ctrl = new MemoryCtrl() val memoryStage = Input(new ExecuteUnitMemoryUnit()) val fetchUnit = Output(new Bundle { - val flush = Bool() - val flush_pc = UInt(PC_WID.W) + val flush = Bool() + val target = UInt(PC_WID.W) }) val decoderUnit = Output(Vec(config.fuNum, new RegWrite())) val csr = Flipped(new CsrMemoryUnit()) @@ -25,7 +25,7 @@ class MemoryUnit(implicit val config: CpuConfig) extends Module { val dataMemoryAccess = Module(new DataMemoryAccess()).io dataMemoryAccess.memoryUnit.in.mem_en := io.memoryStage.inst0.mem.en - dataMemoryAccess.memoryUnit.in.info := io.memoryStage.inst0.mem.info + dataMemoryAccess.memoryUnit.in.info := io.memoryStage.inst0.mem.info dataMemoryAccess.memoryUnit.in.mem_wdata := io.memoryStage.inst0.mem.wdata dataMemoryAccess.memoryUnit.in.mem_addr := io.memoryStage.inst0.mem.addr dataMemoryAccess.memoryUnit.in.mem_sel := io.memoryStage.inst0.mem.sel @@ -43,7 +43,7 @@ class MemoryUnit(implicit val config: CpuConfig) extends Module { io.decoderUnit(1).wdata := io.writeBackStage.inst1.rd_info.wdata(io.writeBackStage.inst1.info.fusel) io.writeBackStage.inst0.pc := io.memoryStage.inst0.pc - io.writeBackStage.inst0.info := io.memoryStage.inst0.info + io.writeBackStage.inst0.info := io.memoryStage.inst0.info io.writeBackStage.inst0.rd_info.wdata := io.memoryStage.inst0.rd_info.wdata io.writeBackStage.inst0.rd_info.wdata(FuType.lsu) := dataMemoryAccess.memoryUnit.out.rdata io.writeBackStage.inst0.ex := io.memoryStage.inst0.ex @@ -54,7 +54,7 @@ class MemoryUnit(implicit val config: CpuConfig) extends Module { io.writeBackStage.inst0.commit := io.memoryStage.inst0.info.valid io.writeBackStage.inst1.pc := io.memoryStage.inst1.pc - io.writeBackStage.inst1.info := io.memoryStage.inst1.info + io.writeBackStage.inst1.info := io.memoryStage.inst1.info io.writeBackStage.inst1.rd_info.wdata := io.memoryStage.inst1.rd_info.wdata io.writeBackStage.inst1.rd_info.wdata(FuType.lsu) := dataMemoryAccess.memoryUnit.out.rdata io.writeBackStage.inst1.ex := io.memoryStage.inst1.ex @@ -80,8 +80,8 @@ class MemoryUnit(implicit val config: CpuConfig) extends Module { 0.U.asTypeOf(new InstInfo()) ) - io.fetchUnit.flush := io.csr.out.flush && io.ctrl.allow_to_go - io.fetchUnit.flush_pc := Mux(io.csr.out.flush, io.csr.out.flush_pc, io.writeBackStage.inst0.pc + 4.U) + io.fetchUnit.flush := io.csr.out.flush && io.ctrl.allow_to_go + io.fetchUnit.target := Mux(io.csr.out.flush, io.csr.out.flush_pc, io.writeBackStage.inst0.pc + 4.U) io.ctrl.flush_req := io.fetchUnit.flush }