9

默认情况下,酸洗一个 numpy 视图数组会丢失视图关系,即使数组基础也被酸洗了。我的情况是我有一些复杂的容器对象被腌制。在某些情况下,一些包含的数据是其他一些数据的视图。保存每个视图的独立数组不仅是空间的损失,而且重新加载的数据也失去了视图关系。

一个简单的例子是(但在我的情况下,容器比字典更复杂):

import numpy as np
import cPickle

tmp = np.zeros(2)
d1 = dict(a=tmp,b=tmp[:])    # d1 to be saved: b is a view on a

pickled = cPickle.dumps(d1)
d2 = cPickle.loads(pickled)  # d2 reloaded copy of d1 container

print 'd1 before:', d1
d1['b'][:] = 1
print 'd1 after: ', d1

print 'd2 before:', d2
d2['b'][:] = 1
print 'd2 after: ', d2

这将打印:

d1 before: {'a': array([ 0.,  0.]), 'b': array([ 0.,  0.])}
d1 after:  {'a': array([ 1.,  1.]), 'b': array([ 1.,  1.])}
d2 before: {'a': array([ 0.,  0.]), 'b': array([ 0.,  0.])}
d2 after:  {'a': array([ 0.,  0.]), 'b': array([ 1.,  1.])} # not a view anymore

我的问题:

(1) 有办法保存吗?(2)(甚至更好)有没有办法只有在基础被腌制的情况下才能做到这一点

对于 (1) 我认为可能有一些方法可以通过更改视图数组的__setstate__,__reduce_ex_等...。但我现在对这些没有信心。对于(2)我不知道。

4

1 回答 1

7

这不是在 NumPy 中完成的,因为腌制基本数组并不总是有意义的,并且腌制不会公开检查另一个对象是否也作为其 API 的一部分被腌制的能力。

但是这种检查可以在 NumPy 数组的自定义容器中完成。例如:

import numpy as np
import pickle

def byte_offset(array, source):
    return array.__array_interface__['data'][0] - np.byte_bounds(source)[0]

class SharedPickleList(object):
    def __init__(self, arrays):
        self.arrays = list(arrays)

    def __getstate__(self):
        unique_ids = {id(array) for array in self.arrays}
        source_arrays = {}
        view_tuples = {}
        for array in self.arrays:
            if array.base is None or id(array.base) not in unique_ids:
                # only use views if the base is also being pickled
                source_arrays[id(array)] = array
            else:
                view_tuples[id(array)] = (array.shape,
                                          array.dtype,
                                          id(array.base),
                                          byte_offset(array, array.base),
                                          array.strides)
        order = [id(array) for array in self.arrays]
        return (source_arrays, view_tuples, order)

    def __setstate__(self, state):
        source_arrays, view_tuples, order = state
        view_arrays = {}
        for k, view_state in view_tuples.items():
            (shape, dtype, source_id, offset, strides) = view_state
            buffer = source_arrays[source_id].data
            array = np.ndarray(shape, dtype, buffer, offset, strides)
            view_arrays[k] = array
        self.arrays = [source_arrays[i]
                       if i in source_arrays
                       else view_arrays[i]
                       for i in order]

# unit tests
def check_roundtrip(arrays):
    unpickled_arrays = pickle.loads(pickle.dumps(
        SharedPickleList(arrays))).arrays
    assert all(a.shape == b.shape and (a == b).all()
               for a, b in zip(arrays, unpickled_arrays))

indexers = [0, None, slice(None), slice(2), slice(None, -1),
            slice(None, None, -1), slice(None, 6, 2)]

source0 = np.random.randint(100, size=10)
arrays0 = [np.asarray(source0[k1]) for k1 in indexers]
check_roundtrip([source0] + arrays0)

source1 = np.random.randint(100, size=(8, 10))
arrays1 = [np.asarray(source1[k1, k2]) for k1 in indexers for k2 in indexers]
check_roundtrip([source1] + arrays1)

这可以显着节省空间:

source = np.random.rand(1000)
arrays = [source] + [source[n:] for n in range(99)]
print(len(pickle.dumps(arrays, protocol=-1)))
# 766372
print(len(pickle.dumps(SharedPickleList(arrays), protocol=-1)))
# 11833
于 2016-10-25T19:04:21.797 回答