我正在实现我自己的迭代器。tqdm 不显示进度条,因为它不知道列表中元素的总数。我不想使用“total=”,因为它看起来很丑。相反,我更愿意在我的迭代器中添加一些 tqdm 可以用来计算总数的东西。
class Batches:
def __init__(self, batches, target_input):
self.batches = batches
self.pos = 0
self.target_input = target_input
def __iter__(self):
return self
def __next__(self):
if self.pos < len(self.batches):
minibatch = self.batches[self.pos]
target = minibatch[:, :, self.target_input]
self.pos += 1
return minibatch, target
else:
raise StopIteration
def __len__(self):
return self.batches.len()
这甚至可能吗?在上面的代码中添加什么...
使用 tqdm 如下..
for minibatch, target in tqdm(Batches(test, target_input)):
output = lstm(minibatch)
loss = criterion(output, target)
writer.add_scalar('loss', loss, tensorboard_step)