0

我想知道如何在 keras 中实施本文 [ https://arxiv.org/abs/1506.03099]中描述的计划抽样(可能与课程学习一起)。

让我们假设一个简单的编码器解码器模型,如下所示。如果可能,您将如何更改此模型以包含这些策略和行为?

# encoder architecture
encoder_inp = lyrs.Input(shape=(None, encoder_input_data.shape[-1]))
encoder_out = lyrs.GRU(size_gru)(encoder_inp)

# decoder architecture | training
decoder_inp = lyrs.Input(shape=(None, decoder_input_data_categorical.shape[-1]))
decoder_gru_lyr = lyrs.GRU(size_gru, return_sequences=True, return_state=True)
decoder_seq, decoder_states = decoder_gru_lyr(decoder_inp, initial_state=encoder_out)
decoder_dns_lyr = lyrs.Dense(decoder_input_data_categorical.shape[-1], activation='softmax')
decoder_out = decoder_dns_lyr(decoder_seq)

# encoder_decoder model | training 
model = k.models.Model(inputs=[encoder_inp, decoder_inp], outputs=decoder_out)

# encoder model | inference
encoder_model = k.models.Model(encoder_inp, encoder_out)

# decoder architecture | inference
decoder_state_inp = lyrs.Input(shape=(size_gru,))
decoder_inf_out_seq, decoder_inf_out_state = decoder_gru_lyr(decoder_inp, initial_state=decoder_state_inp)
decoder_inf_out_preds = decoder_dns_lyr(decoder_inf_out_seq)

# decoder model | inference
decoder_model = k.models.Model([decoder_inp, decoder_state_inp], [decoder_inf_out_preds, decoder_inf_out_state])
4

1 回答 1

0

我已经能够使用 lambda 层来解决这个问题,以启用给定参数的随机选择,手动堆叠解码器层以使输出重新馈送到解码器中,以及具有“on_epoch_end”功能的自定义 keras.utils.Sequence 生成器进行更新从纪元到纪元的采样时间表。

这是(可能有点hacky)版本:


size_gru = 32

encoder_input_data = data['data_train'][0][:,:,:1]
decoder_output_data = data['data_train'][1][:,:,:1]
decoder_input_data = np.concatenate([encoder_input_data[:,-1:], decoder_output_data[:,:-1]], axis=1)


np.random.seed(seed)
steps = 288

# ENCODER ARCHITECTURE
encoder_inp = lyrs.Input(shape=(None, encoder_input_data.shape[-1]))
encoder_out = lyrs.GRU(size_gru)(encoder_inp)


# DECODER ARCHITECTURE 
decoder_inp = lyrs.Input(shape=(None, decoder_input_data.shape[-1]))

first_inp = lyrs.Lambda(lambda x: x[:, 0:1, :])(decoder_inp)
decoder_gru_lyr = lyrs.GRU(size_gru, return_sequences=True, return_state=True)
decoder_dns_lyr = lyrs.Dense(1)

##############################
#### DECODER LOGIC |  training
decoder_out, decoder_state = decoder_gru_lyr(first_inp, initial_state=encoder_out)
decoder_out = decoder_dns_lyr(decoder_out)

#### SCHEDULED SAMPLING LOGIC
do_inp = lyrs.Input(shape=(1,), name='do_input')
do_lyr = lyrs.Lambda(lambda x: K.dropout(x, level=do_inp[0,0])*(1-do_inp[0,0]), name='do_lyr')
ones_inp = lyrs.Input(shape=(1,), name='ones_input')
model_out = decoder_out
teacher_out = lyrs.Lambda(lambda x: x[:, 1:2, :])(decoder_inp)
strategy = do_lyr(ones_inp)
strategy = lyrs.Lambda(lambda x: K.round(x))(strategy)
negative_strategy = lyrs.Lambda(lambda x: (1-x))(strategy)
decoder_out = lyrs.Multiply()([decoder_out, strategy])
teacher_out = lyrs.Multiply()([teacher_out, negative_strategy])
decoder_out = lyrs.Add()([decoder_out, teacher_out])


for step in range(1, steps-1):
    #### DECODER LOGIC
    decoder_out, decoder_state = decoder_gru_lyr(decoder_out, initial_state=decoder_state)
    decoder_out = decoder_dns_lyr(decoder_out)

    #### SCHEDULED SAMPLING LOGIC
    model_out = lyrs.Concatenate(axis=-2)([model_out, decoder_out])
    teacher_out = lyrs.Lambda(lambda x: x[:, step+1:step+2, :])(decoder_inp)
    strategy = do_lyr(ones_inp)
    strategy = lyrs.Lambda(lambda x: K.round(x))(strategy)
    negative_strategy = lyrs.Lambda(lambda x: (1-x))(strategy)
    decoder_out = lyrs.Multiply()([decoder_out, strategy])
    teacher_out = lyrs.Multiply()([teacher_out, negative_strategy])
    decoder_out = lyrs.Add()([decoder_out, teacher_out])

decoder_out, decoder_state = decoder_gru_lyr(decoder_out, initial_state=decoder_state)
decoder_out = decoder_dns_lyr(decoder_out)
model_out = lyrs.Concatenate(axis=-2)([model_out, decoder_out])

##############################
#### DECODER LOGIC | inference

decoder_state_inp = lyrs.Input(shape=(size_gru,))
decoder_inf_out, decoder_inf_state = decoder_gru_lyr(decoder_inp, initial_state=decoder_state_inp)
decoder_inf_out = decoder_dns_lyr(decoder_inf_out)

model = k.models.Model(inputs=[encoder_inp, decoder_inp, do_inp, ones_inp], outputs=[model_out])
encoder_model = k.models.Model(encoder_inp, encoder_out)
decoder_model = k.models.Model([decoder_inp, decoder_state_inp], [decoder_inf_out, decoder_inf_state])

def strat_plot(m, b, epochs=100):
    x = np.arange(1, epochs+1, 1)
    y = 1/(1+np.exp(-(x*m+b)))
    y = 1-y
    plt.subplots(figsize=(16,4))
    plt.plot(x, y)
    plt.grid()
    plt.title('Chosen sample schedule:')
    plt.xlabel('Epoch')
    plt.ylabel('Chance of teacher forcing')
    plt.show()

class test_gen(k.utils.Sequence):

    def __init__(self, x_set_a, x_set_b, y_set, batch_size, m=1, b=-6):
        self.x_enc, self.x_dec, self.y = x_set_a, x_set_b, y_set
        self.epoch = 0
        self.batch_size = batch_size
        self.do_par = 0.
        self.do_inp = np.ones(shape=(self.batch_size,1)) * self.do_par
        self.m = m
        self.b = b
        self.ones_inp = np.ones(shape=(self.batch_size,1))
        strat_plot(m, b)

    def __len__(self):
        return int(np.ceil(len(self.x_enc) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x_enc = self.x_enc[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x_dec = self.x_dec[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_do_inp = self.do_inp
        batch_ones_inp = self.ones_inp

        return [np.array(batch_x_enc), 
                np.array(batch_x_dec), 
                np.array(batch_do_inp),
                np.array(batch_ones_inp)
               ], np.array(batch_y)

    def on_epoch_end(self):
        self.epoch += 1
        self.do_par = 1/(1+np.exp(-(self.epoch*self.m+self.b)))
        self.do_inp = np.ones(shape=(self.batch_size,1)) * self.do_par

full_gen = test_gen(encoder_input_data[:], decoder_input_data[:, :steps, :], decoder_output_data[:, :steps, :], batch_size=256, m=.15, b=-6)

model.compile(
    optimizer=k.optimizers.RMSprop(lr=.003), 
    #loss='mse',
    loss=reconstructionLoss(N=steps, order=0, absolute=True),
    metrics=['mae']
             )
model.fit_generator(full_gen, epochs=100)
于 2019-10-15T14:41:16.983 回答