arch/riscv: Add SMP support for exception handler

Change-Id: Ia1f97b82e329f6358061072f98278cf56b503618
Signed-off-by: Xiang Wang <merle@hardenedlinux.org>
Reviewed-on: https://review.coreboot.org/c/coreboot/+/68841
Reviewed-by: Philipp Hug <philipp@hug.cx>
Tested-by: build bot (Jenkins) <no-reply@coreboot.org>
Reviewed-by: ron minnich <rminnich@gmail.com>
diff --git a/src/arch/riscv/payload.c b/src/arch/riscv/payload.c
index ee2ee8e..443975b 100644
--- a/src/arch/riscv/payload.c
+++ b/src/arch/riscv/payload.c
@@ -6,6 +6,7 @@
 #include <arch/encoding.h>
 #include <arch/smp/atomic.h>
 #include <console/console.h>
+#include <mcall.h>
 #include <vm.h>
 
 /* Run OpenSBI and let OpenSBI hand over control to the payload */
@@ -47,6 +48,8 @@
 		write_csr(sie, 0);
 		/* disable MMU */
 		write_csr(satp, 0);
+		/* save stack to mscratch so trap_entry can use that as exception stack */
+		write_csr(mscratch, MACHINE_STACK_TOP());
 		break;
 	case RISCV_PAYLOAD_MODE_M:
 		status = INSERT_FIELD(status, MSTATUS_MPP, PRV_M);
diff --git a/src/arch/riscv/trap_util.S b/src/arch/riscv/trap_util.S
index d7b1250..d6a93b0 100644
--- a/src/arch/riscv/trap_util.S
+++ b/src/arch/riscv/trap_util.S
@@ -7,129 +7,129 @@
 #include <mcall.h>
 
 .macro restore_regs
-    # restore x registers
-    LOAD  x1,1*REGBYTES(a0)
-    LOAD  x2,2*REGBYTES(a0)
-    LOAD  x3,3*REGBYTES(a0)
-    LOAD  x4,4*REGBYTES(a0)
-    LOAD  x5,5*REGBYTES(a0)
-    LOAD  x6,6*REGBYTES(a0)
-    LOAD  x7,7*REGBYTES(a0)
-    LOAD  x8,8*REGBYTES(a0)
-    LOAD  x9,9*REGBYTES(a0)
-    LOAD  x11,11*REGBYTES(a0)
-    LOAD  x12,12*REGBYTES(a0)
-    LOAD  x13,13*REGBYTES(a0)
-    LOAD  x14,14*REGBYTES(a0)
-    LOAD  x15,15*REGBYTES(a0)
-    LOAD  x16,16*REGBYTES(a0)
-    LOAD  x17,17*REGBYTES(a0)
-    LOAD  x18,18*REGBYTES(a0)
-    LOAD  x19,19*REGBYTES(a0)
-    LOAD  x20,20*REGBYTES(a0)
-    LOAD  x21,21*REGBYTES(a0)
-    LOAD  x22,22*REGBYTES(a0)
-    LOAD  x23,23*REGBYTES(a0)
-    LOAD  x24,24*REGBYTES(a0)
-    LOAD  x25,25*REGBYTES(a0)
-    LOAD  x26,26*REGBYTES(a0)
-    LOAD  x27,27*REGBYTES(a0)
-    LOAD  x28,28*REGBYTES(a0)
-    LOAD  x29,29*REGBYTES(a0)
-    LOAD  x30,30*REGBYTES(a0)
-    LOAD  x31,31*REGBYTES(a0)
-    # restore a0 last
-    LOAD  x10,10*REGBYTES(a0)
+	# restore x registers
+	LOAD	 x1, 1 * REGBYTES(sp)
+	LOAD	 x3, 3 * REGBYTES(sp)
+	LOAD	 x4, 4 * REGBYTES(sp)
+	LOAD	 x5, 5 * REGBYTES(sp)
+	LOAD	 x6, 6 * REGBYTES(sp)
+	LOAD	 x7, 7 * REGBYTES(sp)
+	LOAD	 x8, 8 * REGBYTES(sp)
+	LOAD	 x9, 9 * REGBYTES(sp)
+	LOAD	x10, 10 * REGBYTES(sp)
+	LOAD	x11, 11 * REGBYTES(sp)
+	LOAD	x12, 12 * REGBYTES(sp)
+	LOAD	x13, 13 * REGBYTES(sp)
+	LOAD	x14, 14 * REGBYTES(sp)
+	LOAD	x15, 15 * REGBYTES(sp)
+	LOAD	x16, 16 * REGBYTES(sp)
+	LOAD	x17, 17 * REGBYTES(sp)
+	LOAD	x18, 18 * REGBYTES(sp)
+	LOAD	x19, 19 * REGBYTES(sp)
+	LOAD	x20, 20 * REGBYTES(sp)
+	LOAD	x21, 21 * REGBYTES(sp)
+	LOAD	x22, 22 * REGBYTES(sp)
+	LOAD	x23, 23 * REGBYTES(sp)
+	LOAD	x24, 24 * REGBYTES(sp)
+	LOAD	x25, 25 * REGBYTES(sp)
+	LOAD	x26, 26 * REGBYTES(sp)
+	LOAD	x27, 27 * REGBYTES(sp)
+	LOAD	x28, 28 * REGBYTES(sp)
+	LOAD	x29, 29 * REGBYTES(sp)
+	LOAD	x30, 30 * REGBYTES(sp)
+	LOAD	x31, 31 * REGBYTES(sp)
+.endm
 
-
-    .endm
 .macro save_tf
-  # save gprs
-  STORE  x1,1*REGBYTES(x2)
-  STORE  x3,3*REGBYTES(x2)
-  STORE  x4,4*REGBYTES(x2)
-  STORE  x5,5*REGBYTES(x2)
-  STORE  x6,6*REGBYTES(x2)
-  STORE  x7,7*REGBYTES(x2)
-  STORE  x8,8*REGBYTES(x2)
-  STORE  x9,9*REGBYTES(x2)
-  STORE  x10,10*REGBYTES(x2)
-  STORE  x11,11*REGBYTES(x2)
-  STORE  x12,12*REGBYTES(x2)
-  STORE  x13,13*REGBYTES(x2)
-  STORE  x14,14*REGBYTES(x2)
-  STORE  x15,15*REGBYTES(x2)
-  STORE  x16,16*REGBYTES(x2)
-  STORE  x17,17*REGBYTES(x2)
-  STORE  x18,18*REGBYTES(x2)
-  STORE  x19,19*REGBYTES(x2)
-  STORE  x20,20*REGBYTES(x2)
-  STORE  x21,21*REGBYTES(x2)
-  STORE  x22,22*REGBYTES(x2)
-  STORE  x23,23*REGBYTES(x2)
-  STORE  x24,24*REGBYTES(x2)
-  STORE  x25,25*REGBYTES(x2)
-  STORE  x26,26*REGBYTES(x2)
-  STORE  x27,27*REGBYTES(x2)
-  STORE  x28,28*REGBYTES(x2)
-  STORE  x29,29*REGBYTES(x2)
-  STORE  x30,30*REGBYTES(x2)
-  STORE  x31,31*REGBYTES(x2)
+	# save general purpose registers
+	# no point in saving x0 since it is always 0
+	STORE	 x1, 1 * REGBYTES(sp)
+	# x2 is our stack pointer and is saved further below
+	STORE	 x3, 3 * REGBYTES(sp)
+	STORE	 x4, 4 * REGBYTES(sp)
+	STORE	 x5, 5 * REGBYTES(sp)
+	STORE	 x6, 6 * REGBYTES(sp)
+	STORE	 x7, 7 * REGBYTES(sp)
+	STORE	 x8, 8 * REGBYTES(sp)
+	STORE	 x9, 9 * REGBYTES(sp)
+	STORE	x10, 10 * REGBYTES(sp)
+	STORE	x11, 11 * REGBYTES(sp)
+	STORE	x12, 12 * REGBYTES(sp)
+	STORE	x13, 13 * REGBYTES(sp)
+	STORE	x14, 14 * REGBYTES(sp)
+	STORE	x15, 15 * REGBYTES(sp)
+	STORE	x16, 16 * REGBYTES(sp)
+	STORE	x17, 17 * REGBYTES(sp)
+	STORE	x18, 18 * REGBYTES(sp)
+	STORE	x19, 19 * REGBYTES(sp)
+	STORE	x20, 20 * REGBYTES(sp)
+	STORE	x21, 21 * REGBYTES(sp)
+	STORE	x22, 22 * REGBYTES(sp)
+	STORE	x23, 23 * REGBYTES(sp)
+	STORE	x24, 24 * REGBYTES(sp)
+	STORE	x25, 25 * REGBYTES(sp)
+	STORE	x26, 26 * REGBYTES(sp)
+	STORE	x27, 27 * REGBYTES(sp)
+	STORE	x28, 28 * REGBYTES(sp)
+	STORE	x29, 29 * REGBYTES(sp)
+	STORE	x30, 30 * REGBYTES(sp)
+	STORE	x31, 31 * REGBYTES(sp)
 
-  # get sr, epc, badvaddr, cause
-  csrrw  t0,mscratch,x0
-  csrr   s0,mstatus
-  csrr   t1,mepc
-  csrr   t2,mtval
-  csrr   t3,mcause
-  STORE  t0,2*REGBYTES(x2)
-  STORE  s0,32*REGBYTES(x2)
-  STORE  t1,33*REGBYTES(x2)
-  STORE  t2,34*REGBYTES(x2)
-  STORE  t3,35*REGBYTES(x2)
+	# get sr, epc, badvaddr, cause
+	csrr	t0, mscratch
+	bnez	t0, 1f	# t0 == 0, trap come from coreboot
+			# t0 != 0, t0 is saved old sp
+	add	t0, sp, MENTRY_FRAME_SIZE
+1:
+	csrr	s0, mstatus
+	csrr	t1, mepc
+	csrr	t2, mtval
+	csrr	t3, mcause
+	STORE	t0, 2 * REGBYTES(sp)
+	STORE	s0, 32 * REGBYTES(sp)
+	STORE	t1, 33 * REGBYTES(sp)
+	STORE	t2, 34 * REGBYTES(sp)
+	STORE	t3, 35 * REGBYTES(sp)
 
-  # get faulting insn, if it wasn't a fetch-related trap
-  li x5,-1
-  STORE x5,36*REGBYTES(x2)
+	# get faulting insn, if it wasn't a fetch-related trap
+	li	x5, -1
+	STORE	x5, 36 * REGBYTES(sp)
+.endm
 
-  .endm
-
-.globl estack
-  .text
-
-    .global  trap_entry
-    .align 2	# four byte alignment, as required by mtvec
+	.text
+	.global  trap_entry
+	.align 2 # four byte alignment, as required by mtvec
 trap_entry:
-  csrw mscratch, sp
+	# mscratch is initialized to 0
+	# when exiting coreboot, write sp to mscratch
+	# before jumping to m-mode firmware we always set trap vector to the entry point of the
+	# payload and we don't care about mscratch anymore. mscratch is only ever used as
+	# exception stack if whatever coreboot jumps to is in s-mode.
+	#TODO we could check MPP field in mstatus to see if come from unpriviledged code. That
+	#     way we could still use mscratch for other purposes inside the code base.
+	#TODO In case we got called from s-mode firmware we need to protect our stack and trap
+	#     handler with a PMP region.
+	csrrw	sp, mscratch, sp
+	# sp == 0 => trap came from coreboot
+	# sp != 0 => trap came from s-mode payload
+	bnez	sp, 1f
+	csrrw	sp, mscratch, sp
+1:
+	addi	sp, sp, -MENTRY_FRAME_SIZE
+	save_tf
 
-  # SMP isn't supported yet, to avoid overwriting the same stack with different
-  # harts that handle traps at the same time.
-  # someday this gets fixed.
-  //csrr sp, mhartid
-  csrr sp, 0xf14
-.Lsmp_hang:
-  bnez sp, .Lsmp_hang
+	mv a0,sp # put trapframe as first argument
 
-  # Use a different stack than in the main context, to avoid overwriting
-  # stack data.
-  # TODO: Maybe use the old stack pointer (plus an offset) instead. But only if
-  # the previous mode was M, because it would be a very bad idea to use a stack
-  # pointer provided by unprivileged code!
-  la	sp, _estack
-  addi	sp, sp, -2048	# 2 KiB is half of the stack space
-  addi	sp, sp, -MENTRY_FRAME_SIZE
-
-  save_tf
-  move  a0,sp
-
-  # store pointer to stack frame (moved out from trap_handler)
-  csrw  mscratch, sp
-
-  LOAD	t0, trap_handler
-  jalr	t0
+	LOAD	t0, trap_handler
+	jalr	t0
 
 trap_return:
-	csrr	a0, mscratch
 	restore_regs
-	# go back to the previous mode
+	addi	sp, sp, MENTRY_FRAME_SIZE
+
+	# restore original stack pointer (either sp or mscratch)
+	csrrw	sp, mscratch, sp
+	bnez	sp, 1f
+	csrrw	sp, mscratch, sp
+1:
 	mret