我在 google colab 上使用 TPU 运行一个非常简单的模型时遇到问题。我把它提炼成一个非常简单的程序。我怀疑它不喜欢嵌套模型(input_2?),但我不知道如何解决这个问题:
import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Activation, Dense, Multiply, Input
from tensorflow.keras import metrics
import warnings
warnings.filterwarnings("ignore")
class DataGenerator:
def __init__(self):
pass
def create_train(self, dataset_info, batch_size, shape, augument=True):
assert shape[2] == 3
while True:
random_indexes = np.random.choice(len(dataset_info), batch_size)
batch_images1 = np.empty((batch_size, shape[0], shape[1], shape[2]))
batch_labels = np.zeros((batch_size, 28))
for i, idx in enumerate(random_indexes):
image1= self.load_image(
dataset_info[idx]['path'], shape)
batch_images1[i] = image1
batch_labels[i][dataset_info[idx]['labels']] = 1
yield batch_images1, batch_labels
def load_image(self, path, shape):
image1 = np.stack((
np.ones((256,256)),
np.ones((256,256)),
np.ones((256,256)),
), -1)
return image1.astype(np.float)
train_datagen = DataGenerator()
train_dataset_info = []
for i in range(0, 1000):
train_dataset_info.append({
'path':str(i),
'labels':np.array([5])})
train_dataset_info = np.array(train_dataset_info)
valid_dataset_info = []
for i in range(1000, 1200):
valid_dataset_info.append({
'path':str(i),
'labels':np.array([6])})
valid_dataset_info = np.array(valid_dataset_info)
print(train_dataset_info.shape, valid_dataset_info.shape)
def create_model(input_shape, n_out):
inp_mask = Input(shape=input_shape)
pretrain_model_mask = ResNet50( input_shape = (256,256,3),
include_top=False,
weights=None,
pooling='max')
x = pretrain_model_mask(inp_mask)
out = Dense(n_out, activation='sigmoid')(x)
model = Model(inputs=inp_mask, outputs=[out])
return model
tf.keras.backend.clear_session()
model = create_model(
input_shape=(256,256,3),
n_out=28)
model.compile(
loss='binary_crossentropy',
optimizer=tf.train.AdamOptimizer(learning_rate=1e-3, ),
metrics=['acc'])
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
tpu_model = tf.contrib.tpu.keras_to_tpu_model(
model,
strategy=tf.contrib.tpu.TPUDistributionStrategy(
tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))
tpu_model.summary()
epochs = 4 ;batch_size = 64
# create train and valid datagens
train_generator = train_datagen.create_train(
train_dataset_info, batch_size, (256,256,3))
validation_generator = train_datagen.create_train(
valid_dataset_info, batch_size, (256,256,3))
# train model
history = tpu_model.fit_generator(
train_generator,
steps_per_epoch=1000,
validation_data=validation_generator,
validation_steps=20,
epochs=epochs,
verbose=1)
这是运行它的输出(只需在 colab 中粘贴为单个单元格):
Epoch 1/4
INFO:tensorflow:New input shapes; (re-)compiling: mode=train (# of cores 8), [TensorSpec(shape=(8,), dtype=tf.int32, name='core_id0'), TensorSpec(shape=(8, 512, 512, 3), dtype=tf.float32, name='input_1_10'), TensorSpec(shape=(8, 28), dtype=tf.float32, name='dense_target_30')]
INFO:tensorflow:Overriding default placeholder.
INFO:tensorflow:Remapping placeholder for input_1
INFO:tensorflow:Remapping placeholder for input_2
INFO:tensorflow:Default: input_2
ERROR:tensorflow:Operation of type Placeholder (tpu_140454984405456_1/input_2) is not supported on the TPU. Execution will fail if this op is used in the graph.
INFO:tensorflow:Started compiling
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-36-112706d24f9b> in <module>()
61 validation_steps=len(valid_df)//batch_size,
62 epochs=4,
---> 63 verbose=1,
64 # use_multiprocessing=False,
65 # callbacks=[checkpointer]
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
2175 use_multiprocessing=use_multiprocessing,
2176 shuffle=shuffle,
-> 2177 initial_epoch=initial_epoch)
2178
2179 def evaluate_generator(self,
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
174
175 outs = model.train_on_batch(
--> 176 x, y, sample_weight=sample_weight, class_weight=class_weight)
177
178 if not isinstance(outs, list):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
1938
1939 self._make_train_function()
-> 1940 outputs = self.train_function(ins)
1941
1942 if len(outputs) == 1:
/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in __call__(***failed resolving arguments***)
1247 input_specs = infeed_instance.make_input_specs(input_tensors)
1248 tpu_model_ops = self._tpu_model_ops_for_input_specs(input_specs,
-> 1249 infeed_manager)
1250 infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
1251
/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager)
1154 infeed_manager)
1155 self._compilation_cache[shape_key] = new_tpu_model_ops
-> 1156 self._test_model_compiles(new_tpu_model_ops)
1157
1158 return self._compilation_cache[shape_key]
/usr/local/lib/python3.6/dist-packages/tensorflow/contrib/tpu/python/tpu/keras_support.py in _test_model_compiles(self, tpu_model_ops)
1097 if proto.status_error_message:
1098 raise RuntimeError('Compilation failed: {}'.format(
-> 1099 proto.status_error_message))
1100
1101 end_time = time.time()
RuntimeError: Compilation failed: Compilation failure: Detected unsupported operations when trying to compile graph cluster_1_11838307395637379894[] on XLA_TPU_JIT: Placeholder (No registered 'Placeholder' OpKernel for XLA_TPU_JIT devices compatible with node {{node tpu_140454984405456_1/input_2}} = Placeholder[dtype=DT_FLOAT, shape=[?,512,512,3], _device="/device:TPU_REPLICATED_CORE"]()
. Registered: device='TPU'
device='CPU'
device='GPU'
device='XLA_GPU'
device='XLA_CPU'
){{node tpu_140454984405456_1/input_2}}
出于某种原因,stackoverflow 坚持我会写一些更多的细节......没有。