0

我正在研究神经网络,我是这个领域的新手。我正在为文本数据实现自动编码器。当我执行变分自动编码器的拟合函数时,问题就开始了。代码片段是

text_input_shape                 = (encoder(tweets_dna).shape[1], )
text_input_diensions             = len(encoder.get_vocabulary())
text_embedding_output_dimensions = 64

text_input_shape

encoder_input = tf.keras.layers.Input(shape=text_input_shape, dtype='float32', name="TF_input")
encoder_embedded = tf.keras.layers.Embedding(input_dim=text_input_diensions, dtype='float32', output_dim=text_embedding_output_dimensions, mask_zero=True, name="Encoder_Embedding")(encoder_input)
LTSM_layer, state_h, state_c = tf.keras.layers.LSTM(text_embedding_output_dimensions, dtype='float32', return_state=True, name="LTSM")(encoder_embedded)
dense_layer = tf.keras.layers.Dense(text_embedding_output_dimensions, dtype='float32', activation='relu', name="Compression")(LTSM_layer)
final_layer = tf.keras.layers.Dense(text_input_diensions, dtype='float32', activation='relu', name="Lossy_Compression")(dense_layer)
text_mu = tf.keras.layers.Dense(text_input_diensions, dtype='float32', name='latent_mu')(final_layer)
text_sigma = tf.keras.layers.Dense(text_input_diensions, dtype='float32', name='latent_sigma')(final_layer)

text_encoder_states = [state_h, state_c]

def sample_z(args):
    mu, sigma = args
    batch     = tf.keras.backend.shape(mu)[0]
    dim       = tf.keras.backend.int_shape(mu)[1]
    eps       = tf.keras.backend.random_normal(shape=(batch, dim))
    return text_mu + tf.keras.backend.exp(sigma / 2) * eps

text_z = tf.keras.layers.Lambda(sample_z, output_shape=(len(encoder.get_vocabulary()), ), name='z')([text_mu, text_sigma])

text_encoder = tf.keras.Model(encoder_input, [text_mu, text_sigma, text_z], name='TextCompressionModel')

# Decoder
decoder_input = tf.keras.layers.Input(shape=(text_input_diensions, ), dtype='float32', name='decoder_input')
decoder_embedded = tf.keras.layers.Embedding(input_dim=text_input_diensions, dtype='float32', output_dim=64, mask_zero=True)(decoder_input)
LTSM_layer = tf.keras.layers.LSTM(64, dtype='float32')(decoder_embedded)
final_output_dense = tf.keras.layers.Dense(MAX_SEQUENCE_LENGHT, dtype='float32')(LTSM_layer)

text_decoder = tf.keras.Model(decoder_input, final_output_dense, name='TextDecoder')

# Instantiate VAE
tweet_vae_outputs = text_decoder(text_encoder(encoder_input)[2])
tweets_vae = tf.keras.Model(encoder_input, tweet_vae_outputs, name='Text-VAE')

def tweets_kl_reconstruction_loss(true, pred):
    # Reconstruction loss
    reconstruction_loss = binary_crossentropy(K.flatten(true), K.flatten(pred)) * input_shape[0]
    # KL divergence loss
    kl_loss = -0.5 * (1 + sigma - K.square(mu) - K.exp(sigma))
    kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
    # Total loss = 50% rec + 50% KL divergence loss
    return (reconstruction_loss + kl_loss)

tweets_reconstruction_loss = tf.function(tweets_kl_reconstruction_loss)

tweets_vae.compile(optimizer='adam', loss=tweets_reconstruction_loss)

history = vae.fit(x = tweet_data, y = tweet_data, 
        epochs = no_epochs, 
        batch_size = batch_size, 
        validation_split = validation_split)

错误是:

纪元 1/10 ---------------------------------------------- ----------------------------- InvalidArgumentError Traceback(最近一次调用最后一次)在----> 1 history = vae.fit( x = tweet_data, y = tweet_data, 2 epochs = no_epochs, 3 batch_size = batch_size, 4 validation_split = validation_split)

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs) 106 def _method_wrapper( self, *args, **kwargs): 107 if not self._in_multi_worker_mode(): # pylint: disable=protected-access --> 108 return method(self, *args, **kwargs) 109 110 #run_distribute_coordinator已经在里面运行了。

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks ,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,validation_batch_size,validation_freq,max_queue_size,workers,use_multiprocessing)1096 batch_size=batch_size):
1097 callbacks.on_train_batch_begin(step)-> 1098 tmp_logs = train_function(iterator)109如果 data_handler.should_sync: 1100
context.async_wait()

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py in train_function(iterator) 804 def train_function(iterator): 805 """一步执行训练。""" --> 806 return step_function(self, iterator) 807 808 else:

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py 在 step_function(model, iterator) 794 795 data = next(iterator) - -> 796 个输出 = model.distribute_strategy.run(run_step, args=(data,)) 797 个输出 = reduce_per_replica(798 个输出, self.distribute_strategy, reduction='first')

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py 在运行(解析参数失败)1209 fn = autograph.tf_convert(1210 fn,autograph_ctx .control_status_ctx(), convert_by_default=False) -> 1211 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) 1212 1213 # TODO(b/151224785): 移除已弃用的别名。

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py in call_for_each_replica(self, fn, args, kwargs) 2583 kwargs = {} 2584 with self._container_strategy().scope(): -> 2585 return self._call_for_each_replica(fn, args, kwargs) 2586 2587 def _call_for_each_replica(self, fn, args, kwargs):

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\distribute\distribute_lib.py in _call_for_each_replica(self, fn, args, kwargs) 2943
self._container_strategy(), 2944
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): -> 2945 return fn(*args, **kwargs) 2946 2947 def _reduce_to(self, reduce_op, value, destinations, experimental_hints):

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\autograph\impl\api.py in wrapper(*args, **kwargs) 273 def wrapper(*args , **kwargs): 274 with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): --> 275 return func(*args, **kwargs) 276 277 if inspect.isfunction(func) 或 inspect.ismmethod(func ):

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py in run_step(data) 787 788 def run_step(data): --> 789 outputs = model.train_step(data) 790 # 确保仅在train_step成功时更新计数器。791 与 ops.control_dependencies(_minimum_control_deps(outputs)):

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\training.py in train_step(self, data) 745 746 with backprop.GradientTape() as磁带:-> 747 y_pred = self(x, training=True) 748 loss = self.compiled_loss(749 y, y_pred, sample_weight, regularization_losses=self.losses)

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in call (self, *args, **kwargs) 983 984 with ops .enable_auto_cast_variables(self._compute_dtype_object): --> 985 个输出 = call_fn(inputs, *args, **kwargs) 986 987 if self._activity_regularizer:

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\functional.py in call(self, inputs, training, mask) 383 张量列表如果有多个输出。384“”“-> 385返回self._run_internal_graph(386输入,培训=培训,掩码=掩码)387

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\functional.py in _run_internal_graph(self, inputs, training, mask) 506 507 args, kwargs = node.map_arguments(tensor_dict) --> 508 个输出 = node.layer(*args, **kwargs) 509 510 # 更新 tensor_dict。

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in call (self, *args, **kwargs) 983 984 with ops .enable_auto_cast_variables(self._compute_dtype_object): --> 985 个输出 = call_fn(inputs, *args, **kwargs) 986 987 if self._activity_regularizer:

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\functional.py in call(self, inputs, training, mask) 383 张量列表如果有多个输出。384“”“-> 385返回self._run_internal_graph(386输入,培训=培训,掩码=掩码)387

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\functional.py in _run_internal_graph(self, inputs, training, mask) 506 507 args, kwargs = node.map_arguments(tensor_dict) --> 508 个输出 = node.layer(*args, **kwargs) 509 510 # 更新 tensor_dict。

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in call (self, *args, **kwargs) 983 984 with ops .enable_auto_cast_variables(self._compute_dtype_object): --> 985 个输出 = call_fn(inputs, *args, **kwargs) 986 987 if self._activity_regularizer:

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\layers\core.py in call(self, inputs) 1191 1192 def call(self, inputs) :-> 1193 返回 core_ops.dense(1194 个输入,1195 个 self.kernel,

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\keras\layers\ops\core.py 密集(输入,内核,偏差,激活,dtype)51 output = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel) 52 else: ---> 53 outputs = gen_math_ops.mat_mul(inputs, kernel) 54 # 向输入广播内核。55 其他:

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\ops\gen_math_ops.py in mat_mul(a, b, transpose_a, transpose_b, name) 5622 return _result 5623 除外_core._NotOkStatusException as e: -> 5624 _ops.raise_from_not_ok_status(e, name) 5625 除了 _core._FallbackException: 5626 pass

c:\users\bilal\appdata\local\programs\python\python38\lib\site-packages\tensorflow\python\framework\ops.py in raise_from_not_ok_status(e, name) 6841 message = e.message + (" name: " + name if name is not None else "") 6842 # pylint: disable=protected-access -> 6843 Six.raise_from(core._status_to_exception(e.code, message), None) 6844 # pylint: enable=protected-access 6845

~\AppData\Roaming\Python\Python38\site-packages\six.py in raise_from(value, from_value)

InvalidArgumentError:矩阵大小不兼容:In[0]:[10,3500],In 1:[9,9] [Op:MatMul]

我是这个领域的新手。而且我不确定我的实现是否正确,因为我不知道 In 1 : [9,9] 来自哪里。此外,这个 GitHub 链接指向完整代码的已执行笔记本。请帮忙。

4

0 回答 0