2

我在 TensorFlow 的 EagerExecution 中工作,以在顺序数据设置中开发变分自动编码器 (VAE) 的变体。由于循环网络结构及其输入输出流程都不是标准的,因此我必须构建自己的自定义 RNNCell,稍后可以将其传递给 tf.nn.raw_rnn API。

关于构建所需 RNNCell 的类,我使用 tf.keras.Model 作为基类。但是,当我将这个 RNNCell 传递给 tf.nn.raw_rnn 时,我得到了 nan 输出。怎么了?

这是我的实现(如果您仍然不清楚,请告诉我)

import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()
import numpy as np
from tensorflow.keras.layers import Input, Dense, Lambda
from tensorflow.keras.models import Model

数据集称为“输入”,所有有界条目均为 float32 dtype 和形状(time_steps、batch_size、input_depth)=(20、1000、4)。注意 shape 格式与使用更熟悉的 tf.nn.dynamic_rnn API 时的区别(使用后者 API 时,shape 的格式为 (batch_size, time_steps, input_depth))。

#defining sampling and reparameterizing function
def sampling(args):
    mean, logvar = args
    batch = batch_size
    dim = latent_dim
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = tf.random_normal(shape=(batch, dim))
    return mean + tf.exp(0.5 * logvar) * epsilon



#defining class of the model (PreSSM = without transition module yet)
class PreSSM(tf.keras.Model):
    def __init__(self, latent_dim = 4, intermediate_dim = 4):
        super(PreSSM, self).__init__()
        self.latent_dim = latent_dim
        self.input_dim = self.latent_dim + 4 #toy problem

        inputs = Input(shape=(self.latent_dim + 4,), name='inference_input')
        layer_1 = Dense(intermediate_dim, activation='relu')(inputs)
        layer_2 = Dense(intermediate_dim, activation='relu')(layer_1)
        mean = Dense(latent_dim, name='mean')(layer_2)
        logvar = Dense(latent_dim, name='logvar')(layer_2)        
        s = Lambda(sampling, output_shape=(latent_dim,), name='s')([mean, logvar])
        self.inference_net = Model(inputs, [mean, logvar, s], name='inference_net')

        latent_inputs = Input(shape=(latent_dim,), name='s_sampling')
        layer_3 = Dense(intermediate_dim, activation='relu')(latent_inputs)
        layer_4 = Dense(intermediate_dim, activation='relu')(layer_3)
        outputs = Dense(2)(layer_4)
        self.generative_net = Model(latent_inputs, outputs, name='generative_net')

    @property
    def state_size(self):
        return latent_dim

    @property
    def output_size(self):
        return 2 #(x,y) coordinate

    @property
    def zero_state(self):
        return init_state #global variable we have defined

    def __call__(self, inputs, state):
        next_state = self.inference_net(inputs)[-1]
        output = self.generative_net(next_state)
        return output, next_state

#instantiate cell == model instant
model = PreSSM()

#define a class with instant super_loop_fn(inputs) that has method called loop_fn
class SuperLoop:
    def __init__(self, inputs, output_dim = 2):
        inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time, clear_after_read=False)
        inputs_ta = inputs_ta.unstack(inputs) #ini datanya
        self.inputs_ta = inputs_ta
        self.output_dim = output_dim

    def loop_fn(self,time, cell_output, cell_state, loop_state):
        emit_output = cell_output # ==None for time == 0
        if cell_output is None: # when time == 0
            next_cell_state = init_state
            emit_output = tf.zeros([self.output_dim])
        else :
            emit_output = cell_output
            next_cell_state = cell_state

        elements_finished = (time >= seq_length)
        finished = tf.reduce_all(elements_finished)

        if finished :
            next_input = tf.zeros(shape=(self.output_dim), dtype=tf.float32)
        else :
            next_input = tf.concat([self.inputs_ta.read(time), next_cell_state], -1)

        next_loop_state = None
        return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)


#defining a model
def SSM_model(inputs, RNN_cell = model, output_dim = 2):
    superloop = SuperLoop(inputs, output_dim)
    outputs_ta, final_state, final_loop_state = tf.nn.raw_rnn(RNN_cell, superloop.loop_fn)
    outputs = outputs_ta.stack()
    return outputs

#model checking
SSM_model(inputs = inputs, RNN_cell = model)

在这里,输出是纳米...

因此我无法进行训练步骤。怎么了?在上面使用 tf.keras.Model 作为基类定义 RNNCell 时,我是否遗漏了什么?

4

0 回答 0