位运算的妙用和坑

Last Updated: 2023-12-22 08:56:33 Friday

-- TOC --

代码中总有一些计算,可以通过位运算来实现,一般来说,位运算的速度更快,更能体现程序员的专业素养!本文总结自己遇到过的可以使用位运算的场景,以及使用位运算可能遇到的坑。

先说一个:我在反复测试本文部分代码的过程中,发现同一个位运算技巧,用C语言就有效果,但在Python中就没有效果,反而速度更慢。经过仔细的思考和分析,我认为这是合理的,可以这样来解释,一个位运算在C语言中,就直接对应一条汇编指令,而在Python中却不是,其对应的是Python虚拟机里面的指令,这两种指令大不相同!前者是CPU的一条指令,后者是Python虚拟机的指令。在Python中,使用位运算虽然也能够得到正确的结果,但是其速度快慢的原理,需要研究Python虚拟机的算法实现,而不是CPU架构。从本文很有限的测试结果来看,用Python的时候,那些位运算的技巧和思路就算了吧,用更自然的人类更容易理解的写法,速度反而更快(乘法别用shift,判断奇偶别用and...)!但在C语言中,这些位运算的技巧就很有价值,不仅速度快,而且还可以在某些场景下,帮助你消除branch(去掉CPU的branch hazard)。

判断奇偶性

判断一个数是奇数还是偶数,通常的做法就是让这个数做模2运算(用%符号),检查结果是否为0,是0就是偶数,不是0就是奇数。这个方法没错,但除了使用模2运算来判断数据的奇偶性,还可以使用位运算。

使用位运算判断数值的奇偶性,只需要对这个数进行 &1(与1)操作,1的二进制码在计算机中只有最后一个bit是1,与数据进行与操作,如果得到的结果是1,就是奇数,得到的结果是0,就是偶数。示例代码如下:

>>> assert (123&1) == (123%2)
>>> assert (124&1) == (124%2)
>>> assert (-123&1) == (-123%2)
>>> assert (-124&1) == (-124%2)

由于计算机都采用补码来表达负数,这种表达正好也正好应用&1判断奇偶性。-1奇-2偶-3奇-4偶。实际上CPU在执行AND操作时,只关心寄存器中的bits,没有正数负数补码这些概念。假设寄存器是4bit宽,-1=0xFFFF,奇,-2=0xFFFE,偶,-3=0xFFFD,奇,-4=0xFFFC,偶......

速度分析

用C语言随便写几行代码,分析其汇编,可以发现这两个运算所涉及的汇编语句和数量差异都很大,略了......但用Python做的测试,结果令人失望:

$ python -m timeit -p '[i%2 for i in range(10000)]'
500 loops, best of 5: 528 usec per loop
$ python -m timeit -p '[i&1 for i in range(10000)]'
500 loops, best of 5: 654 usec per loop

%2的速度反而要比&1更快一丢丢。我只能说:高级动态解释性编程语言提供的操作并不能像C语言一样,贴合CPU内部的机器指令。

左移1位,等于乘2(坑)

这个技巧是避免使用乘法的常用招数,编译器也常常将乘法编译成移位和加减操作的组合来提速!

>>> assert (12<<2) == 12*2**2
>>> assert (12<<3) == 12*2**3
>>> assert (12<<4) == 12*2**4
>>> assert (12<<5) == 12*2**5

左移N位,就相当于乘上2的N次方!对于负数也是一样的:

>>> assert (-12<<2) == -12*2**2
>>> assert (-12<<3) == -12*2**3
>>> assert (-12<<4) == -12*2**4
>>> assert (-12<<5) == -12*2**5

对于Python而言,int的大小不收寄存器bit位数的限制,因此可以无限左移来成倍放大整数,得到很大的正数或负数:

>>> 12 << 128
4083388403051261561560495289181218537472
>>> -12 << 128
-4083388403051261561560495289181218537472

x64下的最大左移位数(C语言)

对于C语言来说,整数的大小受到CPU内寄存器大小的限制,x64 CPU的左移指令有个最大位数限制:

下面这段文字,来自AMD的芯片手册:

Shifts the bits of a register or memory location (first operand) to the left through the CF bit by the number of bit positions in an unsigned immediate value or the CL register (second operand). The instruction discards bits shifted out of the CF flag. For each bit shift, the SAL instruction clears the least-significant bit to 0. At the end of the shift operation, the CF flag contains the last bit shifted out of the first operand.

The processor masks the upper three bits of the count operand, thus restricting the count to a number between 0 and 31. When the destination is 64 bits wide, the processor masks the upper two bits of the count, providing a count in the range of 0 to 63.

The effect of this instruction is multiplication by powers of two.

如果代码出现左移超过限制的情况,计算得到的结果是不对的!比如:

int a = 1;
int b = 34;
# 没有warning,32位的a左移34位,结果居然不是零!!
printf("%d\n", a<<b);

乘法计算技巧

编译器常用这个技巧优化代码速度。一些不希望引入乘法的场景,也是用这个技巧。

>>> bin(9)
'0b1001'
>>> assert 12345*9 == (12345<<3) + (12345<<0)
>>> bin(14)
'0b1110'
>>> assert 12345*14 == (12345<<3) + (12345<<2) + (12345<<1)
>>> assert 12345*14 == (12345<<4) - (12345<<1)

这个技巧,在Python下无效,反而更慢:

这是本文第2次发现位运算的技巧在Python下无效了,好无趣,难道在Python下,要避免位运算吗?用更自然的表达方式反而更快?

$ python -m timeit -p 'a=12345;b=a*14'
10000000 loops, best of 5: 29.5 nsec per loop
$ python -m timeit -p 'a=12345;b=(a<<4)-(a<<1)'
5000000 loops, best of 5: 68.2 nsec per loop

右移1位,不一定等于除2(坑)

左移乘2,正负数都适用,而右移除2,只是针对正数而言的!负数右移与除2的结果并不相同。

>>> assert (1234>>1) == int(1234/2)  # 整数除
>>> assert (1234>>3) == int(1234/2**3)
>>> assert (-1234>>1) == int(-1234/2)
>>> assert (-1234>>2) != int(-1234/2**2)
>>> assert (-1234>>3) != int(-1234/2**3)

int(-1234/2)这个计算是个巧合,因为刚好除尽了。如果除不尽,右移和除2是不相等的。因为右移执行round floor,向下取整,而整数除执行round to zero!向0方向取整。

>>> -9 >> 1  # round to floor
-5
>>> -9 // 2  # round to floor
-5
>>> int(-9/2)  # truncated, round to zero
-4

右移还有个特点:C语言和Python对于负数的右移,都选择Arithmetic Right Shift,即左边空出来的bit位不是补0,而是补MSB。因此,正数右移到最后得0,而负数右移到最后,就是-1:

>>> 1234 >> 100
0
>>> -1234 >> 100
-1

负数的右移

位运算的优先级比较低(坑)

搞定一个bug,几乎跟踪了一整天的代码,与正确的参考对比,最后发现错误的原因是位移运算没有括起来,导致计算表达式的优先级没有按照预期的执行!

看下面的测试代码:

>>> 0<<1 + 1
0
>>> (0<<1) + 1
1

位移计算如果不括起来,表达式会先计算1+1。+号两边的空格极具误导性,乍一看还反应不过来这个加法是要先计算的。

如果你的代码不是数字,而是有点长的变量名称,计算表达式除了位移,还有别的计算混合在一起,这个时候,如果不牢记将位运算括起来的经验教训,就很容易出现难以debug的bug。

除了Python,C等其它高级编程语言,在这个细节上的优先级处理应该都是一样的。再来一次示例,来自《CSAPP》:

>>> 1<<2 + 3<<4
512
>>> (1<<2) + (3<<4)
52

promote位运算中的常数

编译器会自动promote位运算中的常数,就是在前面补足够位数的0,然后再开始计算。

#include <stdio.h>

int main(void) {
    unsigned char a = 0x4F & ~1;
    printf("%X\n", a);
    unsigned short b = ~0b10101010;
    printf("%X\n", b);
    unsigned int c = ~0b1111;
    printf("%X\n", c);
    unsigned long d = ~0b11110000;
    printf("%lX\n", d);

    unsigned short e = 0xFFFF;
    e &= 0xEF;
    printf("%X\n", e);

    return 0;
}

输出:

4E
FF55
FFFFFFF0
FFFFFFFFFFFFFF0F
EF

在例如:e &= ~0x10,只清除看得见1的那些bit位,不管e有几个字节。

判断正整数是否为2的幂次

2的幂次整数的特征是:

def is_twos_power(n):
    return n!=0 and n&(n-1)==0

assert is_twos_power(0) == False
assert is_twos_power(1) == True
assert is_twos_power(2) == True
assert is_twos_power(3) == False
assert is_twos_power(1024) == True
assert is_twos_power(1025) == False

如果n是2的幂次,n-1的bit表达,就是n的bit=1的位变为0,后面的bit全变1。

这几个技巧在Hash表的实现中很常见,从key到hash值,再从hash值到bucket的index,后一步计算就是用bucket的总数-1,然后与hash值。因此,这些hash表的bucket的数量,总是2的幂次,按2的幂次增长。(一直按2的幂次增长下去,恐怕也是问题,这就是另一个话题了.....)

判断整数符号

判断int是否为负

// int v = ?;
int sign = v >> (sizeof(int)*8-1);

C语言int类型执行右移,采用arithmetic right shift,如果v为负,sign=-1,如果不为负,sign=0。

看到网上有资源说可能存在一个byte不等于8bit的设计,还可能有int右移不是arithmetic right shift的情况,反正我见识短,我从来没见过....

判断int正负或零

// int v = ?;
int sign = (v!=0) | v>>(sizeof(int)*8-1);
// Or
int sign = (v>0) - (v<0);

判断两整数符号是否相同

// int x,y
bool is_sign_same = ((x^y)>=0);

用到了异或(下面详细介绍)。(查看了这行代码的汇编,先对xor结果执行not,然后逻辑右移31位)

XOR异或技巧

这个XOR操作,是非常值得好好研究的!貌似整个加密解密,都建立在它的基础上。

XOR规则

>>> assert (0^0) == 0
>>> assert (0^1) == 1
>>> assert (1^0) == 1
>>> assert (1^1) == 0

如果把XOR看成bit反转操作,一个bit如果与0异或,no flip,如果与1异或,flip!这个操作是密码学的基础。

恒等律:X ^ 0 = X
归零律:X ^ X = 0
交换律:A ^ B = B ^ A
结合律:A ^ (B ^ C) = (A ^ B) ^ C

汇编代码常常见到使用XOR指令来对寄存器清零的情况!

用XOR反转bit

如上。

用XOR清零

如上。

快速比较两个数是否相等

判断a-b==0,不如(a^b)==0,后者更快!(一定要括起来)

下面是一段使用此技巧,比较IPv6地址是否相等的代码:

static inline int ipv6_addr_equal(const struct in6_addr *a1, const struct in6_addr *a2)
    {
    return (((a1->s6_addr32[0] ^ a2->s6_addr32[0]) |
             (a1->s6_addr32[1] ^ a2->s6_addr32[1]) |
             (a1->s6_addr32[2] ^ a2->s6_addr32[2]) |
             (a1->s6_addr32[3] ^ a2->s6_addr32[3])) == 0);
    }

用XOR交换数值(无需额外空间)

这个C代码常见的骚操作如下:

void swap(int *a, int *b) {
    assert(a != b);
    *a ^= *b ^= *a ^= *b;
}

这段代码的执行顺序是从右到左的。

先执行 *a ^= *b,此时 *a 值已变为 a^b,然后 *b ^= *a,执行后 *b == 最初的 *a,最后执行 *a ^= *b,此时 *a == 最初的 *b

a = a^b;
b = a^b;
a = a^b;

不过,这个swap接口有一个:如果传入的两个int的地址相同,就不再是swap了,而是清零!

用XOR实现奇偶校验

判断一个二进制数中 1 的数量是奇数还是偶数:

// 例如:求 10100001 中 1 的数量是奇数还是偶数
// 结果为 1 就是奇数个 1,结果为 0 就是偶数个 1
1 ^ 0 ^ 1 ^ 0 ^ 0 ^ 0 ^ 0 ^ 1 = 1   

这条性质可用于奇偶校验(Parity Check),每个字节的数据都计算一个校验位,数据和校验位一起发送出去,这样接收方可以根据校验位粗略地判断接收到的数据是否有误。

这个算法的性能可能存在问题,计算机最小按byte计算,如果要按bit计算,性能会受到影响。快速判断多个byte值的奇偶性,我觉得最快的方式,就是查表。对8bit的byte的所有可能取值制作一张表!

用XOR实现FEC

FEC:Forward Error Correction,前向纠错

FEC是一种通过在传输中增加冗余信息,使得接收端能够在传输出错或数据丢失时,利用这些冗余信息,直接在接收端修正或恢复出丢失数据的一种方法。FEC的理论基础,就是XOR。

假设在网络通信中,有 N 个 packet 需要发送,每 2 个 packet 生成一个 FEC packet,这样,连续的 3 个 packet 中的任意一个 packet 丢失,都能通过另外 2 个恢复出来的。但考虑到每 2 个 packet 就产生 1 个 FEC packet,冗余度可能有点高(比较浪费带宽),我们可以根据具体场景,每 3 个或者每 N 个 packet 再产生一个 FEC packet。我们以每 3 个 packet(A、B、C) 产生 1 个 FEC packet(D)为例来推导一下:

d = a ^ b ^ c
a = a ^ (b ^ b) ^ (c ^ c) = (b ^ c) ^ (a ^ b ^ c) = b ^ c ^ d
b = (a ^ a) ^ b ^ (c ^ c) = (a ^ c) ^ (a ^ b ^ c) = a ^ c ^ d
c = (a ^ a) ^ (b ^ b) ^ c = (a ^ b) ^ (a ^ b ^ c) = a ^ b ^ d

好优美的数学特性....

奇偶位互换

题目:写一个宏定义,实现的功能是将一个int型的数的奇偶位互换,例如6的2进制为00000110,(从右向左)第一位与第二位互换,第三位与第四位互换,其余都是0不需要交换,得到00001001,输出应该为9。

#define XEO_INT(n) ((n<<1)&(0xAAAAAAAA))|((n>>1)&(0x55555555))

Single Number

题目: Given an array of integers, every element appears twice except for one. Find that single one. Note: Your algorithm should have a linear runtime complexity. Could you implement it without using extra memory?

将所有数字进行XOR,最后的结果就是这个只出现了一次的数。

这种题的套路是,一个整型数组里除了N个数字之外,其他的数字都出现了两次,找出这N个数字。先XOR一遍,然后再具体分析。

绝对值计算

// int v = ?;
int mask = v >> (sizeof(int)*8-1);
int abs_v = (v^mask) - mask;
// Or
int abs_v = (v+mask) ^ mask; 

mask其实就是取v的MSB,如果MSB=1,mask=0xFFFFFFFF,即-1,如果MSB=0,mask=0。如果v是负数,与全1做异或,相当于取反,然后再减-1,等于+1,这就是取反加1,得到正数呀!如果v是正数,异或没效果,减0也没效果,啥都不干。位运算优先级低,要将异或操作括起来。另一个实现,对于v为正的情况一样,啥都不干。当v为负数时,先减去1,再取反。

用Python测试一下:

>>> v = -1234
>>> m = v >> 31  # m = -1
>>> (v^m) - m
1234
>>> (v+m) ^ m
1234

但...Python依然表现出位运算比调用内置的abs更慢的特点...

取反加1,反过来就是,减1取反:

-a = ~a + 1
-a-1 = ~a
~(-a-1) = a

此方法与abs接口处理最大负数的结果一致:

#include <stdio.h>
#include <stdlib.h>

int main(){
    int v = 0x80000000;
    int mask = v >> 31;
    int abs_v = (v^mask) - mask;
    printf("%X\n", abs_v);   // 0x80000000
    printf("%X\n", abs(v));  // 0x80000000
    return 0;
}

输出一致!

But,x86和gcc的配合,让这个小技巧没有了用武之地:

int func(int a){
    return abs(a);
}

得到汇编(-O0)如下:

func:
    push    rbp
    mov     rbp, rsp
    mov     DWORD PTR [rbp-4], edi
    mov     eax, DWORD PTR [rbp-4]
    mov     edx, eax
    neg     edx
    cmovns  eax, edx
    pop     rbp
    ret

inline,neg和cmovns配合....完美!

《CSAPP》Data Lab,位运算部分

这部分难度还是挺大的,最后一题不看答案完全没思路。如下实现,全部都测试通过:

$ ./btest
Score   Rating  Errors  Function
 1      1       0       bitXor
 1      1       0       tmin
 1      1       0       isTmax
 2      2       0       allOddBits
 2      2       0       negate
 3      3       0       isAsciiDigit
 3      3       0       conditional
 3      3       0       isLessOrEqual
 4      4       0       logicalNeg
 4      4       0       howManyBits
...

常用技巧:

代码:

//1
/*
 * bitXor - x^y using only ~ and &
 *   Example: bitXor(4, 5) = 1
 *   Legal ops: ~ &
 *   Max ops: 14
 *   Rating: 1
 */
int bitXor(int x, int y) {
    return (~(x&y))&(~((~x)&(~y)));
}
/*
 * tmin - return minimum two's complement integer
 *   Legal ops: ! ~ & ^ | + << >>
 *   Max ops: 4
 *   Rating: 1
 */
int tmin(void) {
    return 1 << 31;
}
//2
/*
 * isTmax - returns 1 if x is the maximum, two's complement number,
 *     and 0 otherwise
 *   Legal ops: ! ~ & ^ | +
 *   Max ops: 10
 *   Rating: 1
 */
int isTmax(int x) {
    return !((!(x+1))|((~x)^(x+1)));
}
/*
 * allOddBits - return 1 if all odd-numbered bits in word set to 1
 *   where bits are numbered from 0 (least significant) to 31 (most significant)
 *   Examples allOddBits(0xFFFFFFFD) = 0, allOddBits(0xAAAAAAAA) = 1
 *   Legal ops: ! ~ & ^ | + << >>
 *   Max ops: 12
 *   Rating: 2
 */
int allOddBits(int x) {
    int a = 0xAA + (0xAA<<8);
    int m = (a<<16) + a;
    return !((x&m)^m);
}
/*
 * negate - return -x
 *   Example: negate(1) = -1.
 *   Legal ops: ! ~ & ^ | + << >>
 *   Max ops: 5
 *   Rating: 2
 */
int negate(int x) {
    return (~x) + 1;
}
//3
/*
 * isAsciiDigit - return 1 if 0x30 <= x <= 0x39 (ASCII codes for characters '0' to '9')
 *   Example: isAsciiDigit(0x35) = 1.
 *            isAsciiDigit(0x3a) = 0.
 *            isAsciiDigit(0x05) = 0.
 *   Legal ops: ! ~ & ^ | + << >>
 *   Max ops: 15
 *   Rating: 3
 */
int isAsciiDigit(int x) {
    int m = (1<<31) >> 24;
    int a = !((x&(m|0xF0))^0x30);  // if 0x0011____
    int b = !!(x&0x08);            // if 0x____1___
    int c = !(x&0x06);             // if 0x_____00_
    return a & ((!b)|(b&c));
}
/*
 * conditional - same as x ? y : z
 *   Example: conditional(2,4,5) = 4
 *   Legal ops: ! ~ & ^ | + << >>
 *   Max ops: 16
 *   Rating: 3
 */
int conditional(int x, int y, int z) {
    int a = !!x;
    int b = (~a) + 1;
    return (b&y) | (~b&z);
}
/*
 * isLessOrEqual - if x <= y  then return 1, else return 0
 *   Example: isLessOrEqual(4,5) = 1.
 *   Legal ops: ! ~ & ^ | + << >>
 *   Max ops: 24
 *   Rating: 3
 */
int isLessOrEqual(int x, int y) {
    int max_neg = 1<<31;
    int a = !(x&max_neg);
    int b = !(y&max_neg);
    int x_sub_y = x + (~y) + 1;
    int c = !(x_sub_y&max_neg);
    return ((a^b)&((!a)&b)) | ((!(a^b))&(!c|!(x_sub_y^0)));
}
//4
/*
 * logicalNeg - implement the ! operator, using all of
 *              the legal operators except !
 *   Examples: logicalNeg(3) = 0, logicalNeg(0) = 1
 *   Legal ops: ~ & ^ | + << >>
 *   Max ops: 12
 *   Rating: 4
 */
int logicalNeg(int x) {
    return ((x|((~x)+1))>>31) + 1;
}
/* howManyBits - return the minimum number of bits required to represent x in
 *             two's complement
 *  Examples: howManyBits(12) = 5
 *            howManyBits(298) = 10
 *            howManyBits(-5) = 4
 *            howManyBits(0)  = 1
 *            howManyBits(-1) = 1
 *            howManyBits(0x80000000) = 32
 *  Legal ops: ! ~ & ^ | + << >>
 *  Max ops: 90
 *  Rating: 4
 */
int howManyBits(int x) {
    int b16, b8, b4, b2, b1, b0;
    int s = x>>31;
    x = (s&~x) | (~s&x);  // MSB of x must be zero

    b16 = (!!(x>>16)) << 4;
    x >>= b16;
    b8 = (!!(x>>8)) << 3;
    x >>= b8;
    b4 = (!!(x>>4)) << 2;
    x >>= b4;
    b2 = (!!(x>>2)) << 1;
    x >>= b2;
    b1 = !!(x>>1);
    x >>= b1;
    b0 = x;

    return b16+b8+b4+b2+b1+b0+1;  // +1 sign bit
}

LeetCode第29题

LeetCode第29题,整数除

本文链接:https://cs.pynote.net/ag/202110053/

-- EOF --

-- MORE --