6

在 GitHub 上的官方 PyTorch C++ 示例中, 可以看到一个奇怪的类定义:

class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {...}

我的理解是,这定义了一个CustomDataset“继承自”或“扩展”的类torch::data::datasets::Dataset<CustomDataset>。这对我来说很奇怪,因为我们正在创建的类是从另一个由我们正在创建的类参数化的类继承的......这甚至是如何工作的?这是什么意思?在我看来,这就像一个Integer继承自 的类vector<Integer>,这似乎很荒谬。

4

1 回答 1

9

这是奇怪重复出现的模板模式,简称 CRTP。这种技术的一个主要优点是它启用了所谓的静态多态性,这意味着函数 intorch::data::datasets::Dataset可以调用 的函数CustomDataset,而无需使这些函数虚拟化(从而处理虚拟方法调度的运行时混乱等)。enable_if您还可以根据自定义数据集类型的属性执行编译时元编程,例如 compile-time s。

在 PyTorch 的情况下,BaseDataset(的超类Dataset)大量使用这种技术来支持诸如映射和过滤之类的操作:

  template <typename TransformType>
  MapDataset<Self, TransformType> map(TransformType transform) & {
    return datasets::map(static_cast<Self&>(*this), std::move(transform));
  }

注意派生类型的静态this转换(只要正确应用 CRTP 就合法);datasets::map构造一个MapDataset对象,该对象也由数据集类型参数化,允许MapDataset实现静态调用方法,例如(或者如果它们不存在,则会get_batch遇到编译时错误)。

此外,由于MapDataset接收自定义数据集类型作为类型参数,编译时元编程是可能的:

  /// The implementation of `get_batch()` for the stateless case, which simply
  /// applies the transform to the output of `get_batch()` from the dataset.
  template <
      typename D = SourceDataset,
      typename = torch::disable_if_t<D::is_stateful>>
  OutputBatchType get_batch_impl(BatchRequestType indices) {
    return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
  }

  /// The implementation of `get_batch()` for the stateful case. Here, we follow
  /// the semantics of `Optional.map()` in many functional languages, which
  /// applies a transformation to the optional's content when the optional
  /// contains a value, and returns a new optional (of a different type)  if the
  /// original optional returned by `get_batch()` was empty.
  template <typename D = SourceDataset>
  torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
      BatchRequestType indices) {
    if (auto batch = dataset_.get_batch(std::move(indices))) {
      return transform_.apply_batch(std::move(*batch));
    }
    return nullopt;
  }

请注意,条件启用依赖于SourceDataset,我们只能使用它,因为数据集是使用此 CRTP 模式参数化的。

于 2020-04-20T03:54:55.083 回答