1

我想知道如何在 Python 中实现二维数组切片?

例如,

arr是自定义类二维数组的一个实例。

如果我想在这个对象上启用 2D 切片语法,如下所示:

arr[:,1:3] #retrieve the 1 and 2 column values of every row

或者

arr[,:3] #retrieve the 1 and 2 column values of every row

用法和语法就像 numpy.array。但是我们自己如何才能实现这种功能呢?

PS:

我的想法是:

对于第一种情况,该[:,1:3]部分就像两个切片的元组

但是,对于第二种情况[,1:3],显得相当神秘。

4

4 回答 4

10

如果您想了解数组切片的规则,下图可能会有所帮助:

在此处输入图像描述

于 2013-04-19T03:59:55.180 回答
4

对于读取访问,您需要覆盖该__getitem__方法:

class ArrayLike(object):
    def __init__(self):
        pass
    def __getitem__(self, arg):
        (rows,cols) = arg # unpack, assumes that we always pass in 2-arguments
        # TODO: parse/interpret the rows/cols parameters,
        # for single indices, they will be integers, for slices, they'll be slice objects
        # here's a dummy implementation as a placeholder 
        return numpy.eye(10)[rows, cols]

一个棘手的问题是__getitem__始终只使用一个参数(除了 self),当您将多个逗号分隔的项目放在方括号内时,您实际上提供了一个元组作为__getitem__调用的参数;因此需要在函数内部解包这个元组(并可选地验证元组的长度是否合适)。

现在给出a = ArrayLike(),你最终得到

  • a[2,3]意味着rows=2cols=3
  • a[:3,2]意味着rows=slice(None, 3, None)cols=3

等等; 您必须查看有关切片对象的文档,以决定如何使用切片信息从类中提取您需要的数据。

为了使它更像一个 numpy 数组,您还需要覆盖__setitem__,以允许分配给元素/切片。

于 2014-01-03T21:26:22.920 回答
2

obj[,:3]不是有效的python,所以它会引发SyntaxError- 因此,您的源文件中不能有该语法。numpy(当您尝试在阵列上使用它时也会失败)

于 2013-04-19T03:39:30.030 回答
0

如果它是您自己的课程并且您愿意传入一个字符串,那么这是一个 hack。

如何覆盖 [] 运算符?

class Array(object):

    def __init__(self, m, n):
        """Create junk demo array."""
        self.m = m
        self.n = n
        row = list(range(self.n))
        self.array = map(lambda x:row, range(self.m))

    def __getitem__(self, index_string):
        """Implement slicing/indexing."""

        row_index, _, col_index = index_string.partition(",")

        if row_index == '' or row_index==":":
            row_start = 0
            row_stop = self.m
        elif ':' in row_index:
            row_start, _, row_stop = row_index.partition(":")
            try:
                row_start = int(row_start)
                row_stop = int(row_stop)
            except ValueError:
                print "Bad Data"
        else:
            try:
                row_start = int(row_index)
                row_stop = int(row_index) + 1
            except ValueError:
                print "Bad Data"

        if col_index == '' or col_index == ":":
            col_start = 0
            col_stop = self.n
        elif ':' in col_index:
            col_start, _, col_stop = col_index.partition(":")
            try:
                col_start = int(col_start)
                col_stop = int(col_stop)
            except ValueError:
                print "Bad Data"
        else:
            try:
                col_start = int(col_index)
                col_stop = int(col_index) + 1
            except ValueError:
                print "Bad Data"

        return map(lambda x: self.array[x][col_start:col_stop],
                       range(row_start, row_stop))

    def __str__(self):
        return str(self.array)

    def __repr__(self):
        return str(self.array)


array = Array(4, 5)
print array
out: [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]

array[",1:3"]
out: [[1, 2], [1, 2], [1, 2], [1, 2]]

array[":,1:3"]
out: [[1, 2], [1, 2], [1, 2], [1, 2]]
于 2013-04-19T05:54:04.397 回答