我是一个尝试定义 SageMaker 管道的 Tensorflow 和 Python 新手。目前,我在尝试在 SageMaker 中运行保存的模型 Movie Lens示例时遇到问题。我设法使用下面的代码训练模型,并使用 SavedModel API 将其保存到适当的 S3 存储桶中。当我加载模型并尝试使用加载的模型进行预测时,我收到错误消息:
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (3 total):
* <_VariantDataset shapes: (), types: tf.string>
* True
* None
Keyword arguments:
{}
Expected these arguments to match one of the following 0 option(s):
模型加载和预测:
model_path = "/opt/ml/processing/model"
tar_path = os.path.join(model_path, "model.tar.gz")
logger.info(tar_path)
with tarfile.open(tar_path) as tar:
tar.extractall(path=model_path)
logger.info("Extracted model.")
model = tf.saved_model.load(model_path)
scores, titles = model(tf_ratings.take(1))
模型类
class MovieLensModel(tfrs.Model):
# We derive from a custom base class to help reduce boilerplate. Under the hood,
# these are still plain Keras Models.
def __init__(self,
user_model: tf.keras.Model,
movie_model: tf.keras.Model,
task: tfrs.tasks.Retrieval):
super(tfrs.Model,self).__init__() #added arguments for super
# Set up user and movie representations.
self.user_model = user_model
self.movie_model = movie_model
# Set up a retrieval task.
self.task = task
@tf.function
def __call__(self, x, training=True, mask=None):
user_embeddings = self.user_model(x)
movie_embeddings = self.movie_model(X)
return self.task(user_embeddings, movie_embeddings)
def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
# Define how the loss is computed.
user_embeddings = self.user_model(features[0])
movie_embeddings = self.movie_model(features[0])
return self.task(user_embeddings, movie_embeddings)
模型训练并保存:
task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
tf_movies.batch(128).map(movie_model)
))
logger.info("Model training...")
# Create a retrieval model.
model = MovieLensModel(user_model, movie_model, task)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.5))
# Train for 3 epochs.
model.fit(tf_ratings.batch(args.batch_size), epochs = args.epochs)
# save mod
model_path = os.path.join(args.model_dir, "movie_lens")
logger.info("Model path is " + args.model_dir)
model.task = tfrs.tasks.Retrieval() # Removes the metrics.
tf.saved_model.save(model, args.model_dir)
用于 SageMaker 容器的图像是 tensorflow-training:2.2-cpu-py37
我想上面错误中的位置参数与__call__
模型类中的函数匹配。这里困扰我的是我不理解给出的错误,如果__call__
在我进行预测时识别出该函数,model(value)
为什么我会收到没有匹配函数的错误?