深入理解Python之正确的操作符重载

概述

文本将包括以下内容

  • Python如何支持不同类型的中缀操作符
  • 使用鸭子类型或者显式类型检查来处理各种类型的操作数
  • 中缀操作符方法如果知道不处理操作数
  • 各种比较操作符的特殊行为

操作符101

操作符重载在某些圈子有不好的名声。它是一种容易被滥用的语言特性,导致别人困惑,bugs和意想不到的性能瓶颈。但是如果用好它,能产生令人愉快的APIs和阅读性很好的代码。Python在灵活,可用,安全之间做了平衡。使用以下的限制:

  • 不能重载内置类型操作符
  • 不能创建新的操作符,只能重载原有的
  • 几个操作符不能重载: is, and, or, not

下面开始讨论最简单的一元操作符

一元操作符

  • -(__neg__)
  • +(__pos__)
  • ~(__invert__)
1
2
3
4
5
6
7
8
9
10

def __abs__(self):
return math.sqrt(sum(x * x for x in self))

def __neg__(self):
return Vector(-x for x in self)

def __pos__(self):
return Vector(self)

什么时候 x != +x?

正常情况下x==+x。以下情况两者不相等

第一种情况, decimal.Decimal中,如果x在数学上下文中创建,而+x在不一样的上下文中创建。

1
2
3
4
5
6
7
8
9
10
11
12
13
ctx = decimal.getcontext()
ctx.prec = 40
one_third = decimal.Decimal('1') / decimal.Decimal('3')
print(one_third)
print(one_third == +one_third)
ctx.prec = 28
print(one_third == +one_third)
print(one_third)
print(+one_third)

result:
0.3333333333333333333333333333333333333333
0.3333333333333333333333333333

Decimal的精度被改变。one_third+one_third值不相等。原因是+one_third用新精度会产生一个新的Decimal

第二种情况是collections.Counter,其实现了几个数学操作符。如果使用+操作符,那么会丢弃0及负数

1
2
3
4
5
6
7
8
ct = Counter('abracadabra')
print(ct)

ct['r'] = -3
ct['d'] = 0
print(ct)
print(+ct)

为向量相加重载+

1
2
3
4
5
6
7
8
9
def __add__(self, other):
pairs = itertools.zip_longest(self, other, fillvalue=0.0)
return Vector(a + b for a, b in pairs)

v1 = Vector([3, 4, 5])
v2 = Vector([6, 7, 8])
print(v1 + v2)
print((v1 + v2) == Vector([3 + 6, 4 + 7, 5 + 8]))

pairs 是个生成器,能产生tuple(a, b),aself,bother,如果两者的长度不一致,那么会使用fillvalue填充

1
(10, 20, 30) + v1

然而使用上述表达式会出错。

为了支持不同类型的操作符,Python实现了特殊的下发机制.例如a+b会执行以下步骤

  1. 如果a__add__,调用a.__add__(b),然后返回结果
  2. 如果a没有__add__,或者调用它返回NotImplemented,检查b是否有__radd__,然后调用b.__radd__(a),然后返回结果
  3. 如果b也不含有__radd__,或者调用它返回NotImplemented。抛出TypeError

因此为了使上述的表达式能够正确运行,我们需要实现__radd__方法

1
2
3
def __radd__(self, other):
return self + other

重载*

1
2
3
4
5
6
7
8
9
10
def __mul__(self, scalar):
if isinstance(scalar, numbers.Real):
return Vector(n * scalar for n in self)
return NotImplemented

def __rmul__(self, other):
return self * other

v1 = Vector([1, 2, 3])
print(v1 * 10)

判断被乘数是否是实数

1
2
3
4
5
6
7
8
9
10

def __matmul__(self, other):
try:
return sum(a * b for a, b in zip(self, other))
except TypeError:
return NotImplemented

def __rmatmul__(self, other):
return self @ other

丰富的比较操作符

1
2
3
4
5
def __eq__(self, other):
if isinstance(other, Vector):
return len(self) == len(other) and all(a == b for a, b in zip(self, other))
return NotImplemented

1
2
3
4
5
def __ne__(self, other):
eq_result = self == other
if eq_result is NotImplemented:
return NotImplemented

增量赋值操作符

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

class AddableBingoCage(BingoCage):
def __add__(self, other):
if isinstance(other, Tombola):
return AddableBingoCage(self.inspect() + other.inspect())
else:
return NotImplemented

def __iadd__(self, other):
if isinstance(other, Tombola):
other_iterable = other.inspect()
else:
try:
other_iterable = iter(other)
except TypeError:
self_cls = type(self).__name__
msg = "right operand in += must be {!r} or an iterable"
raise TypeError(msg.format(self_cls))

self.load(other_iterable)
return self