banner
CedricXu

CedricXu

计科学生 / 摄影爱好者

[RISCV]手寫數字識別

介紹#

qwii3

最近正在學習伯克利的 CS61C 這門課,其中 Project2 是使用 RISC-V 實現手寫數字識別。

聽起來很複雜,但做起來其實還好,主要考驗的是如何高效利用寄存器,使用組合語言編寫和調用函數,如何從堆棧上手動分配內存,以及使用 Venus 調試組合語言程序的能力。

最後,只需要把它們連接起來,就能組成一個能對手寫數字進行分類的人工神經網絡 (ANN)

RISC-V 調用約定#

t1xa2

要說什麼是 RISC-V 中編程最重要的部分,那非 Calling Convention 莫屬了。它讓函數能夠自由地使用寄存器而不用擔心產生錯誤,試想一下如果 A 函數要使用 s1 寄存器,但是此時 s1 寄存器裡保存了 B 函數的重要的值,如果 A 函數把它改掉了那麼 B 函數在使用它時就會發生錯誤。如果我們編程時需要時刻小心翼翼地檢查我們要使用的寄存器是不是還被別的函數需要,那我們的編程會變成一場災難。這時 Calling Convention 就能讓我們的生活變得更美好!

在函數調用時我們把發起調用的函數稱為 Caller,把被調用的函數稱為 Callee。它們對寄存器都有不同的保存責任

Saved Registers (Callee Saved)#

sb11b

  • s0-s11(saved registers)
  • sp(stack pointer)

這些寄存器對 Caller 來說十分重要,它在裡面保存了一些重要的值,所以 Caller 希望在調用函數後這些值不會變化

這就對 Callee 產生了要求,它需要:

  1. 在堆棧上分配空間 (減 sp 寄存器)
  2. 把它要用的 s 寄存器保存在堆棧上
  3. 隨意調用 s 寄存器
  4. 從堆棧上把原先的 s 寄存器值恢復
  5. 釋放堆棧上的空間 (加 sp 寄存器)

這樣 Callee 既向 Caller 保證了在調用它前後 s 和 sp 寄存器的值都是不變的,又讓自己可以隨意使用這些寄存器

Volatile Registers (Caller Saved)#

g71ox

  • t0-t6 (temporary registers)
  • a0-a7 (arguments & return values)
  • ra (return address)

這些寄存器 Callee 沒有義務保存,因為它只會保存最重要的 s 寄存器。所以如果 Caller 需要它們,只能在調用函數前自行保存,在調用函數後自行恢復了

Structure of a Function#

所以根據我們的 Calling Convention,一個函數的結構應該是這樣的:

xnes0

首先這個函數作為 Callee,要保存會使用到的 s 寄存器。如果這個函數要調用別的函數,那麼作為 Caller,它要保存自己需要的一些其他寄存器,在調用結束後恢復這些值,在函數結束時恢復 s 寄存器的值,最後跳轉回被調用的地方

很簡單的思想:所有的 Callee 都對 Caller 負責,那麼當 Callee 自己當 Caller 時也不用擔心自己的重要寄存器被篡改

神經網絡#

5vldo

我們要用 RISC-V 寫一個神經網絡來識別數字。簡單來說,神經網絡想要近似一個將輸入映射到輸出的非線性函數。在這個項目中,我們已經有了預訓練的矩陣 $m_0$ 和 $m_1$,因此我們只需使用它們進行推理。我們的輸入是 MNIST 數據集,其中包含 60,000 個 28x28 像素的圖像,涵蓋了手寫數字 0-9

我們需要編寫以下函數:

  • relu:激活函數 $f (x)=max (0,x)$
  • argmax: 返回向量中最大元素索引
  • dot:向量點乘
  • matmul:矩陣乘法
  • read_matrix:讀取矩陣文件
  • write_matrix:寫入矩陣文件
  • calssify:調用以上函數連接各層

同時我們需要編寫測試文件來測試程序的正確性,讓我們用向量點乘舉個例子

dot.s#

功能:將兩個向量點乘

輸入:

  • a0 (int*) 指向 v0 第一個元素的指針
  • a1 (int*) 指向 v1 第一個元素的指針
  • a2 (int) 向量長度
  • a3 (int) v0 的步長
  • a4 (int) v1 的步長

返回值:a0 (int) 點乘的結果

代碼:

dot:
    bge x0, a2, exit_5
    bge x0, a3, exit_6
    bge x0, a4, exit_6

    li t0 0 # loop counter
    li t4 0 # dot product accumulator

    #Multiply stride by 4 to get byte offset
    slli a3 a3 2 
    slli a4 a4 2
 
loop_start:
    beq t0, a2, loop_end

    lw t1, 0(a0)
    lw t2, 0(a1)

    mul t3, t1, t2
    add t4, t4, t3

    add a0, a0, a3
    add a1, a1, a4
    addi t0, t0, 1

    j loop_start

loop_end:
    mv a0 t4

    ret

這個點乘函數不需要調用任何函數,所以只作為 Callee,我們可以不使用任何 s 寄存器來取消保存寄存器到堆棧的這個環節,從而提高程序的運行速度

測試代碼:

# Set vector values for testing
.data
vector0: .word 1 2 3 4 5 6 7 8 9
vector1: .word 1 2 3 4 5 6 7 8 9

.text
# main function for testing
main:
    # Load vector addresses into registers
    la s0 vector0
    la s1 vector1

    # Set vector attributes
    addi a2, x0, 9
    addi a3, x0, 1
    addi a4, x0, 1

    # Call dot function
    mv a0, s0
    mv a1, s1
    jal ra, dot

    # Print integer result
    mv a1, a0
    jal ra, print_int

    # Print newline
    li a1 '\n'
    jal ra print_char

    # Exit
    jal exit

測試結果:

ndi5v

最終效果#

就這樣實現一個一個的函數,並且在最後把它們統統放到一起,我們就實現了一個能分類手寫數字的神經網絡

當我們輸入這張圖片:

j8mfi

運行程序後就得到了我們的結果:

ltdb0

總結#

這次用 RISC-V 編寫手寫數字識別的項目提高了我編寫組合語言的能力,同時也鍛煉了測試代碼的編寫能力。特別有意思的是編寫完向量內積、矩陣乘法等簡單的函數後可以實現一個功能強大的神經網絡,十分有成就感

9irzv

最後還是想感慨一下伯克利的分數設置,考試佔 40% 的分數,分散為三次,可以很檢測各階段的學習情況而不是到期末前一周瘋狂復習,四個有意思的項目 (組合手寫數字識別、畫一個自己的 CPU……) 佔 40%,作業實驗加考勤佔 20%,多元化且有意思,很好

活在當下!

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。