我有一个大的输入特征,如大小的 3D 数组500x500x500和10000此类样本。和尺寸的标签500x500x500x500。我创建了一个输入形状的模型,输入形状500x500x500仅使用Conv3D一层Dense,输出层仅使用一层(我有自己的原因在输出处使用密集层),网络的输出形状为500x500x500x500.
以下是我使用的最低限度模型:
ip = Input(shape=(500,500,500,1))
x = Conv3D(100,3,activation="relu",padding='same')(ip)
x = Dense(500,activation="softmax")(x)
nn = Model(inputs=ip, outputs=x)
以下是摘要:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) (None, 500, 500, 500, 1) 0
_________________________________________________________________
conv3d_4 (Conv3D) (None, 500, 500, 500, 100 2800
_________________________________________________________________
dense_4 (Dense) (None, 500, 500, 500, 500 50500
=================================================================
Total params: 53,300
Trainable params: 53,300
Non-trainable params: 0
_________________________________________________________________
当我运行模型时出现内存错误,因为我有 64 GB RAM 和 quadroP5000 nvidia GPU。
使其工作的另一种方法是将输入拆分100s为5x500x500块,从而使网络输入为 size 5x500x500。现在我有10000x100=1000000size 的样本5x500x500。下面是修改后的网络:
ip = Input(shape=(5,500,500,1))
x = Conv3D(100,3,activation="relu",padding='same')(ip)
x = Dense(500,activation="softmax")(x)
nn = Model(inputs=ip, outputs=x)
以下是摘要:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_6 (InputLayer) (None, 5, 500, 500, 1) 0
_________________________________________________________________
conv3d_5 (Conv3D) (None, 5, 500, 500, 100) 2800
_________________________________________________________________
dense_5 (Dense) (None, 5, 500, 500, 500) 50500
=================================================================
Total params: 53,300
Trainable params: 53,300
Non-trainable params: 0
_________________________________________________________________
显然参数总数是相同的,但现在我能够训练网络,因为我能够将数据加载到 RAM 中。但网络无法学习,因为它无法一次看到所有信息只看5那些。信息分布在整个大小数组中500x500x500,因此网络无法仅查看一个大小的块来找出任何内容5x500x500。
请建议我如何克服这个问题。我希望我的网络使用所有信息进行预测,而不仅仅是一个块。