学习SIMD编程(AVX2)

Last Updated: 2024-02-11 02:30:21 Sunday

-- TOC --

开一篇笔记,专门用来学习和总结SIMD编程。

SIMD代码的基本套路,先将数据Load进寄存器,计算,然后Store出来分析结果。使用SSE指令不需要特别的编译器选项,因为SSE已经存在很久了,gcc编译浮点数计算的代码,就是使用SSE指令。但AVX2需要使用-mavx2编译选项。

本文代码以AVX2的标准Intrinsics为主,这些Intrinsics都是编译器的builtin接口,大多数情况下,这些接口编译后都直接对应一条AVX2指令。它们只是看起来像是函数调用,其实不是,没有调用开销(在头文件中,都是static inline申明)。使用Intrinsics的好处,可以有效利用编译器的Type System,也避免了直接写Inline Assembly的复杂。

Intel Intrinsics:https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html

LOAD

load int

// load 256bits data, mem_addr must be 32-byte aligned.
__m256i _mm256_load_si256(__m256i const *mem_addr);
// load 256bits data, mem_addr could not be aligned. u means unaligned.
__m256i _mm256_loadu_si256(__m256i const *mem_addr);

load float

// single precision, aligned or unaligned, 4*8=32
__m256 _mm256_load_ps(float const *mem_addr);
__m256 _mm256_loadu_ps(float const *mem_addr);
// double precision, aligned or unaligned, 8*4=32
__m256d _mm256_load_pd(double const *mem_addr);
__m256d _mm256_loadu_pd(double const *mem_addr);

STORE

store int

// store 256bits data a into mem_addr, which must be 32-byte aligned.
void _mm256_store_si256(__m256i *mem_addr, __m256i a);
// store 256bits data a into mem_addr, which could not be aligned. u means unaligned.
void _mm256_storeu_si256(__m256i *mem_addr, __m256i a);

store float

// single precision
void _mm256_store_ps(float *mem_addr, __m256 a);
void _mm256_storeu_ps(float *mem_addr, __m256 a);
// double precision
void _mm256_store_pd(double *mem_addr, __m256d a);
void _mm256_storeu_pd(double *mem_addr, __m256d a);

SET和BROADCAST

给256bits长的寄存器填充满需要的数据,是个技术活。

set1/broadcast

用一个相同的值,填满整个256bits的寄存器。注意这一组接口的参数,都是具体的值,而不是指针:

__m256i _mm256_set1_epi8(char a);
__m256i _mm256_set1_epi16(short a);
__m256i _mm256_set1_epi32(int a);
__m256i _mm256_set1_epi64x(long long a);
__m256d _mm256_set1_pd(double a);
__m256  _mm256_set1_ps(float a);

也有用指针的接口:

__m256  _mm256_broadcast_ss(float const *mem_addr);
__m256d _mm256_broadcast_sd(double const *mem_addr);

setzero

我们都知道,用xor操作来给寄存器置零,是推荐做法。AVX2指令集也一样,intrinsics中有如下3条指令可用来置零,对应3条xor指令。

__m256d _mm256_setzero_pd(void);
__m256  _mm256_setzero_ps(void);
__m256i _mm256_setzero_si256(void);

setr

这一组set register接口,对应的是instruction sequence。

/*
dst[7:0] := e31
dst[15:8] := e30
...
dst[255:248] := e0
*/
__m256i _mm256_setr_epi8(char e31, char e30, char e29, char e28, char e27, char e26, char e25, char e24, char e23, char e22, char e21, char e20, char e19, char e18, char e17, char e16, char e15, char e14, char e13, char e12, char e11, char e10, char e9, char e8, char e7, char e6, char e5, char e4, char e3, char e2, char e1, char e0);
/*
dst[15:0] := e15
dst[31:16] := e14
...
dst[255:240] := e0
*/
__m256i _mm256_setr_epi16(short e15, short e14, short e13, short e12, short e11, short e10, short e9, short e8, short e7, short e6, short e5, short e4, short e3, short e2, short e1, short e0);
/*
dst[31:0] := e7
dst[63:32] := e6
...
dst[255:224] := e0
*/
__m256i _mm256_setr_epi32(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0);
/*
dst[63:0] := e3
dst[127:64] := e2
dst[191:128] := e1
dst[255:192] := e0
*/
__m256i _mm256_setr_epi64x(__int64 e3, __int64 e2, __int64 e1, __int64 e0);
/*
dst[31:0] := e7
dst[63:32] := e6
...
dst[255:224] := e0
*/
__m256 _mm256_setr_ps(float e7, float e6, float e5, float e4, float e3, float e2, float e1, float e0);
/*
dst[63:0] := e3
dst[127:64] := e2
dst[191:128] := e1
dst[255:192] := e0
*/
__m256d _mm256_setr_pd(double e3, double e2, double e1, double e0);

测试:

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>

int main() {
    int64_t result[4] = {};

    __m256i a = _mm256_setr_epi64x(1,1,1,1);
    __m256i b = _mm256_setr_epi64x(1,2,3,4);
    __m256i c = _mm256_add_epi64(a,b);
    _mm256_storeu_si256((__m256i*)result, c);

    for(int i=0; i<4; ++i)
        printf("%ld ", result[i]); // 2,3,4,5
    printf("\n");

    return 0;
}

Extract/Insert

store将整个256bits寄存器都copy到内存,extract指令则是有选择的。extract接口编译之后,是一个sequence。

int _mm256_extract_epi8(__m256i a, const int index);
int _mm256_extract_epi16(__m256i a, const int index);
__int32 _mm256_extract_epi32(__m256i a, const int index);
__int64 _mm256_extract_epi64(__m256i a, const int index);

与extract对应的,就是insert,很好理解。

__m256i _mm256_insert_epi8(__m256i a, __int8 i, const int index);
__m256i _mm256_insert_epi16(__m256i a, __int16 i, const int index);
__m256i _mm256_insert_epi32(__m256i a, __int32 i, const int index);
__m256i _mm256_insert_epi64(__m256i a, __int64 i, const int index);

Zero

将所有YMM寄存器清零。

void _mm256_zeroall (void);

将所有YMM寄存器的高128bits清零。

void _mm256_zeroupper (void);

申请地址对齐内存

AVX2很多指令需要32字节对齐,一般情况下编译器操作调用栈的地址,以及libc中的malloc得到的地址,都是8字节对齐。因此,intrinsics中有专用的接口,用来申请特定对齐长度的内存:

// 比libc中的malloc,对了一个align参数,用来指定对齐字节数
void* _mm_malloc(size_t size, size_t align);
// 对应的free
void _mm_free(void * mem_addr);

这两个接口是SSE提供的,AVX时代一样方便使用。

Add/Sub

Int

LeetCode第1题,两数和,最后测试了一组SIMD实现。这组代码主要测试了加法和减法两个intrinsics接口。整数加减法接口如下:

// add
__m256i _mm256_add_epi8(__m256i a, __m256i b);
__m256i _mm256_add_epi16(__m256i a, __m256i b);
__m256i _mm256_add_epi32(__m256i a, __m256i b);
__m256i _mm256_add_epi64(__m256i a, __m256i b);
// sub
__m256i _mm256_sub_epi8(__m256i a, __m256i b);
__m256i _mm256_sub_epi16(__m256i a, __m256i b);
__m256i _mm256_sub_epi32(__m256i a, __m256i b);
__m256i _mm256_sub_epi64(__m256i a, __m256i b);

Float

// single precision
__m256 _mm256_add_ps(__m256 a, __m256 b);
__m256 _mm256_sub_ps(__m256 a, __m256 b);
// double precision
__m256d _mm256_add_pd(__m256d a, __m256d b);
__m256d _mm256_sub_pd(__m256d a, __m256d b);

Saturated Add/Sub

Adds

所谓饱和加,就是两个数相加后,如果overflow,则最后结果等于max value。

// signed
__m256i _mm256_adds_epi8(__m256i a, __m256i b);
__m256i _mm256_adds_epi16(__m256i a, __m256i b);
// unsigned
__m256i _mm256_adds_epu8(__m256i a, __m256i b);
__m256i _mm256_adds_epu16(__m256i a, __m256i b);

测试:

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>

int main() {
    int16_t result[16] = {};

    __m256i a = _mm256_setr_epi16(1,2,3,4,5,6,7,8,
                                  1,2,3,4,5,6,7,8);
    __m256i b = _mm256_setr_epi16(1,2,3,4,5,6,7,8,
                                  32767,32766,32765,32764,
                                  32763,32762,32761,32760);
    __m256i c = _mm256_adds_epi16(a, b);
    _mm256_storeu_si256((__m256i*)result, c);

    // 0x2 0x4 0x6 0x8 0xa 0xc 0xe 0x10 0x7fff 0x7fff 0x7fff 0x7fff 0x7fff 0x7fff 0x7fff 0x7fff
    for(int i=0; i<16; ++i)
        printf("0x%x ", result[i]);
    printf("\n");

    return 0;
}

Subs

减到小于0,就等于0了。

__m256i _mm256_subs_epi8(__m256i a, __m256i b);
__m256i _mm256_subs_epi16(__m256i a, __m256i b);
__m256i _mm256_subs_epu8(__m256i a, __m256i b);
__m256i _mm256_subs_epu16(__m256i a, __m256i b);

又加又减

/*
FOR j := 0 to 7
    i := j*32
    IF ((j & 1) == 0)
        dst[i+31:i] := a[i+31:i] - b[i+31:i]
    ELSE
        dst[i+31:i] := a[i+31:i] + b[i+31:i]
    FI
ENDFOR
dst[MAX:256] := 0
奇数位置做加法,偶数位置做减法。
*/
__m256  _mm256_addsub_ps(__m256 a, __m256 b);
/*
FOR j := 0 to 3
    i := j*64
    IF ((j & 1) == 0)
        dst[i+63:i] := a[i+63:i] - b[i+63:i]
    ELSE
        dst[i+63:i] := a[i+63:i] + b[i+63:i]
    FI
ENDFOR
dst[MAX:256] := 0
*/
__m256d _mm256_addsub_pd(__m256d a, __m256d b);

均值

__m256i _mm256_avg_epu8(__m256i a, __m256i b);
__m256i _mm256_avg_epu16(__m256i a, __m256i b);

Sign

保持a与b对应的位置同号,如果b为0,a对应位置置零。

__m256i _mm256_sign_epi8(__m256i a, __m256i b);
__m256i _mm256_sign_epi16(__m256i a, __m256i b);
__m256i _mm256_sign_epi32(__m256i a, __m256i b);

Mul

Int

分别按16bits和32bits做乘法。lo表示只取乘法计算结果的低位,这与正常int乘法溢出后的处理手法一样。

__m256i _mm256_mullo_epi16(__m256i a, __m256i b);
__m256i _mm256_mullo_epi32(__m256i a, __m256i b);

还有两条指令,取计算结果的高位,hi,但仅支持16bits的乘法计算。

__m256i _mm256_mulhi_epi16(__m256i a, __m256i b);
__m256i _mm256_mulhi_epu16(__m256i a, __m256i b);

下面两个乘法指令,高位低位全要,取每个64bits的低32bits做乘法,保存64bits结果。因此,只能同时做4组数据的乘法。

__m256i _mm256_mul_epi32(__m256i a, __m256i b);
__m256i _mm256_mul_epu32(__m256i a, __m256i b);

测试:

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>

int main() {
    int64_t result[4] = {};

    __m256i a = _mm256_setr_epi64x(0x7FFFFFFF,2,3,-4);
    __m256i b = _mm256_setr_epi64x(0x7FFFFFFF,2,-3,4);
    __m256i c = _mm256_mul_epi32(a, b);
    _mm256_storeu_si256((__m256i*)result, c);

    for(int i=0; i<4; ++i)
        printf("0x%lx ", result[i]);
    printf("\n");

    // the same with below, d1 must be cast first!!
    int32_t d1 = 0x7FFFFFFF;
    int64_t d2 = ((int64_t)d1)*d1;
    printf("0x%lx\n", d2);

    return 0;
}

最后是一个非常奇怪,不知道用在何处的乘法指令:

__m256i _mm256_mulhrs_epi16(__m256i a, __m256i b);

Float

float的乘法,就很直接了,ps同时乘8组数据,pd同时乘4组数据。

__m256d _mm256_mul_pd(__m256d a, __m256d b);
__m256  _mm256_mul_ps(__m256 a, __m256 b);

Div

没有packed int除法指令。

__m256d _mm256_div_pd(__m256d a, __m256d b);
__m256  _mm256_div_ps(__m256 a, __m256 b);

比较Int

比较BGRA格式的图片并找出具体差异时,每个pixel对应4个bytes,如果不使用SIMD,每次最多比较2个pixel,cast到long long类型后用xor比较。使用AVX2,每次可以比较32个字节,即8个pixel。

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>

int main() {
    int32_t data1[8];
    int32_t data2[8];
    for(int i=0; i<8; ++i){
        data1[i] = i;
        if(i%4)
            data2[i] = i+1;
        else
            data2[i] = i;
    }
    int32_t result[8] = {};

    __m256i a = _mm256_loadu_si256((__m256i*)data1);
    __m256i b = _mm256_loadu_si256((__m256i*)data2);
    __m256i c = _mm256_cmpeq_epi32(a, b);
    _mm256_storeu_si256((__m256i*)result, c);

    for(int i=0; i<8; ++i)
        printf("%d ", result[i]);
    printf("\n");

    return 0;
}

编译执行:

$ gcc -mavx2 -Wall -Wextra simd.c -o simd && ./simd
-1 0 0 0 -1 0 0 0

_mm256_cmpeq_epi32按4bytes的方式进行比较,相同时设置为-1(全F),不相等时为0!这个返回值的设置,跟一般场景下是反着的....

关于cmpeq指令:

__m256i _mm256_cmpeq_epi8(__m256i a, __m256i b);
__m256i _mm256_cmpeq_epi16(__m256i a, __m256i b);
__m256i _mm256_cmpeq_epi32(__m256i a, __m256i b);
__m256i _mm256_cmpeq_epi64(__m256i a, __m256i b);

其实,我觉得更好的方式,是使用并行的xor比较指令,Intel官方文档上说明,avx的xor比cmpeq拥有更低的CPI。修改上例代码,除了使用xor之外,对result采用申请32字节对齐的内存地址,采用byte-wise逻辑(不再对应比较pixel场景),并相应调整其它代码:

#include <stdio.h>
#include <immintrin.h>

int main() {
    char data1[32];
    char data2[32];
    for(int i=0; i<32; ++i){
        data1[i] = i;
        if(i%4)
            data2[i] = i+1;
        else
            data2[i] = i;
    }

    char *result = _mm_malloc(32, 32);

    __m256i a = _mm256_loadu_si256((__m256i*)data1);
    __m256i b = _mm256_loadu_si256((__m256i*)data2);
    __m256i c = _mm256_xor_si256(a, b);
    _mm256_store_si256((__m256i*)result, c);

    for(int i=0; i<32; ++i)
        printf("%d ", result[i]);
    printf("\n");

    _mm_free(result);
    return 0;
}

编译执行:

$ gcc -mavx2 -Wall -Wextra simd.c -o simd && ./simd
0 3 1 7 0 3 1 15 0 3 1 7 0 3 1 31 0 3 1 7 0 3 1 15 0 3 1 7 0 3 1 63

此时的返回值,相等时为0,不相等时不等于0,返回值也符合常识了。虽然store时的地址32字节对齐了,但_mm_malloc和_mm_free都是函数调用,整体性能不一定比非对齐是使用栈内空间就更优。

Intrinsics还提供了一组比较大于的接口:

__m256i _mm256_cmpgt_epi8(__m256i a, __m256i b);
__m256i _mm256_cmpgt_epi16(__m256i a, __m256i b);
__m256i _mm256_cmpgt_epi32(__m256i a, __m256i b);
__m256i _mm256_cmpgt_epi64(__m256i a, __m256i b);

同样,如果大于条件满足,结果将被设置为-1(全F),为满足设置为0。

比较Float

__m256  _mm256_cmp_ps(__m256 a, __m256 b, const int imm8);
__m256d _mm256_cmp_pd(__m256d a, __m256d b, const int imm8);

imm8有32种可能的取值,其中:

CASE (imm8[4:0]) OF
0: OP := _CMP_EQ_OQ
1: OP := _CMP_LT_OS
2: OP := _CMP_LE_OS
3: OP := _CMP_UNORD_Q 
4: OP := _CMP_NEQ_UQ
5: OP := _CMP_NLT_US
6: OP := _CMP_NLE_US
7: OP := _CMP_ORD_Q
8: OP := _CMP_EQ_UQ
9: OP := _CMP_NGE_US
10: OP := _CMP_NGT_US
11: OP := _CMP_FALSE_OQ
12: OP := _CMP_NEQ_OQ
13: OP := _CMP_GE_OS
14: OP := _CMP_GT_OS
15: OP := _CMP_TRUE_UQ
16: OP := _CMP_EQ_OS
17: OP := _CMP_LT_OQ
18: OP := _CMP_LE_OQ
19: OP := _CMP_UNORD_S
20: OP := _CMP_NEQ_US
21: OP := _CMP_NLT_UQ
22: OP := _CMP_NLE_UQ
23: OP := _CMP_ORD_S
24: OP := _CMP_EQ_US
25: OP := _CMP_NGE_UQ 
26: OP := _CMP_NGT_UQ 
27: OP := _CMP_FALSE_OS 
28: OP := _CMP_NEQ_OS 
29: OP := _CMP_GE_OQ
30: OP := _CMP_GT_OQ
31: OP := _CMP_TRUE_US
ESAC
FOR j := 0 to 7
    i := j*32
    dst[i+31:i] := ( a[i+31:i] OP b[i+31:i] ) ? 0xFFFFFFFF : 0
ENDFOR
dst[MAX:256] := 0

测试:

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>

int main() {
    int64_t result[4] = {};

    __m256i a = _mm256_set1_epi64x(-1); // NaN
    __m256i b = _mm256_set1_epi64x(1);

    __m256d c = _mm256_cmp_pd(a, b, _CMP_EQ_OS);
    _mm256_storeu_pd((double*)result, c);
    // 0 0 0 0
    for(int i=0; i<4; ++i)
        printf("%ld ", result[i]);
    printf("\n");

    c = _mm256_cmp_pd(a, b, _CMP_EQ_US);
    _mm256_storeu_pd((double*)result, c);
    // -1 -1 -1 -1
    for(int i=0; i<4; ++i)
        printf("%ld ", result[i]);
    printf("\n");

    return 0;
}

绝对值

以下测试代码,同时对4个int32值进行abs操作:

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>


int main() {
    int32_t data[8] = {1,-1,2,-2,3,-3,0,0x80000000 };
    int32_t result[8] = {};

    __m256i a = _mm256_loadu_si256((__m256i*)data);
    __m256i b = _mm256_abs_epi32(a);
    _mm256_storeu_si256((__m256i*)result, b);

    for(int i=0; i<8; ++i)
        printf("0x%x ", result[i]);
    printf("\n");

    return 0;
}

结果符合预期,注意对0x80000000取绝对值,值不变,依然是最大负数。

$ gcc -mavx2 simd.c && ./a.out
0x1 0x1 0x2 0x2 0x3 0x3 0x0 0x80000000

总结abs指令:

// no epi64 available!!
// The AVX2 instruction set primarily focuses on
// providing operations for 8-bit, 16-bit, and 32-bit integers.
// But, AVX-512 has...
__m256i _mm256_abs_epi8(__m256i a);
__m256i _mm256_abs_epi16(__m256i a);
__m256i _mm256_abs_epi32(__m256i a);

自制 packed int64 abs 接口

负数变正数,就是取反加1。srai没有epi64接口,因此用permutevar8x32_epi32做一次错位的copy,得到的c就相当于srai epi64的效果。

static inline __m256i _mm256_abs_epi64(__m256i x){
    // self made _mm256_srai_epi64
    __m256i a = _mm256_srai_epi32(x, 31);
    __m256i b = _mm256_setr_epi32(1,1,3,3,5,5,7,7);
    __m256i c = _mm256_permutevar8x32_epi32(a, b);
    // ~ and +1
    __m256i y = _mm256_sub_epi64(_mm256_xor_si256(x,c), c);
    return y;
}

Shift

sll

slli: shift left logic with immediate

逻辑左移,补0,用立即数来表达左移位数。当位移数大于有效bit位数时,寄存器清零。

// no slli epi8 available
__m256i _mm256_slli_epi16(__m256i a, int imm8);
__m256i _mm256_slli_epi32(__m256i a, int imm8);
__m256i _mm256_slli_epi64(__m256i a, int imm8);

测试代码:

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>


int main() {
    int64_t data[4] = {0,1,2,3};
    int64_t result[4] = {};

    for(int i=0; i<64; ++i){
        __m256i a = _mm256_loadu_si256((__m256i*)data);
        __m256i b = _mm256_slli_epi64(a,i);
        _mm256_storeu_si256((__m256i*)result, b);
        for(int i=0; i<4; ++i)
            printf("%lu ", result[i]);
        printf("\n");
    }

    return 0;
}

有一个按lane执行的slli指令:

// slli bit-lane by byte (imm8*8)
__m256i _mm256_bslli_epi128(__m256i a, const int imm8);
// synonym 
__m256i _mm256_slli_si256(__m256i a, const int imm8);

lane是AVX2架构中的一个概念,它表示128bits长的一段(double quadword),256bits长的ymm寄存器,就有左右两个lane。这条指令的imm8,表达的是byte的意思,按byte左移。如下测试:

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>

int main() {
    int64_t data[4] = {0,1,2,3};
    int64_t result[4] = {};

    __m256i a = _mm256_loadu_si256((__m256i*)data);
    __m256i b = _mm256_bslli_epi128(a,1);
    _mm256_storeu_si256((__m256i*)result, b);

    for(int i=0; i<4; ++i)
        printf("%ld ", result[i]);
    printf("\n");

    return 0;
}

这里有个:如果将上述测试代码,位移参数设置为大于7,比如8或9,测试结果出乎意料。通过在gdb中查看ymm寄存器的值,感觉芯片是按照little endian的思路,在处理128bit的int。因此,0和1的int128,little endian时的字节排列,就成了1和0....被绕晕了...

下面是通过gdb看到的移位之前的ymm0中的数据,按int128的方式:
v2_int128 = {0x10000000000000000, 0x30000000000000002}}
store回result后,两个64bits又被交换了一次。

sll: shift left logic

参数由一个__m128i类型的变量提供,此变量低64位的值有效。

__m256i _mm256_sll_epi16(__m256i a, __m128i count);
__m256i _mm256_sll_epi32(__m256i a, __m128i count);
__m256i _mm256_sll_epi64(__m256i a, __m128i count);

sllv: shift left vectorized

每个packed int左移的位数不同,分别保存在另一个YMM寄存器中。

__m256i _mm256_sllv_epi32(__m256i a, __m256i count);
__m256i _mm256_sllv_epi64(__m256i a, __m256i count);

测试:

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>

int main() {
    int64_t data[4] = {1,1,1,1};
    int64_t shift[4] = {1,2,3,4};
    int64_t result[4] = {};

    __m256i a = _mm256_loadu_si256((__m256i*)data);
    __m256i b = _mm256_loadu_si256((__m256i*)shift);
    __m256i c = _mm256_sllv_epi64(a,b);
    _mm256_storeu_si256((__m256i*)result, c);

    for(int i=0; i<4; ++i)
        printf("%ld ", result[i]);  // 2 4 8 16
    printf("\n");

    return 0;
}

srl

__m256i _mm256_srli_epi16(__m256i a, int imm8);
__m256i _mm256_srli_epi32(__m256i a, int imm8);
__m256i _mm256_srli_epi64(__m256i a, int imm8);
__m256i _mm256_srl_epi16(__m256i a, __m128i count);
__m256i _mm256_srl_epi32(__m256i a, __m128i count);
__m256i _mm256_srl_epi64(__m256i a, __m128i count);
__m256i _mm256_srlv_epi32(__m256i a, __m256i count);
__m256i _mm256_srlv_epi64(__m256i a, __m256i count);
__m256i _mm256_bsrli_epi128(__m256i a, const int imm8);
__m256i _mm256_srli_si256(__m256i a, const int imm8);

sra

shift right arithmetic,没有64bits接口。

__m256i _mm256_srai_epi16(__m256i a, int imm8);
__m256i _mm256_srai_epi32(__m256i a, int imm8);
__m256i _mm256_sra_epi16(__m256i a, __m128i count);
__m256i _mm256_sra_epi32(__m256i a, __m128i count);
__m256i _mm256_srav_epi32(__m256i a, __m256i count);

自制 srai epi64 接口

// suppose x is target register with type of __m256i
__m256i a = _mm256_srai_epi32(x, 31);
__m256i b = _mm256_setr_epi32(1,1,3,3,5,5,7,7);
__m256i x = _mm256_permutevar8x32_epi32(a, b);

And/Or/Not/XOR

// bit-wise and
__m256i _mm256_and_si256(__m256i a, __m256i b);
__m256  _mm256_and_ps(__m256 a, __m256 b);
__m256d _mm256_and_pd(__m256d a, __m256d b);
// bit-wise or
__m256i _mm256_or_si256(__m256i a, __m256i b);
__m256  _mm256_or_ps(__m256 a, __m256 b);
__m256d _mm256_or_pd(__m256d a, __m256d b);
// (not a) and b
__m256i _mm256_andnot_si256(__m256i a, __m256i b);
__m256  _mm256_andnot_ps(__m256 a, __m256 b);
__m256d _mm256_andnot_pd(__m256d a, __m256d b);
// XOR
__m256i _mm256_xor_si256(__m256i a, __m256i b);
__m256  _mm256_xor_ps(__m256 a, __m256 b);
__m256d _mm256_xor_pd(__m256d a, __m256d b);

Max

__m256i _mm256_max_epi8(__m256i a, __m256i b);
__m256i _mm256_max_epi16(__m256i a, __m256i b);
__m256i _mm256_max_epi32(__m256i a, __m256i b);
__m256i _mm256_max_epu8(__m256i a, __m256i b);
__m256i _mm256_max_epu16(__m256i a, __m256i b);
__m256i _mm256_max_epu32(__m256i a, __m256i b);
__m256  _mm256_max_ps(__m256 a, __m256 b);
__m256d _mm256_max_pd(__m256d a, __m256d b);

自制 max epi64 接口

#include <stdio.h>
#include <stdint.h>
#include <immintrin.h>

static inline __m256i _mm256_max_epi64(__m256i a, __m256i b){
    __m256i c = _mm256_cmpgt_epi64(a, b);
    return _mm256_blendv_epi8(b, a, c);
}

int main() {
    int64_t result[4] = {};

    __m256i a = _mm256_setr_epi64x(1,20202,3,40404);
    __m256i b = _mm256_setr_epi64x(10101,2,30303,4);
    __m256i c = _mm256_max_epi64(a, b);
    _mm256_storeu_si256((__m256i*)result, c);

    // 10101 20202 30303 40404
    for(int i=0; i<4; ++i)
        printf("%ld ", result[i]);
    printf("\n");

    return 0;
}

自制 max epu64 接口

Min

__m256i _mm256_min_epi8(__m256i a, __m256i b);
__m256i _mm256_min_epi16(__m256i a, __m256i b);
__m256i _mm256_min_epi32(__m256i a, __m256i b);
__m256i _mm256_min_epu8(__m256i a, __m256i b);
__m256i _mm256_min_epu16(__m256i a, __m256i b);
__m256i _mm256_min_epu32(__m256i a, __m256i b);
__m256  _mm256_min_ps(__m256 a, __m256 b);
__m256d _mm256_min_pd(__m256d a, __m256d b);

Floor/Ceil

// floor
__m256  _mm256_floor_ps(__m256 a);
__m256d _mm256_floor_pd(__m256d a);
// ceil
__m256  _mm256_ceil_ps(__m256 a);
__m256d _mm256_ceil_pd(__m256d a);

Sqrt/Rsqrt

只有float type接口,如果是int,需要cast。

__m256d _mm256_sqrt_pd(__m256d a);
__m256  _mm256_sqrt_ps(__m256 a);
// 1/sqrt(a)
__m256  _mm256_rsqrt_ps(__m256 a);

Round

__m256  _mm256_round_ps(__m256 a, int rounding);
__m256d _mm256_round_pd(__m256d a, int rounding);

rouding的低4bits控制模式:

Rounding is done according to the rounding[3:0] parameter, which can be one of:
(_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC) // round to nearest, and suppress exceptions
(_MM_FROUND_TO_NEG_INF |_MM_FROUND_NO_EXC)     // round down, and suppress exceptions
(_MM_FROUND_TO_POS_INF |_MM_FROUND_NO_EXC)     // round up, and suppress exceptions
(_MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC)        // truncate, and suppress exceptions
_MM_FROUND_CUR_DIRECTION                       // use MXCSR.RC; see _MM_SET_ROUNDING_MODE

Cast

cast接口在编译之后,并不会真的产生什么指令,这些接口的存在,主要是为了满足C/C++语言的type system的要求。

// cast from __m256 to __m256i, single to int
__m256i _mm256_castps_si256(__m256 a);
// cast from __m256d to __m256i, double to int
__m256i _mm256_castpd_si256(__m256d a);
// cast from __m256i to __m256, int to single
__m256  _mm256_castsi256_ps(__m256i a);
// cast from __m256i to __m256d, int to double
__m256d _mm256_castsi256_pd(__m256i a);
// cast from __m256d to __m256, double to single
__m256  _mm256_castpd_ps(__m256d a);
// cat from __m256 to __m256d, single to double
__m256d _mm256_castps_pd(__m256 a);

Cast和Convert的区别

Cast只是对内存数据的重新解释,Convert就是尽量保持值不变的转换了。下面的代码解释了这种差异:

float f1 =123.456;
int i1 = f1;  // convert implicitly
printf("%d\n", i1);
printf("%d\n", *(int*)&f1);  // cast to int

Convert

// rounding to int
__m256i _mm256_cvtps_epi32(__m256 a)
// truncate to int
__m256i _mm256_cvttps_epi32(__m256 a)
// int to single float
__m256 _mm256_cvtepi32_ps(__m256i a);

下面三条cvt指令,用来提取256bits寄存器的最后部分:

int    _mm256_cvtsi256_si32(__m256i a);
float  _mm256_cvtss_f32(__m256 a);
double _mm256_cvtsd_f64(__m256d a);

Permute

Permute指令的特点是,在一个256bits寄存器内部,对基础类型数据进行Shuffle操作。其实用Shuffle不太准确,因为最后的结果,有可能丢弃某些数据。

/*
DEFINE SELECT4(src, control) {
    CASE(control[1:0]) OF
    0:  tmp[31:0] := src[31:0]
    1:  tmp[31:0] := src[63:32]
    2:  tmp[31:0] := src[95:64]
    3:  tmp[31:0] := src[127:96]
    ESAC
    RETURN tmp[31:0]
}
dst[31:0] := SELECT4(a[127:0], imm8[1:0])
dst[63:32] := SELECT4(a[127:0], imm8[3:2])
dst[95:64] := SELECT4(a[127:0], imm8[5:4])
dst[127:96] := SELECT4(a[127:0], imm8[7:6])
dst[159:128] := SELECT4(a[255:128], imm8[1:0])
dst[191:160] := SELECT4(a[255:128], imm8[3:2])
dst[223:192] := SELECT4(a[255:128], imm8[5:4])
dst[255:224] := SELECT4(a[255:128], imm8[7:6])
dst[MAX:256] := 0
*/
__m256  _mm256_permute_ps(__m256 a, int imm8);
/*
IF (imm8[0] == 0) dst[63:0] := a[63:0]; FI
IF (imm8[0] == 1) dst[63:0] := a[127:64]; FI
IF (imm8[1] == 0) dst[127:64] := a[63:0]; FI
IF (imm8[1] == 1) dst[127:64] := a[127:64]; FI
IF (imm8[2] == 0) dst[191:128] := a[191:128]; FI
IF (imm8[2] == 1) dst[191:128] := a[255:192]; FI
IF (imm8[3] == 0) dst[255:192] := a[191:128]; FI
IF (imm8[3] == 1) dst[255:192] := a[255:192]; FI
dst[MAX:256] := 0
*/
__m256d _mm256_permute_pd(__m256d a, int imm8);

这两个permute指令,更像是在做选择。而且,两个128bits的Lane之间的数据不能交互。

下面这两条指令,与上面两条的区别,仅仅是参数类型不同:

/*
DEFINE SELECT4(src, control) {
    CASE(control[1:0]) OF
    0:  tmp[31:0] := src[31:0]
    1:  tmp[31:0] := src[63:32]
    2:  tmp[31:0] := src[95:64]
    3:  tmp[31:0] := src[127:96]
    ESAC
    RETURN tmp[31:0]
}
dst[31:0] := SELECT4(a[127:0], b[1:0])
dst[63:32] := SELECT4(a[127:0], b[33:32])
dst[95:64] := SELECT4(a[127:0], b[65:64])
dst[127:96] := SELECT4(a[127:0], b[97:96])
dst[159:128] := SELECT4(a[255:128], b[129:128])
dst[191:160] := SELECT4(a[255:128], b[161:160])
dst[223:192] := SELECT4(a[255:128], b[193:192])
dst[255:224] := SELECT4(a[255:128], b[225:224])
dst[MAX:256] := 0
*/
__m256  _mm256_permutevar_ps(__m256 a, __m256i b);
/*
IF (b[1] == 0) dst[63:0] := a[63:0]; FI
IF (b[1] == 1) dst[63:0] := a[127:64]; FI
IF (b[65] == 0) dst[127:64] := a[63:0]; FI
IF (b[65] == 1) dst[127:64] := a[127:64]; FI
IF (b[129] == 0) dst[191:128] := a[191:128]; FI
IF (b[129] == 1) dst[191:128] := a[255:192]; FI
IF (b[193] == 0) dst[255:192] := a[191:128]; FI
IF (b[193] == 1) dst[255:192] := a[255:192]; FI
dst[MAX:256] := 0
*/
__m256d _mm256_permutevar_pd(__m256d a, __m256i b);

下面这两条var8x32指令,真正实现了在整个256bits寄存器内的随意Shuffle,所以要用3个bits来定位。

/*
FOR j := 0 to 7
    i := j*32
    id := idx[i+2:i]*32
    dst[i+31:i] := a[id+31:id]
ENDFOR
dst[MAX:256] := 0
*/
__m256  _mm256_permutevar8x32_ps(__m256 a, __m256i idx);
/*
FOR j := 0 to 7
    i := j*32
    id := idx[i+2:i]*32
    dst[i+31:i] := a[id+31:id]
ENDFOR
dst[MAX:256] := 0
*/
__m256i _mm256_permutevar8x32_epi32(__m256i a, __m256i idx);

下面是两条4x64指令:

__m256d _mm256_permute4x64_pd(__m256d a, const int imm8);
__m256i _mm256_permute4x64_epi64(__m256i a, const int imm8);

有点晕了,感觉命名方式似乎有些不统一....不过,可以确定,permute的最小单位是32bits。

Blend

blend指令与permute的不同之处在于,blend是从2个256bits的寄存器中取数据,是混合的意思。

/*
FOR j := 0 to 15
    i := j*16
    IF imm8[j%8]
        dst[i+15:i] := b[i+15:i]
    ELSE
        dst[i+15:i] := a[i+15:i]
    FI
ENDFOR
dst[MAX:256] := 0
imm8只有8bits用来控制,两个Lane的pattern相同。
*/
__m256i _mm256_blend_epi16(__m256i a, __m256i b, const int imm8);
/*
FOR j := 0 to 7
    i := j*32
    IF imm8[j]
        dst[i+31:i] := b[i+31:i]
    ELSE
        dst[i+31:i] := a[i+31:i]
    FI
ENDFOR
dst[MAX:256] := 0
*/
__m256i _mm256_blend_epi32(__m256i a, __m256i b, const int imm8);
/*
FOR j := 0 to 7
    i := j*32
    IF imm8[j]
        dst[i+31:i] := b[i+31:i]
    ELSE
        dst[i+31:i] := a[i+31:i]
    FI
ENDFOR
dst[MAX:256] := 0
*/
__m256  _mm256_blend_ps(__m256 a, __m256 b, const int imm8);
/*
FOR j := 0 to 3
    i := j*64
    IF imm8[j]
        dst[i+63:i] := b[i+63:i]
    ELSE
        dst[i+63:i] := a[i+63:i]
    FI
ENDFOR
dst[MAX:256] := 0
*/
__m256d _mm256_blend_pd(__m256d a, __m256d b, const int imm8);

下面这一组blendv指令(v代表vector),采用256bits寄存器来控制:

/*
FOR j := 0 to 31
    i := j*8
    IF mask[i+7]
        dst[i+7:i] := b[i+7:i]
    ELSE
        dst[i+7:i] := a[i+7:i]
    FI
ENDFOR
dst[MAX:256] := 0
*/
__m256i _mm256_blendv_epi8(__m256i a, __m256i b, __m256i mask);
/*
FOR j := 0 to 7
    i := j*32
    IF mask[i+31]
        dst[i+31:i] := b[i+31:i]
    ELSE
        dst[i+31:i] := a[i+31:i]
    FI
ENDFOR
dst[MAX:256] := 0
*/
__m256  _mm256_blendv_ps(__m256 a, __m256 b, __m256 mask);
/*
FOR j := 0 to 3
    i := j*64
    IF mask[i+63]
        dst[i+63:i] := b[i+63:i]
    ELSE
        dst[i+63:i] := a[i+63:i]
    FI
ENDFOR
dst[MAX:256] := 0
*/
__m256d _mm256_blendv_pd(__m256d a, __m256d b, __m256d mask);

对BF16的支持

水平加

/*
dst[15:0] := a[31:16] + a[15:0]
dst[31:16] := a[63:48] + a[47:32]
dst[47:32] := a[95:80] + a[79:64]
dst[63:48] := a[127:112] + a[111:96]
dst[79:64] := b[31:16] + b[15:0]
dst[95:80] := b[63:48] + b[47:32]
dst[111:96] := b[95:80] + b[79:64]
dst[127:112] := b[127:112] + b[111:96]
dst[143:128] := a[159:144] + a[143:128]
dst[159:144] := a[191:176] + a[175:160]
dst[175:160] := a[223:208] + a[207:192]
dst[191:176] := a[255:240] + a[239:224]
dst[207:192] := b[159:144] + b[143:128]
dst[223:208] := b[191:176] + b[175:160]
dst[239:224] := b[223:208] + b[207:192]
dst[255:240] := b[255:240] + b[239:224]
dst[MAX:256] := 0
a和b是交替出现的!
*/
__m256i _mm256_hadd_epi16(__m256i a, __m256i b);
/* Saturated hadd */
__m256i _mm256_hadds_epi16(__m256i a, __m256i b);
__m256i _mm256_hadd_epi32(__m256i a, __m256i b);
/*
dst[31:0] := a[63:32] + a[31:0]
dst[63:32] := a[127:96] + a[95:64]
dst[95:64] := b[63:32] + b[31:0]
dst[127:96] := b[127:96] + b[95:64]
dst[159:128] := a[191:160] + a[159:128]
dst[191:160] := a[255:224] + a[223:192]
dst[223:192] := b[191:160] + b[159:128]
dst[255:224] := b[255:224] + b[223:192]
dst[MAX:256] := 0
*/
__m256  _mm256_hadd_ps(__m256 a, __m256 b);
__m256d _mm256_hadd_pd(__m256d a, __m256d b);

水平减

__m256i _mm256_hsub_epi16 (__m256i a, __m256i b);
__m256i _mm256_hsubs_epi16 (__m256i a, __m256i b);
/*
dst[31:0] := a[31:0] - a[63:32]
dst[63:32] := a[95:64] - a[127:96]
dst[95:64] := b[31:0] - b[63:32]
dst[127:96] := b[95:64] - b[127:96]
dst[159:128] := a[159:128] - a[191:160]
dst[191:160] := a[223:192] - a[255:224]
dst[223:192] := b[159:128] - b[191:160]
dst[255:224] := b[223:192] - b[255:224]
dst[MAX:256] := 0
*/
__m256i _mm256_hsub_epi32 (__m256i a, __m256i b);
__m256  _mm256_hsub_ps (__m256 a, __m256 b);
/*
dst[63:0] := a[63:0] - a[127:64]
dst[127:64] := b[63:0] - b[127:64]
dst[191:128] := a[191:128] - a[255:192]
dst[255:192] := b[191:128] - b[255:192]
dst[MAX:256] := 0
*/
__m256d _mm256_hsub_pd (__m256d a, __m256d b);

String(SSE)

本文链接:https://cs.pynote.net/hd/202401216/

-- EOF --

-- MORE --