7

我正在实现我自己的迭代器。tqdm 不显示进度条,因为它不知道列表中元素的总数。我不想使用“total=”,因为它看起来很丑。相反,我更愿意在我的迭代器中添加一些 tqdm 可以用来计算总数的东西。

class Batches:
    def __init__(self, batches, target_input):
        self.batches = batches
        self.pos = 0
        self.target_input = target_input

    def __iter__(self):
        return self

    def __next__(self):
        if self.pos < len(self.batches):
            minibatch = self.batches[self.pos]
            target = minibatch[:, :, self.target_input]
            self.pos += 1
            return minibatch, target
        else:
            raise StopIteration

    def __len__(self):
        return self.batches.len()

这甚至可能吗?在上面的代码中添加什么...

使用 tqdm 如下..

for minibatch, target in tqdm(Batches(test, target_input)):

    output = lstm(minibatch)
    loss = criterion(output, target)
    writer.add_scalar('loss', loss, tensorboard_step)
4

2 回答 2

19

我知道已经有一段时间了,但我一直在寻找相同的答案,这就是解决方案。而不是像这样用 tqdm 包装你的迭代

for i in tqdm(my_iterable):
    do_something()

改用“with”关闭,如:

with tqdm(total=len(my_iterable)) as progress_bar:
    for i in my_iterable:
        do_something()
        progress_bar.update(1) # update progress

对于您的批次,您可以将总数设置为批次数,然后更新为 1(如上)。或者,您可以将总计设置为实际的项目总数,并将更新设置为当前处理的批次的大小。

于 2018-06-28T13:10:09.313 回答
1

原始问题指出:

我不想使用“total=”,因为它看起来很丑。相反,我更愿意在我的迭代器中添加一些 tqdm 可以用来计算总数的东西。

但是,当前接受的答案明确指出要使用total

with tqdm(total=len(my_iterable)) as progress_bar:

事实上,给定的示例比它需要的更复杂,因为原始问题没有要求对条形进行复杂的更新。因此,

for i in tqdm(my_iterable, total=my_total):
    do_something()

实际上已经足够了(正如作者@emem 已经在评论中指出的那样)。


这个问题相对较旧(撰写本文时为 4 年),但查看 tqdm 的代码,可以看到从一开始(撰写本文时为 8 年前)行为是默认的total = len(iterable),以防万一total没有给出。

因此,该问题的正确答案是实施__len__. 正如问题中所述,原始示例已经实现。因此,它应该已经正常工作了。

可以在下面找到一个完整的玩具示例来测试行为(请注意__len__方法上方的注释):

from time import sleep
from tqdm import tqdm


class Iter:

    def __init__(self, n=10):
        self.n = n
        self.iter = iter(range(n))

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.iter)

    # commenting the next two lines disables showing the bar
    # due to tqdm not knowing the total number of elements:
    def __len__(self):
        return self.n


it = Iter()
for i in tqdm(it):
    sleep(0.2)

看看 tqdm 究竟做了什么:

try:
    total = len(iterable)
except (TypeError, AttributeError):
    total = None

...并且由于我们不确切知道@Duane 用作什么batches,我认为这基本上只是一个隐藏得很好的错字(self.batches.len()),这会导致AttributeError在 tqdm 中被捕获。

如果batches只是一个序列类型,那么这可能是预期的定义:

    def __len__(self):
        return len(self.batches)

__next__(using )的定义len(self.batches)也指向了这个方向。

于 2022-01-25T13:07:23.087 回答