2

我目前正在使用一个数据集类加载我的数据。在数据集中,我分别拆分了训练、测试和验证数据。例如:

class Data():
    def __init__(self):
        self.load()

    def load(self):
        with open(file=file_name, mode='r') as f:
            self.data = f.readlines()

        self.train = self.data[:checkpoint]
        self.valid = self.data[checkpoint:halfway]
        self.test = self.data[halfway:]

为了便于阅读,许多细节被省略了。基本上,我读入一个大数据集并手动进行拆分。

我的问题是__len__当我的火车长度、有效数据和测试数据都不同时如何覆盖该方法?

我想这样做的原因是因为我想将拆分数据保留在一个类中,并且我还想为每个类创建单独的数据加载器,例如:

def __len__(self):
    return len(self.train)

不适合self.testand self.valid

也许我从根本上误解了 Dataloader,但我应该如何解决这个问题?提前致谢。

4

1 回答 1

0

我认为获取每个拆分长度的最合适的方法是简单地使用:

# Number of training points
len(self.train)

# Number of testing points
len(self.test)

# Number of validation points
len(self.valid)

或者,如果您想参考对象特定实例的拆分长度:

data = Data()
print(len(data.train))
print(len(data.test))
print(len(data.valid))

__len__允许您实现您想要计算对象元素的方式。因此,我将按如下方式实现它,并使用上述调用来获取每个拆分的计数:

def __len__(self):
    return len(self.data)
于 2019-12-08T13:40:15.107 回答