1

我目前正在使用生成器来生成我的训练和验证数据集tf.data.Dataset.from_generator。我有一个类方法可以为我解决这个问题:

def build_dataset(self, batch_size=16, shuffle=16, validation=None):
    
    train_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
    self.train_dataset = train_dataset.shuffle(shuffle).repeat(-1).batch(batch_size).prefetch(1)
    
    if validation is not None:
        val_dataset = tf.data.Dataset.from_generator(import_images(validation=validation), (tf.float32, tf.float32))
        self.val_dataset = val_dataset.repeat(1).batch(batch_size).prefetch(1)

问题是传递(validation=validation)给我import_images的生成器创建了 Tensorflow 不需要的生成器对象,它给了我错误:

TypeError: `generator` must be callable.

因为我必须传入validation告诉我的生成器生成单独的训练和验证版本,所以我需要创建同一个生成器的两个版本。它也不允许我传递其他参数来控制训练和验证示例的百分比——这意味着生成器必须是静态的。有什么建议么?

4

1 回答 1

0

我最近遇到了类似的问题,但我是初学者,所以不确定这是否会有所帮助。

尝试在您的课程中添加呼叫功能。

以下是提出的原始课程TypeError: `generator` must be callable.

class DataGen:
  def __init__(self, files, data_path):
    self.i = 0
    self.files=files
    self.data_path=data_path
  
  def __load__(self, files_name):
    data_path = os.path.join(self.data_path, files_name)
    arr_img, arr_mask = load_patch(data_path)
    return arr_img, arr_mask

  def getitem(self, index):
    _img, _mask = self.__load__(self.files[index])
    return _img, _mask

  def __iter__(self):
    return self

  def __next__(self):
    if self.i < len(self.files):
      img_arr, mask_arr = self.getitem(self.i)
      self.i += 1
    else:
      raise StopIteration()
    return img_arr, mask_arr

然后我修改了下面的代码,它对我有用。

class DataGen:
  def __init__(self, files, data_path):
    self.i = 0
    self.files=files
    self.data_path=data_path
  
  def __load__(self, files_name):
    data_path = os.path.join(self.data_path, files_name)
    arr_img, arr_mask = load_patch(data_path)
    return arr_img, arr_mask

  def getitem(self, index):
    _img, _mask = self.__load__(self.files[index])
    return _img, _mask

  def __iter__(self):
    return self

  def __next__(self):
    if self.i < len(self.files):
      img_arr, mask_arr = self.getitem(self.i)
      self.i += 1
    else:
      raise StopIteration()
    return img_arr, mask_arr
  
  def __call__(self):
    self.i = 0
    return self
于 2020-08-19T00:15:43.300 回答