这不是在 NumPy 中完成的,因为腌制基本数组并不总是有意义的,并且腌制不会公开检查另一个对象是否也作为其 API 的一部分被腌制的能力。
但是这种检查可以在 NumPy 数组的自定义容器中完成。例如:
import numpy as np
import pickle
def byte_offset(array, source):
return array.__array_interface__['data'][0] - np.byte_bounds(source)[0]
class SharedPickleList(object):
def __init__(self, arrays):
self.arrays = list(arrays)
def __getstate__(self):
unique_ids = {id(array) for array in self.arrays}
source_arrays = {}
view_tuples = {}
for array in self.arrays:
if array.base is None or id(array.base) not in unique_ids:
# only use views if the base is also being pickled
source_arrays[id(array)] = array
else:
view_tuples[id(array)] = (array.shape,
array.dtype,
id(array.base),
byte_offset(array, array.base),
array.strides)
order = [id(array) for array in self.arrays]
return (source_arrays, view_tuples, order)
def __setstate__(self, state):
source_arrays, view_tuples, order = state
view_arrays = {}
for k, view_state in view_tuples.items():
(shape, dtype, source_id, offset, strides) = view_state
buffer = source_arrays[source_id].data
array = np.ndarray(shape, dtype, buffer, offset, strides)
view_arrays[k] = array
self.arrays = [source_arrays[i]
if i in source_arrays
else view_arrays[i]
for i in order]
# unit tests
def check_roundtrip(arrays):
unpickled_arrays = pickle.loads(pickle.dumps(
SharedPickleList(arrays))).arrays
assert all(a.shape == b.shape and (a == b).all()
for a, b in zip(arrays, unpickled_arrays))
indexers = [0, None, slice(None), slice(2), slice(None, -1),
slice(None, None, -1), slice(None, 6, 2)]
source0 = np.random.randint(100, size=10)
arrays0 = [np.asarray(source0[k1]) for k1 in indexers]
check_roundtrip([source0] + arrays0)
source1 = np.random.randint(100, size=(8, 10))
arrays1 = [np.asarray(source1[k1, k2]) for k1 in indexers for k2 in indexers]
check_roundtrip([source1] + arrays1)
这可以显着节省空间:
source = np.random.rand(1000)
arrays = [source] + [source[n:] for n in range(99)]
print(len(pickle.dumps(arrays, protocol=-1)))
# 766372
print(len(pickle.dumps(SharedPickleList(arrays), protocol=-1)))
# 11833