我正在尝试使用 tensorflow 并行化我的模型的训练步骤ParameterServerStrategy
。我与 GCPAI Platform
合作创建集群并启动任务。由于我的数据集很大,我使用tensorflow-io
.
我的脚本受到tensorflow bigquery reader文档和tensorflow ParameterServerStrategy 文档的启发
在本地,我的脚本运行良好,但是当我使用 AI Platform 启动它时,出现以下错误:
{"created":"@1633444428.903993309","description":"Error received from peer ipv4:10.46.92.135:2222","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Op type not registered \'IO>BigQueryClient\' in binary running on gke-cml-1005-141531--n1-standard-16-2-644bc3f8-7h8p. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.","grpc_status":5}
这些脚本适用于 AI 平台上的假数据,并在本地使用 bigquery 连接器。我想模型的编译包括 bigquery 连接器及其在其他设备上的调用会产生错误,但我不知道如何修复它。
我读到当设备没有相同的 tensorflow 版本时会发生此错误,因此我检查了每个设备上的 tensorflow 和 tensorflow-io 版本。
张量流:2.5.0
张量流-io:0.19.1
我创建了一个类似的示例,它在 AI 平台上重现了该错误
import os
from tensorflow_io.bigquery import BigQueryClient
from tensorflow_io.bigquery import BigQueryReadSession
import tensorflow as tf
import multiprocessing
import portpicker
from tensorflow.keras.layers.experimental import preprocessing
from google.cloud import bigquery
from tensorflow.python.framework import dtypes
import numpy as np
import pandas as pd
client = bigquery.Client()
PROJECT_ID = <your_project>
DATASET_ID = 'tmp'
TABLE_ID = 'bq_tf_io'
BATCH_SIZE = 32
# Bigquery requirements
def init_bq_table():
table = '%s.%s.%s' %(PROJECT_ID, DATASET_ID, TABLE_ID)
# Create toy_data
def create_toy_data(N):
x = np.random.random(size = N)
y = 0.2 + x + np.random.normal(loc=0, scale = 0.3, size = N)
return x, y
x, y =create_toy_data(1000)
df = pd.DataFrame(data = {'x': x, 'y': y})
job_config = bigquery.LoadJobConfig(write_disposition="WRITE_TRUNCATE",)
job = client.load_table_from_dataframe( df, table, job_config=job_config )
job.result()
# Create initial data
#init_bq_table()
CSV_SCHEMA = [
bigquery.SchemaField("x", "FLOAT64"),
bigquery.SchemaField("y", "FLOAT64"),
]
def transform_row(row_dict):
# Trim all string tensors
dataset_x = row_dict
dataset_x['constant'] = tf.cast(1, tf.float64)
# Extract feature column
dataset_y = dataset_x.pop('y')
#Export as tensor
dataset_x = tf.stack([dataset_x[column] for column in dataset_x], axis=-1)
return (dataset_x, dataset_y)
def read_bigquery(table_name):
tensorflow_io_bigquery_client = BigQueryClient()
read_session = tensorflow_io_bigquery_client.read_session(
"projects/" + PROJECT_ID,
PROJECT_ID, TABLE_ID, DATASET_ID,
list(field.name for field in CSV_SCHEMA),
list(dtypes.double if field.field_type == 'FLOAT64'
else dtypes.string for field in CSV_SCHEMA),
requested_streams=2)
dataset = read_session.parallel_read_rows()
return dataset
def get_data():
dataset = read_bigquery(TABLE_ID)
dataset = dataset.map(transform_row, num_parallel_calls=4)
dataset = dataset.batch(BATCH_SIZE).prefetch(2)
return dataset
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
# parameter server and worker just wait jobs from the coordinator (chief)
if cluster_resolver.task_type in ("worker"):
worker_config = tf.compat.v1.ConfigProto()
server = tf.distribute.Server(
cluster_resolver.cluster_spec(),
job_name=cluster_resolver.task_type,
task_index=cluster_resolver.task_id,
config=worker_config,
protocol="grpc")
server.join()
elif cluster_resolver.task_type in ("ps"):
server = tf.distribute.Server(
cluster_resolver.cluster_spec(),
job_name=cluster_resolver.task_type,
task_index=cluster_resolver.task_id,
protocol="grpc")
server.join()
elif cluster_resolver.task_type == 'chief':
strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=cluster_resolver)
if cluster_resolver.task_type == 'chief':
learning_rate = 0.01
with strategy.scope():
# model
model_input = tf.keras.layers.Input(
shape=(2,), dtype=tf.float64)
layer_1 = tf.keras.layers.Dense( 8, activation='relu')(model_input)
dense_output = tf.keras.layers.Dense(1)(layer_1)
model = tf.keras.Model(model_input, dense_output)
#optimizer
optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate)
accuracy = tf.keras.metrics.MeanSquaredError()
@tf.function
def distributed_train_step(iterator):
def train_step(x_batch_train, y_batch_train):
with tf.GradientTape() as tape:
y_predict = model(x_batch_train, training=True)
loss_value = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(y_batch_train, y_predict)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
accuracy.update_state(y_batch_train, y_predict)
return loss_value
x_batch_train, y_batch_train = next(iterator)
return strategy.run(train_step, args=(x_batch_train, y_batch_train))
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)
#test
def dataset_fn(_):
def create_toy_data(N):
x = np.random.random(size = N)
y = 0.2 + x + np.random.normal(loc=0, scale = 0.3, size = N)
return np.c_[x,y]
def toy_transform_row(row):
dataset_x = tf.stack([row[0], tf.cast(1, tf.float64)], axis=-1)
dataset_y = row[1]
return dataset_x, dataset_y
N = 1000
data =create_toy_data(N)
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.map(toy_transform_row, num_parallel_calls=4)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(2)
return dataset
@tf.function
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(lambda x : get_data()) # <-- Not working with AI platform
#return strategy.distribute_datasets_from_function(dataset_fn) # <-- Working with AI platform
per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
# Train model
for epoch in range(5):
per_worker_iterator = iter(per_worker_dataset)
accuracy.reset_states()
for step in range(5):
coordinator.schedule(distributed_train_step, args=(per_worker_iterator,))
coordinator.join()
print ("Finished epoch %d, accuracy is %f." % (epoch, accuracy.result().numpy()))
当我创建数据集时,per_worker_dataset_fn()
我可以使用 bigquery 连接器(窃听)或实时创建数据集(工作)。
AI 平台集群配置:
运行时版本:“2.5”
蟒蛇版本:“3.7”
有人得到这个问题吗?Bigquery 连接器与 AI Platform 上的 MirroredStrategy 配合得非常好。告诉我是否应该在其他地方报告该问题。