《Effective Python》读书笔记(三) 类和继承
23.尽量使用挂钩函数来作为简单接口的接收函数
挂钩函数就相当于如下形式所示。
def log_missing():
print 'Add key'
return 0
current = {'green':12, 'blue':3}
increments = [('red', 5), ('blue', 17), ('orange', 9)]
# log_missing为挂钩函数
result = collections.defaultdict(log_missing, current)
print dict(result)
for key, amount in increments:
result[key] += amount
print dict(result)
{'blue': 3, 'green': 12}
Add key
Add key
{'blue': 20, 'orange': 9, 'green': 12, 'red': 5}
如果需要记录状态数,可以创建一个辅助类。
current = {'green':12, 'blue':3}
increments = [('red', 5), ('blue', 17), ('orange', 9)]
# 如果要记录状态值,可以用辅助类
class CountMissing(object):
def __init__(self):
self.added = 0
def missing(self):
self.added += 1
return 0
cm = CountMissing()
result = collections.defaultdict(cm.missing, current)
print dict(result)
for key, amount in increments:
result[key] += amount
print dict(result)
>>>
{'blue': 3, 'green': 12}
{'blue': 20, 'orange': 9, 'green': 12, 'red': 5}
当然也可以使用__call__函数,来让类的实例也能够像函数一样被调用。当使用callable()时,会返回True。
current = {'green':12, 'blue':3}
increments = [('red', 5), ('blue', 17), ('orange', 9)]
# 如果要记录状态值,也可以这么写,用__call__
class CountMissing(object):
def __init__(self):
self.added = 0
def __call__(self, *args, **kwargs):
self.added += 1
return 0
cm = CountMissing()
result = collections.defaultdict(cm, current)
print dict(result)
for key, amount in increments:
result[key] += amount
print dict(result)
print callable(cm)
>>>
{'blue': 3, 'green': 12}
{'blue': 20, 'orange': 9, 'green': 12, 'red': 5}
True
24.用@classmethod形式的多态去通用的构建对象
@classmethod放在类中某个方法前,将其变为类方法,从而可以直接被类调用。通过@classmethod可以实现多态,即在调用函数的时候,可以直接将类作为参数传递,而在函数内部可以直接调用类方法。
下面通过MapReduce的例子来看。首先看一下不用@classmethod的情况:
# Example 1
class InputData(object):
def read(self):
raise NotImplementedError
# Example 2
class PathInputData(InputData):
def __init__(self, path):
super().__init__()
self.path = path
def read(self):
return open(self.path).read()
# Example 3
class Worker(object):
def __init__(self, input_data):
self.input_data = input_data
self.result = None
def map(self):
raise NotImplementedError
def reduce(self, other):
raise NotImplementedError
# Example 4
class LineCountWorker(Worker):
def map(self):
data = self.input_data.read()
self.result = data.count('\n')
def reduce(self, other):
self.result += other.result
# Example 5
import os
def generate_inputs(data_dir):
for name in os.listdir(data_dir):
yield PathInputData(os.path.join(data_dir, name))
# Example 6
def create_workers(input_list):
workers = []
for input_data in input_list:
workers.append(LineCountWorker(input_data))
return workers
# Example 7
from threading import Thread
def execute(workers):
threads = [Thread(target=w.map) for w in workers]
for thread in threads: thread.start()
for thread in threads: thread.join()
first, rest = workers[0], workers[1:]
for worker in rest:
first.reduce(worker)
return first.result
# Example 8
def mapreduce(data_dir):
inputs = generate_inputs(data_dir)
workers = create_workers(inputs)
return execute(workers)
# Example 9
from tempfile import TemporaryDirectory
import random
def write_test_files(tmpdir):
for i in range(100):
with open(os.path.join(tmpdir, str(i)), 'w') as f:
f.write('\n' * random.randint(0, 100))
with TemporaryDirectory() as tmpdir:
write_test_files(tmpdir)
result = mapreduce(tmpdir)
print('There are', result, 'lines')
>>>
There are 4098 lines
上面的Worker这一大类和InputData这一大类需要通过mapreduce这个方法连接起来,并且如果我需要创建新的Worker或InputData,那么扩展起来就不太方便了。
现在看看使用@classmethod来实现多态。
# Example 10
class GenericInputData(object):
def read(self):
raise NotImplementedError
@classmethod
def generate_inputs(cls, config):
raise NotImplementedError
# Example 11
class PathInputData(GenericInputData):
def __init__(self, path):
super().__init__()
self.path = path
def read(self):
return open(self.path).read()
@classmethod
def generate_inputs(cls, config):
data_dir = config['data_dir']
for name in os.listdir(data_dir):
yield cls(os.path.join(data_dir, name))
# Example 12
class GenericWorker(object):
def __init__(self, input_data):
self.input_data = input_data
self.result = None
def map(self):
raise NotImplementedError
def reduce(self, other):
raise NotImplementedError
@classmethod
def create_workers(cls, input_class, config):
workers = []
for input_data in input_class.generate_inputs(config):
workers.append(cls(input_data))
return workers
# Example 13
class LineCountWorker(GenericWorker):
def map(self):
data = self.input_data.read()
self.result = data.count('\n')
def reduce(self, other):
self.result += other.result
# Example 14
def mapreduce(worker_class, input_class, config):
workers = worker_class.create_workers(input_class, config)
return execute(workers)
# Example 15
with TemporaryDirectory() as tmpdir:
write_test_files(tmpdir)
config = {'data_dir': tmpdir}
result = mapreduce(LineCountWorker, PathInputData, config)
print('There are', result, 'lines')
>>>
There are 4098 lines
在python中,每个类只能有一个构造器,即一个__init__方法 可以用@classmethod来仿造构造器,从而构造类的对象
25.用super初始化父类
直接调用类的__init__函数来初始化父类,那么父类的初始化顺序是按照子类的__init__里对各个超类的__init__调用顺序来进行。
下面看一下不用super来初始化的时候,在钻石型继承中出现的问题。
class MyBaseClass(object):
def __init__(self, value):
self.value = value
class TimesFive(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value *= 5
class PlusTwo(MyBaseClass):
def __init__(self, value):
MyBaseClass.__init__(self, value)
self.value += 2
# 多重继承,即继承多个父类
class ThisWay(TimesFive, PlusTwo):
def __init__(self, value):
TimesFive.__init__(self, value)
PlusTwo.__init__(self, value)
foo = ThisWay(5)
print 'Should be (5*5)+2 but is', foo.value
>>>
Should be (5*5)+2 but is 7
由于TimesFive和PlusTwo在调用__init__的时候都各自调用了一次MyBaseClass的__init__,因此为7。显然,这种方式是错误的。
现在考虑使用super,其定义了“方法解析顺序”(method resolution order,MRO)。MRO是以标准的流程来安排超类的初始化顺序(深度优先,从左往右)。
# python2风格
class MyBaseClass(object):
def __init__(self, value):
self.value = value
class TimesFive(MyBaseClass):
def __init__(self, value):
super(TimesFive, self).__init__(value)
self.value *= 5
class PlusTwo(MyBaseClass):
def __init__(self, value):
super(PlusTwo, self).__init__(value)
self.value += 2
class ThisWay(TimesFive, PlusTwo):
def __init__(self, value):
super(ThisWay, self).__init__(value)
foo = ThisWay(5)
print 'Should be (5*5)+2 and it is', foo.value
# 调用顺序
from pprint import pprint
pprint(ThisWay.mro())
# python3风格
class MyBaseClass(object):
def __init__(self, value):
self.value = value
class TimesFive(MyBaseClass):
def __init__(self, value):
# 或者super().__init__(value)
super(__class__, self).__init__(value)
self.value *= 5
class PlusTwo(MyBaseClass):
def __init__(self, value):
super(__class__, self).__init__(value)
self.value += 2
class ThisWay(TimesFive, PlusTwo):
def __init__(self, value):
super().__init__(value)
foo = ThisWay(5)
print 'Should be (5*5)+2 and it is', foo.value
# 调用顺序
from pprint import pprint
pprint(ThisWay.mro())
>>>
Should be (5*(5+2)) ant it is 35
[<class '__main__.ThisWay'>,
<class '__main__.TimesFive'>,
<class '__main__.PlusTwo'>,
<class '__main__.MyBaseClass'>,
<type 'object'>]
输出结果为35,是正确的。使用super,其初始化的逻辑顺序为:要初始化ThisWay,先初始化TimesFive(深度优先);要初始化TimesFive,先初始化PlusTwo(从左到右);要初始化PlusTwo,先初始化MyBaseClass(深度优先)。所以初始化顺序为:
MyBaseClass->PlusTwo->TimesFive->ThisWay
。
26.只在使用mix-in组件制作工具类时进行多重继承
mix-in是指只实现了单个功能(方法)的类,或者继承这些类的类。
下面以用于类的序列化的mix-in组件为例。
isinstance函数可以动态检测对象类型 __dict__可以打印类实例的所有成员值和类实例的默认私有成员值(如__module__等),并以键值对的形成出现 hasattr函数可以判定某个类实例里有没有某个成员或方法
class ToDictMixin(object):
def to_dict(self):
return self._traverse_dict(self.__dict__)
def _traverse_dict(self, instance_dict):
output = {}
for key, value in instance_dict.iteritems():
cur = self._traverse(key, value)
# 不现实空值
if cur is not None:
output[key] = cur
return output
def _traverse(self, key, value):
# 当然也可以写成isinstance(value, BinaryTree),但是为了通用性,一般写父类
if isinstance(value, ToDictMixin):
return value.to_dict()
# 下面三个该例子中没有用到,可以注释掉
# elif isinstance(value, dict):
# return self._traverse_dict(value)
# elif isinstance(value, list):
# return [self._traverse(key, i) for i in value]
# elif hasattr(value, '__dict__'):
# return self._traverse_dict(value.__dict__)
else:
return value
class BinaryTree(ToDictMixin):
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
tree = BinaryTree(10, left=BinaryTree(7,right=BinaryTree(10)),
right=BinaryTree(3, left=BinaryTree(11)))
print tree.to_dict()
>>>
{'right': {'value': 3, 'left': {'value': 11}},
'value': 10,
'left': {'right': {'value': 10}, 'value': 7}}
下面实现一下多个mix-in的搭配。比如使用JsonMixin来load成json字符串,然后将加载后的值初始化(相当于反序列化,deserialize)DatacenterRack对象,再将该对象通过ToDictMixin来进行序列化,之后再dump成json字符串。
import json
class JsonMixin(object):
@classmethod
def from_json(cls, data):
kwargs = json.loads(data)
return cls(**kwargs)
def to_json(self):
return json.dumps(self.to_dict())
class DatacenterRack(ToDictMixin, JsonMixin):
def __init__(self, switch=None, machines=None):
self.switch = Switch(**switch)
self.machines = [
Machine(**kwargs) for kwargs in machines]
class Switch(ToDictMixin, JsonMixin):
def __init__(self, ports=None, speed=None):
self.ports = ports
self.speed = speed
class Machine(ToDictMixin, JsonMixin):
def __init__(self, cores=None, ram=None, disk=None):
self.cores = cores
self.ram = ram
self.disk = disk
serialized = """{
"switch": {"ports": 5, "speed": 1e9},
"machines": [
{"cores": 8, "ram": 32e9, "disk": 5e12},
{"cores": 4, "ram": 16e9, "disk": 1e12},
{"cores": 2, "ram": 4e9, "disk": 500e9}
]
}"""
deserialized = DatacenterRack.from_json(serialized)
roundtrip = deserialized.to_json()
assert json.loads(serialized) == json.loads(roundtrip)
27.多用public属性,少用private属性
各个属性值的含义:
self.field表示public成员 self.field表示protect成员 self.field表示私有成员,它可以被类内部方法访问,在类外,可以通过instance._Class__field被访问,Class就是该类对应的名称。因此python无法保证private成员的私密性。 self.__len()表示类中的特殊成员或方法
一般情况下,不要在类内定义private成员,应多用protect代替。在下面一种情况下,可以使用private,来防止子类的属性覆盖同名的超类属性。
class ApiClass(object):
def __init__(self):
self._value = 5
def get(self):
return self._value
class Child(ApiClass):
def __init__(self):
super().__init__()
self._value = 'hello' # Conflicts
a = Child()
print(a.get(), 'and', a._value, 'should be different')
>>>
hello and hello should be different
class ApiClass(object):
def __init__(self):
self.__value = 5
def get(self):
return self.__value
class Child(ApiClass):
def __init__(self):
super().__init__()
self._value = 'hello' # OK!
a = Child()
print(a.get(), 'and', a._value, 'are different')
>>>
5 and hello are different
28.继承collections.abc(在python3里有)以实现自定义的容器类型
collections.abc中定义了很多容器的抽象基类,如果要自定义容器,最好就是继承需要的抽象基类,然后实现抽象基类当中的某些特殊方法(如__getitems__和__len__都是特殊方法),那么自定义类就具备了抽象基类提供的其他方法,如count和index方法。
下面给出一个使用collections.abc中的Sequence抽象基类实现自定义容器的实例。
索引访问(
比如foo[0]
)其实就是调用__getitem__()
方法(比如foo.__getitem__(0)
) 使用len(a)
相当于调用a.__len__()
# 实现__getitem__()方法
class IndexableNode(BinaryNode):
# 前序遍历
def _search(self, count, index):
found = None
if self.left:
found, count = self.left._search(count, index)
if not found and count == index:
found = self
else:
count += 1
if not found and self.right:
found, count = self.right._search(count, index)
return found, count
# Returns (found, count)
def __getitem__(self, index):
found, _ = self._search(0, index)
if not found:
raise IndexError('Index out of range')
return found.value
# 实现__len__()方法
class SequenceNode(IndexableNode):
def __len__(self):
_, count = self._search(0, None)
return count
# 载入模块
from collections.abc import Sequence
class BetterNode(SequenceNode, Sequence):
pass
tree = BetterNode(
10,
left=BetterNode(
5,
left=BetterNode(2),
right=BetterNode(
6, right=BetterNode(7))),
right=BetterNode(
15, left=BetterNode(11))
)
print('Index of 7 is', tree.index(7))
print('Count of 10 is', tree.count(10))
>>>
Index of 7 is 3
Count of 10 is 1
当然,如果自定义的容器比较简单,可以直接继承像list、dict、set这样的类,然后加入自己的方法。实例如下:
class FrequencyList(list):
def __init__(self, data):
super(FrequencyList, self).__init__(data)
def frequency(self):
count = collections.defaultdict(lambda:0, {})
for item in self:
count[item] += 1
return dict(count)
fl = FrequencyList(['a', 'b', 'c', 'c', 'a', 'd', 'f', 'b'])
print repr(fl.frequency())
>>>
{'a': 2, 'c': 2, 'b': 2, 'd': 1, 'f': 1}