--- title: Go 语言字符串相等函数的 SIMD 实现 tags: - Assembly - Course Project - Go categories: - - Coding - Programming Language date: 2021-12-30 17:24:00 --- # 背景说明 **Go 语言**是 Google 开发的一种静态强类型、编译型、并发型,并具有垃圾回收功能的编程语言。它用批判吸收的眼光,融合 C 语言、Java 等众家之长,将简洁、高效演绎得淋漓尽致。 **SIMD** 全称 Single Instruction Multiple Data,即单指令多数据流,用一条指令对多个数据进行操作。一般用向量寄存器来实现。常用于加速数据密集型运算,如数列求和、矩阵乘法。 本实验对 Go 语言自带的字符串相等函数的源码进行分析,用自己实现的函数进行替换,并比较性能。 # 探索过程 不同 CPU 对 SIMD 的支持不同。用 CPU-Z 查看当前 CPU 支持的 SIMD 指令集。最新的指令集为 **AVX2**。 Go 语言汇编器基于 Plan 9 汇编器的输入风格,与 GNU 汇编器不同。阅读和编写代码时需要注意。 Go 语言字符串相等函数的代码在 `GOROOT/src/internal/bytealg/equal_[arch].s` 文件中。当前 CPU 架构为 AMD64,对应文件为 `equal_amd64.s`。 文件中定义了三个函数:`runtime·memequal`,`runtime·memequal_varlen`,`memeqbody`。前两个函数为 ABI,供应用程序调用。它们将设置相应寄存器并跳转到第三个函数。部分定义如下。 ```x86asm // memequal(a, b unsafe.Pointer, size uintptr) bool TEXT runtime·memequal(SB),NOSPLIT,$0-25 // AX = a (want in SI) // BX = b (want in DI) // CX = size (want in BX) ... JMP memeqbody<>(SB) // memequal_varlen(a, b unsafe.Pointer) bool TEXT runtime·memequal_varlen(SB),NOSPLIT,$0-17 // AX = a (want in SI) // BX = b (want in DI) // 8(DX) = size (want in BX) ... JMP memeqbody<>(SB) // Input: // a in SI // b in DI // count in BX // Output: // result in AX TEXT memeqbody<>(SB),NOSPLIT,$0-0 ... RET ``` 我们只需要关注真正进行相等判断的 `memeqbody` 函数。它接收两个字符串的地址(`SI`、`DI`)以及字符串的长度(`BX`),返回 0 或 1(`AX`)。Go 语言的编译器保证这两个字符串的长度相等(否则可以直接判断不相等)。 原代码巧妙地利用了 SIMD。思路如下: 1. 如果字符串的长度不小于 64,则一轮循环比较 64 个字符(512 位),直到剩余长度小于 64; 2. 如果字符串的长度不小于 8,则一轮循环比较 8 个字符(64 位),直到剩余长度小于 8; 3. 比较剩余字符(不用循环)。 结合代码分析。在函数开头,根据字符串的长度进入不同的循环。 ```x86asm TEXT memeqbody<>(SB),NOSPLIT,$0-0 CMPQ BX, $8 JB small CMPQ BX, $64 JB bigloop CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1 JE hugeloop_avx2 ``` 第 6 行的代码判断 CPU 是否支持 AVX2 指令集,如果支持则用 `Y` 系列寄存器比较 64 个字符,否则用 `X` 系列寄存器比较 64 个字符。64 个字符的循环如下。 ```x86asm // 64 bytes at a time using xmm registers hugeloop: CMPQ BX, $64 JB bigloop MOVOU (SI), X0 MOVOU (DI), X1 MOVOU 16(SI), X2 MOVOU 16(DI), X3 MOVOU 32(SI), X4 MOVOU 32(DI), X5 MOVOU 48(SI), X6 MOVOU 48(DI), X7 PCMPEQB X1, X0 PCMPEQB X3, X2 PCMPEQB X5, X4 PCMPEQB X7, X6 PAND X2, X0 PAND X6, X4 PAND X4, X0 PMOVMSKB X0, DX ADDQ $64, SI ADDQ $64, DI SUBQ $64, BX CMPL DX, $0xffff JEQ hugeloop XORQ AX, AX // return 0 RET // 64 bytes at a time using ymm registers hugeloop_avx2: CMPQ BX, $64 JB bigloop_avx2 VMOVDQU (SI), Y0 VMOVDQU (DI), Y1 VMOVDQU 32(SI), Y2 VMOVDQU 32(DI), Y3 VPCMPEQB Y1, Y0, Y4 VPCMPEQB Y2, Y3, Y5 VPAND Y4, Y5, Y6 VPMOVMSKB Y6, DX ADDQ $64, SI ADDQ $64, DI SUBQ $64, BX CMPL DX, $0xffffffff JEQ hugeloop_avx2 VZEROUPPER XORQ AX, AX // return 0 RET ``` 如果发现不同,则直接返回 0,否则继续循环,直到剩余长度小于 64,进入 8 个字符的循环。一个通用寄存器刚好可以装下 8 个字符,所以不需要 SIMD。 ```x86asm bigloop_avx2: VZEROUPPER // 8 bytes at a time using 64-bit register bigloop: CMPQ BX, $8 JBE leftover MOVQ (SI), CX MOVQ (DI), DX ADDQ $8, SI ADDQ $8, DI SUBQ $8, BX CMPQ CX, DX JEQ bigloop XORQ AX, AX // return 0 RET ``` 最后比较剩余字符。不过这里并没有用循环,而是直接加载字符串末尾的 8 个字符(可能与之前判断过的字符重叠)。 ```x86asm // remaining 0-8 bytes leftover: MOVQ -8(SI)(BX*1), CX MOVQ -8(DI)(BX*1), DX CMPQ CX, DX SETEQ AX RET ``` 如果字符串的长度本来就小于 8,这么做会加载一些不属于字符串的字符。代码中对这种情况也做了处理。 可以看出,这个函数尽可能地使用了 SIMD,用一条指令判断多个字符是否相等,加快了处理速度,同时也保证了边界情况下的正确性,对较短的字符串进行特殊处理。 尝试用自己编写的函数进行替换。首先用最简单的方法,直接一个一个字符进行比较。 ```x86asm TEXT memeqbody<>(SB),NOSPLIT,$0-0 loop_1: CMPQ BX, $0 JEQ equal MOVB (SI), CX MOVB (DI), DX ADDQ $1, SI ADDQ $1, DI SUBQ $1, BX CMPB CX, DX JEQ loop_1 XORQ AX, AX RET equal: SETEQ AX RET ``` 代码很短,但显然效率不够高。添加 **8 个字符**的循环,剩余字符还是一个一个比较。 ```x86asm loop_8: CMPQ BX, $8 JB loop_1 MOVQ (SI), CX MOVQ (DI), DX ADDQ $8, SI ADDQ $8, DI SUBQ $8, BX CMPQ CX, DX JEQ loop_8 XORQ AX, AX RET ``` 当前 CPU 支持 MMX、SSE4.2、AVX2 指令集。**MMX** 的寄存器为 64 位寄存器,一个寄存器可以装下 8 个字符。将 8 个字符的循环改为使用 MMX 的寄存器。 ```x86asm loop_8_mmx: CMPQ BX, $8 JB loop_1 MOVQ (SI), M0 MOVQ (DI), M1 PCMPEQB M0, M1 MOVQ M1, CX ADDQ $8, SI ADDQ $8, DI SUBQ $8, BX CMPQ CX, $-1 JEQ loop_8_mmx XORQ AX, AX RET ``` 第 4、5 行将从地址 `SI`、`DI` 开始的 8 个字符分别存入 `M0`、`M1` 寄存器,即进行了打包。第 6 行用 `PCMPEQB` 指令(compare packed bytes for equal)比较 `M0`、`M1` 打包字节整数值的相等性,并将比较结果存入 `M1`。如果 `M0`、`M1` 中的某个字节相等,则 `M1` 中的这个字节会变成全 1(即有符号数的 -1),否则变成全 0。因为 MMX 指令不会修改状态寄存器,所以需要将 `M1` 的值存入 `CX`,再与 -1 比较。 **SSE4.2** 的寄存器为 128 位寄存器,一个寄存器可以装下 16 个字符。添加 **16 个字符**的循环。 ```x86asm loop_16_sse: CMPQ BX, $16 JB loop_8_mmx MOVUPD (SI), X0 MOVUPD (DI), X1 PCMPEQB X0, X1 PMOVMSKB X1, CX ADDQ $16, SI ADDQ $16, DI SUBQ $16, BX CMPW CX, $0xffff JEQ loop_16_sse XORQ AX, AX RET ``` 第 4、5 行用 `MOVUPD` 指令(move two unaligned packed double-precision floating-point values between XMM registers and memory)将从地址 `SI`、`DI` 开始的 16 个字符分别存入 `X0`、`X1` 寄存器。第 6 行将比较结果存入 `X1`。第 7 行用 `PMOVMSKB` 指令(move byte mask)将 `X1` 中每个字节的最高位提取出来存入 `CX` 的低位。这是因为 `X1` 寄存器有 128 位,无法直接存入通用寄存器进行判断。而比较某个字节时,如果相等则这个字节会变成全 1,所以只要提取每个字节的最高位即可判断是否全部相等。 **AVX2** 的寄存器为 256 位寄存器,一个寄存器可以装下 32 个字符。添加 **32 个字符**的循环。 ```x86asm loop_32_avx: CMPQ BX, $32 JB loop_16_sse VMOVUPD (SI), Y0 VMOVUPD (DI), Y1 VPCMPEQB Y0, Y1, Y1 VPMOVMSKB Y1, CX ADDQ $32, SI ADDQ $32, DI SUBQ $32, BX CMPL CX, $0xffffffff JEQ loop_32_avx XORQ AX, AX RET ``` AVX2 指令集与 SSE4.2 指令集类似,只需在指令前加字母 V。唯一不同的是第 6 行的 `VPCMPEQB` 指令,它需要三个操作数,将前两个操作数的比较结果存入第三个操作数。 由于 CPU 不支持 AVX512 指令集,因此无法使用 512 位寄存器。不过,还可以用**循环展开**来优化代码。AVX2 的寄存器有 16 个,所以可以展开 2 次、4 次、8 次循环。以展开 4 次循环为例,添加 **128 个字符**的循环。 ```x86asm loop_128_unroll: CMPQ BX, $128 JB loop_32_avx VMOVUPD (SI), Y0 VMOVUPD (DI), Y1 VMOVUPD 32(SI), Y2 VMOVUPD 32(DI), Y3 VMOVUPD 64(SI), Y4 VMOVUPD 64(DI), Y5 VMOVUPD 96(SI), Y6 VMOVUPD 96(DI), Y7 VPCMPEQB Y0, Y1, Y1 VPCMPEQB Y2, Y3, Y3 VPCMPEQB Y4, Y5, Y5 VPCMPEQB Y6, Y7, Y7 VPAND Y1, Y3, Y3 VPAND Y5, Y7, Y7 VPAND Y3, Y7, Y7 VPMOVMSKB Y7, CX ADDQ $128, SI ADDQ $128, DI SUBQ $128, BX CMPL CX, $0xffffffff JEQ loop_128_unroll XORQ AX, AX RET ``` 第 4 行到第 7 行将连续 128 个字符存入 `Y` 系列寄存器,第 12 行到第 15 行分别比较四对 `Y` 系列寄存器,第 16 行到第 18 行将比较结果作按位与(因为相等的比较结果为全 1),第 19 行将结果每个字节的最高位提取出来存入 `CX` 的低位。展开 2 次、8 次循环的代码类似。 至此,已基本实现了用 SIMD 优化的字符串相等函数。 # 效果分析 编写测试函数,用于测试字符串相等函数的正确性。 ```go func TestEqual(t *testing.T) { d, s := []byte("abcde"), []byte("abcde") if !bytes.Equal(d, s) { // len < 8 t.Errorf("Equal(d, s) = false; want true") } for i := 0; i < 50; i++ { d, s = append(d, 'f'), append(s, 'f') } if !bytes.Equal(d, s) { // len < 64 t.Errorf("Equal(d, s) = false; want true") } for i := 0; i < 500; i++ { d, s = append(d, 'g'), append(s, 'g') } if !bytes.Equal(d, s) { // len >= 64 t.Errorf("Equal(d, s) = false; want true") } d = append(d, 'h') if bytes.Equal(d, s) { // len(d) > len(s) t.Errorf("Equal(d, s) = true; want false") } s = append(s, 'i') if bytes.Equal(d, s) { // len(d) == len(s) && d[len-1] != s[len-1] t.Errorf("Equal(d, s) = true; want false") } s = append(s, 'j') if bytes.Equal(d, s) { // len(d) < len(s) t.Errorf("Equal(d, s) = true; want false") } d, s = []byte("k"), []byte("l") for i := 0; i < 5000; i++ { d, s = append(d, 'm'), append(s, 'm') } if bytes.Equal(d, s) { // len(d) == len(s) && d[0] != s[0] t.Errorf("Equal(d, s) = true; want false") } } ``` 用 `go test -run ^TestEqual$` 命令运行测试函数。实验中编写的每一种循环都可以通过测试。 编写基准函数,用于测试字符串相等函数的性能。 ```go func BenchmarkEqual(b *testing.B) { d, s := []byte(""), []byte("") for i := 0; i < 4096; i++ { // 4KB d = append(d, 'n') s = append(s, 'n') } b.ResetTimer() for i := 0; i < b.N; i++ { bytes.Equal(d, s) } } ``` 用 `go test -run ^$ -bench ^BenchmarkEqual$ -count 10` 命令运行基准函数,可以得到字符串相等函数的运行次数和平均运行时间。字符串大小均为 4KB。据此可以算出函数处理数据的速率。 在表格的说明中,**R1** 表示一轮比较 1 个字符,**R8** 表示用通用寄存器一轮比较 8 个字符,**M8** 表示用 MMX 寄存器一轮比较 8 个字符,**X16** 表示用 SSE4.2 寄存器一轮比较 16 个字符,**Y32** 表示用 AVX2 寄存器一轮比较 32 个字符,**Y64**、**Y128**、**Y256** 分别表示用 AVX2 寄存器和 2 次、4 次、8 次循环展开一轮比较 64 个、128 个、256 个字符。 | 测试程序 | 说明 | 运行次数 | 运行时间(ns/op) | 处理速率(GB/s) | | :-------------------: | :------------------: | :------: | :---------------: | :--------------: | | `original.asm` | Go 语言自带 | 13499010 | 87.92 | 43.39 | | `loop_1.asm` | R1 | 424518 | 2721 | 1.402 | | `loop_8.asm` | R8 + R1 | 3993097 | 298.0 | 12.80 | | `loop_8_mmx.asm` | M8 + R1 | 2980017 | 395.4 | 9.648 | | `loop_16_sse.asm` | X16 + R8 + R1 | 5688907 | 206.8 | 18.45 | | `loop_32_avx.asm` | Y32 + R8 + R1 | 10291144 | 115.2 | 33.11 | | `loop_64_unroll.asm` | Y64 + R8 + R1 | 13146477 | 88.00 | 43.35 | | `loop_128_unroll.asm` | Y128 + Y32 + R8 + R1 | 17871183 | 64.31 | 59.32 | | `loop_256_unroll.asm` | Y256 + Y32 + R8 + R1 | 21847575 | 54.56 | 69.92 | 可以看出: 1. 随着 SIMD 寄存器位数的增加,函数的运行时间会减少,处理数据的速率也会变快; 2. 用通用寄存器比较 8 个字符比用 MMX 寄存器比较 8 个字符要快,是因为 MMX 寄存器比较相等之后还需要移回通用寄存器进行跳转判断; 3. Go 语言自带的函数用 AVX2 的寄存器时只展开了 2 次循环,实验中展开了 2 次循环的函数与 Go 语言自带的函数速度相近,而展开了 4 次、8 次循环的函数速度更快; 4. 对于大小为 4KB 的字符串,实验中最快的函数相比最慢的函数速度提升 4887%,相比 Go 语言自带的函数速度提升 61.14%。 以上结果符合预期。 # 参考文献 1. A Quick Guide to Go's Assembler, https://go.dev/doc/asm. 2. A Manual for the Plan 9 assembler, https://9p.io/sys/doc/asm.html. 3. x64 Cheat Sheet, https://cs.brown.edu/courses/cs033/docs/guides/x64_cheatsheet.pdf. 4. x86 and amd64 instruction reference, https://www.felixcloutier.com/x86/index.html. 5. x86 Assembly Language Reference Manual, https://docs.oracle.com/cd/E37838_01/html/E61064/index.html. 6. Intel® Instruction Set Extensions Technology, https://www.intel.com/content/www/us/en/support/articles/000005779/processors.html. 7. equal_amd64.s, https://github.com/golang/go/blob/master/src/internal/bytealg/equal_amd64.s. 8. testing package, https://pkg.go.dev/testing.