ref link: FizzBuzz – SIMD Style

周末看了个文章,大意是说Java有了SIMD了,我来教大家怎么用AVX来算FizzBuzz。 说来我第一次知道FizzBuzz还是在Coding Horror上面。

早些年我面试实习生的时候用过几次……怎么说呢,脑子不清楚确实挺容易写出问题的。 不过后来觉得这个题就算做出来也没啥用,反而担心第二天直接上V2EX:“XX公司面试官就这水平?是不是在侮辱我?”

刚刚搜了下,甚至连LeetCode上都有这道题了。

SIMD

先说下那个帖子里的解法。因为AVX256的宽度限制,一次只能跑8个数,而FizzBuzz的循环宽度是15,互质。 所以老哥做了一点循环的展开,准备了一个15个模板,每个模板里有8个数。

主要思想是,AVX里有一个blend方法,长这样:

1
__m256i _mm256_blend_epi32(__m256i s1, __m256i s2, const int mask);

大体来说就是,按照mask里面每个bit的0或1,决定输出的8个数是来自s1或者s2中的哪一个。本质是个 if-else 操作。 对应epi16的版本,感觉这里 vpblendw说的更清楚一些。

靠这个blend运算,s1是输入的1~N的数组,s2是fizz或者buzz的对应值,mask是预先准备的对应位置是否需要出Fizz或者Buzz。 所以如果mask有值,就出s2里的fizz或者buzz,而如果对应bit是0的话,就照抄s1。

.NET 实验

Gunnar的主要内容是讲讲Java里新的SIMD接口。考虑到.NET在近期的版本里也有了Intrinsics API,而更加通用的Vector API则是在更早就已经实现了。出于学习SIMD Intrinsics API的目的,我试着用.NET复现了一下SIMD FizzBuzz。下面是几个不同实现算法的跑分:

Method Mean Error StdDev Median Ratio RatioSD
Simple_15 1,589.93 ns 31.728 ns 88.446 ns 1,553.42 ns 4.06 0.55
Simple_3_5 1,243.74 ns 21.567 ns 16.838 ns 1,244.81 ns 3.07 0.43
SIMD 397.41 ns 17.203 ns 50.724 ns 389.53 ns 1.00 0.00
SIMD_aligned 417.61 ns 10.178 ns 28.874 ns 411.72 ns 1.06 0.16
Unroll_15 258.90 ns 4.408 ns 5.885 ns 257.15 ns 0.65 0.10
SIMD_Unroll_16 44.59 ns 0.799 ns 0.747 ns 44.42 ns 0.11 0.02
SIMD_NoRead 434.46 ns 18.955 ns 55.888 ns 436.64 ns 1.11 0.20
Unroll_16 453.34 ns 8.193 ns 9.753 ns 450.42 ns 1.15 0.19

具体的实验设置后面慢慢说,

简而言之:

  1. 尽量避免用除法取模
  2. unroll 简单粗暴效果好。
  3. SIMD还挺难写的。

Simple-15 vs Simple-3-5

先说一个我之前面试的时候遇到过的一个常见错误:

1
2
3
4
5
6
# Wrong code

if       (n %  3 == 0) => Fizz
else if  (n %  5 == 0) => Buzz
else if  (n % 15 == 0) => FizzBuzz
else                   => n

事实上这个实现是永远不可能输出“FizzBuzz”的。我怀疑很多人写成这样,是由于文本到代码的直接翻译导致的,有点类似“如果看到卖西瓜的,买一个”。

一个简单的修复是把那个 %15 的版本放到最上面,变成:

1
2
3
4
5
6
# Simple-15

if       (n % 15 == 0) => FizzBuzz
else if  (n %  3 == 0) => Fizz
else if  (n %  5 == 0) => Buzz
else                   => n

或者按照Gunnar的做法,如果确定是3的倍数以后,多加一个5的倍数的检查:

1
2
3
4
5
6
7
# Simple-3-5

if        (n % 3 == 0) 
    if    (n % 5 == 0)   => FizzBuzz
    else                 => Fizz  
else if   (n % 5 == 0)   => Buzz
else                     => n

其实我之前从来没想过这两种写法会有差不多25%的性能差异。不过仔细想想-15的版本确实要慢很多。

大概一半(8/15=0.53)的数是与3,5无关的。在-15的版本中,他们都需要3次mod运算,而在-3-5中则只需要2次。对于占26%的Fizz组而言,都是2次,没有区别。而占13%的Buzz组,-15里是3次mod,在-3-5中依然是2次。最后是6%的FizzBuzz组,分别是1次和2次。

综合来看,-15的版本中平均每个数要做2.6次mod运算,而-3-5中是2次。这个比例和实际跑出来的性能差距基本相同。

Simple vs SIMD

这个对比,Gunnar的文章里说的挺详细了,这里就不多说了。不过似乎Gunnar的测试里,SIMD的版本是simple的4倍,而我这里只有3倍。具体原因不是很清楚。

Load Aligned

在我跑SIMD的跑分的时候,系统每次都会提示我一个类似的错误:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
// * Warnings *
MultimodalDistribution
  FizzBuzzBenchmark.SIMD: Default         -> It seems that the distribution is bimodal (mValue = 3.65)

FizzBuzzBenchmark.SIMD: DefaultJob
Runtime = .NET Core 5.0.4 (CoreCLR 5.0.421.11614, CoreFX 5.0.421.11614), X64 RyuJIT; GC = Concurrent Workstation
Mean = 379.348 ns, StdErr = 4.012 ns (1.06%), N = 99, StdDev = 39.918 ns
Min = 307.527 ns, Q1 = 348.640 ns, Median = 373.750 ns, Q3 = 406.518 ns, Max = 472.894 ns
IQR = 57.878 ns, LowerFence = 261.823 ns, UpperFence = 493.334 ns
ConfidenceInterval = [365.737 ns; 392.959 ns] (CI 99.9%), Margin = 13.611 ns (3.59% of Mean)
Skewness = 0.34, Kurtosis = 2.29, MValue = 3.65
-------------------- Histogram --------------------
[307.146 ns ; 332.744 ns) | @@@@@@@@@@
[332.744 ns ; 355.395 ns) | @@@@@@@@@@@@@@@@@@@@@@@
[355.395 ns ; 377.707 ns) | @@@@@@@@@@@@@@@@@@@@
[377.707 ns ; 386.188 ns) | @@
[386.188 ns ; 408.839 ns) | @@@@@@@@@@@@@@@@@@@@@
[408.839 ns ; 434.039 ns) | @@@@@@@@@@@@@@
[434.039 ns ; 456.646 ns) | @@@@@@
[456.646 ns ; 478.479 ns) | @@@
---------------------------------------------------

思来想去,怀疑是load数据的时候,input data的地址对齐造成的问题。然后写了个自己对齐输入地址的实验。结果倒是挺出乎意料的。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
// Init:
input = new int[N+8];
fixed (int* start = input)
{
    var aligned = (int*)(((ulong)start + 31UL) & ~31UL);
    alignOffset = (int)(aligned - start);
}

// Load data:
var s1 = Avx2.LoadAlignedVector256(start + alignOffset + i);

其实这个代码单纯跑一两次倒是没啥问题。但是用 BenchmarkDotNet 跑的话,几乎一定会挂在WorkloadJitting 1这个阶段。 一番debug以后发现,如果开始我new出来的input是不对齐的,在 load 的时候,start会突然神秘的变成对齐的形状。 然后如果我继续加那个之前算的 align offset 就会 AccessViolationException。

开始我还以为是GC的问题,不过在有限的几次观测之下,地址都是正好往后挪了几个位置,到最近的对齐的位置上。 感觉进一步分析就需要找JIT的人问问了。或者用Marshal.AllocHGlobal再试试了。

至于最上面的表里跑出来的分,是我加了一个更奇怪的workaround的结果。大体就是在LoadAlignedVector256之前重新算一次offset。 不知道为什么还不如之前的SIMD的效果。

SIMD unroll 16

这个方法是HN 上的一个评论提到的方案。

大意是既然你都unroll了,完全可以做的更简单一点。虽然8和15互质,但是完全可以一次写16个数,然后忽略掉最后一个数。 按照这个思路,就是最上面的表里面跑的最快的SIMD_Unroll_16。具体的代码可以参考HN里的代码。

参照这个思路,我又跑了几个类似的实现。

比较不行的Unroll_16是把SIMD里SIMD相关指令直接拿掉的版本。v1v2从Vector256改成一个Int[16],d1d2类似。Avx.Add()直接展开成16个加法运算。没想到这么菜!

SIMD_Unroll_16和原版的SIMD的另一个非常显著的区别是,原版需要从input data里把数据加载到xmm register里,而SIMD_Unroll_16是直接靠模板迭代出输入数值。所以在SIMD_NoRead里,我使用类似的数据加载方式,结果没想到反而更慢了。怀疑是xmm register的编排出了问题。

Unroll-15SIMD_Unroll_16去掉SIMD的另一个实验:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
for (; i < bound; i += 15)
{
    result[i]      = input[i];
    result[i + 1]  = input[i + 1];
    result[i + 2]  = FIZZ;
    result[i + 3]  = input[i + 3];
    result[i + 4]  = BUZZ;
    result[i + 5]  = FIZZ;
    result[i + 6]  = input[i + 6];
    result[i + 7]  = input[i + 7];
    result[i + 8]  = FIZZ;
    result[i + 9]  = BUZZ;
    result[i + 10] = input[i + 10];
    result[i + 11] = FIZZ;
    result[i + 12] = input[i + 12];
    result[i + 13] = input[i + 13];
    result[i + 14] = FIZZBUZZ;
}

没别的,就一句话,简单,贼快。

其他

HN的另一个评论里提到了一个编译器的AVX展开技巧:https://godbolt.org/z/q4bcfj

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void mod15( int arr[ 8 ] )
{
    for( int i = 0; i < 8; i++ )
        arr[ i ] = arr[ i ] % 15;
}

==>

vmovdqu ymm0, ymmword ptr [rdi]
vpbroadcastd    ymm2, dword ptr [rip + .LCPI0_0] # ymm2 = [2290649225,2290649225,2290649225,2290649225,2290649225,290649225,2290649225,2290649225]
vpbroadcastd    ymm3, dword ptr [rip + .LCPI0_1] # ymm3 = [15,15,15,15,15,15,15,15]
vpshufd ymm1, ymm0, 245                 # ymm1 = ymm0[1,1,3,3,5,5,7,7]
vpmuldq ymm1, ymm1, ymm2
vpmuldq ymm2, ymm0, ymm2
vpshufd ymm2, ymm2, 245                 # ymm2 = ymm2[1,1,3,3,5,5,7,7]
vpblendd        ymm1, ymm2, ymm1, 170           # ymm1 = ymm2[0],ymm1[1],ymm2[2],ymm1[3],ymm2[4],ymm1[5],ymm2[6],ymm1[7]
vpaddd  ymm1, ymm1, ymm0
vpsrld  ymm2, ymm1, 31
vpsrad  ymm1, ymm1, 3
vpaddd  ymm1, ymm1, ymm2
vpmulld ymm1, ymm1, ymm3
vpsubd  ymm0, ymm0, ymm1
vmovdqu ymmword ptr [rdi], ymm0
vzeroupper
ret

说实话,没看懂……

试了下BenchmarkDotNet的教程里关于hash的sample:

Method Mean Error StdDev
Md5 22.72 μs 0.452 μs 0.903 μs
Sha1 26.04 μs 0.604 μs 1.742 μs
Sha256 69.93 μs 2.015 μs 5.941 μs

我宣布我未来一段时间都会是MD5粉了,除非谁送我一块支持SHA extension的CPU。

天下算法,唯快不破!