0

我一直在尝试使用谷歌基于 RNN 的seq2seq 模型。

我一直在训练一个文本摘要模型,并且正在输入大约 1GB 大小的文本数据。该模型很快填满了我的整个 RAM(8GB),甚至开始填满交换内存(进一步的 8GB)和崩溃后我必须硬关机。

我的 LSTM 网络的配置如下:

model: AttentionSeq2Seq
model_params:
  attention.class: seq2seq.decoders.attention.AttentionLayerDot
  attention.params:
    num_units: 128
  bridge.class: seq2seq.models.bridges.ZeroBridge
  embedding.dim: 128
  encoder.class: seq2seq.encoders.BidirectionalRNNEncoder
  encoder.params:
    rnn_cell:
      cell_class: GRUCell
      cell_params:
        num_units: 128
      dropout_input_keep_prob: 0.8
      dropout_output_keep_prob: 1.0
      num_layers: 1
  decoder.class: seq2seq.decoders.AttentionDecoder
  decoder.params:
    rnn_cell:
      cell_class: GRUCell
      cell_params:
        num_units: 128
      dropout_input_keep_prob: 0.8
      dropout_output_keep_prob: 1.0
      num_layers: 1
  optimizer.name: Adam
  optimizer.params:
    epsilon: 0.0000008
  optimizer.learning_rate: 0.0001
  source.max_seq_len: 50
  source.reverse: false
  target.max_seq_len: 50

我尝试将批量大小从 32 减少到 16,但它仍然没有帮助。为了防止我的模型占用整个 RAM 并崩溃,我应该进行哪些具体更改?(如减少数据大小、减少堆叠 LSTM 单元的数量、进一步减少批量大小等)

我的系统运行 Python 2.7x、TensorFlow 版本 1.1.0 和 CUDA 8.0。该系统有一个 Nvidia Geforce GTX-1050Ti(768 个 CUDA 内核)和 4GB 内存,系统有 8GB RAM 和另外 8GB 交换内存。

4

1 回答 1

0

你的模型看起来很小。唯一大的就是火车数据。请检查以确保您的get_batch()功能没有错误。有可能每个批次实际上都在加载整个数据集进行训练,以防那里出现错误。

为了快速证明这一点,只需将您的训练数据大小缩减到非常小的值(例如当前大小的 1/10),看看是否有帮助。请注意,它不应该有帮助,因为您使用的是小批量。但是,如果这样可以解决问题,请修复您的get_batch()功能。

于 2017-06-19T13:36:28.867 回答