我正在尝试将 @tf.function 添加到我在 Tensorflow 2.0 中的自定义训练代码中以提高性能。但是运行代码会引发 FailedPreconditionError: Table already initialized。使用分类特征列时-我怀疑正在发生的是由于@tf.function 的工作方式,分类特征列被多次初始化

例如,以下代码在未应用 @tf.function 装饰器时运行良好,但在添加 @tf.function 装饰器时会中断。

import numpy as np
import pandas as pd

!pip install sklearn
!pip install tensorflow==2.0.0-alpha0
import tensorflow as tf

from tensorflow import feature_column
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split

URL = 'https://storage.googleapis.com/applied-dl/heart.csv'
dataframe = pd.read_csv(URL)

train, test = train_test_split(dataframe, test_size=0.2)
train, val = train_test_split(train, test_size=0.2)

# A utility method to create a tf.data dataset from a Pandas Dataframe
def df_to_dataset(dataframe, shuffle=True, batch_size=32):
  dataframe = dataframe.copy()
  labels = dataframe.pop('target').values.reshape(-1,1)
  ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
  if shuffle:
    ds = ds.shuffle(buffer_size=len(dataframe))
  ds = ds.batch(batch_size)
  return ds

batch_size = 32
train_ds = df_to_dataset(train, batch_size=batch_size)
val_ds = df_to_dataset(val, shuffle=False, batch_size=batch_size)
test_ds = df_to_dataset(test, shuffle=False, batch_size=batch_size)

feature_columns = []

# numeric cols
for header in ['age', 'trestbps', 'chol', 'thalach', 'oldpeak', 'slope', 'ca']:

# bucketized cols
age = feature_column.numeric_column("age")
age_buckets = feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

# indicator cols
thal = feature_column.categorical_column_with_vocabulary_list(
      'thal', ['fixed', 'normal', 'reversible'])
thal_one_hot = feature_column.indicator_column(thal)

# embedding cols
thal_embedding = feature_column.embedding_column(thal, dimension=8)

#crossed cols
crossed_feature = feature_column.crossed_column([age_buckets, thal], hash_bucket_size=1000)
crossed_feature = feature_column.indicator_column(crossed_feature)

feature_layer = tf.keras.layers.DenseFeatures(feature_columns, trainable=False)

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.features = feature_layer#layers.DenseFeatures(feature_columns, trainable=False)
    self.dense = layers.Dense(128, activation = 'relu')
    self.dense2 = layers.Dense(128, activation = 'relu')
    self.sigmoid = layers.Dense(1, activation = 'sigmoid')

  def call(self, x):
    x = self.features(x)
    x = self.dense(x)
    x = self.dense2(x)
    return self.sigmoid(x)

model = MyModel()
loss_object = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')

def train_step(features, label):
  with tf.GradientTape() as tape:
    predictions = model(features)
    loss = loss_object(label, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy(label, predictions)

def test_step(image, label):
  predictions = model(image)
  t_loss = loss_object(label, predictions)

  test_accuracy(label, predictions)

for epoch in range(EPOCHS):
  for features, labels in train_ds:
    train_step(features, labels)

  for features, labels in val_ds:
      test_step(features, labels)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print (template.format(epoch+1,


Epoch 1, Loss: 1.2603175640106201, Accuracy: 61.65803146362305, Test Loss: 1.3003877401351929, Test Accuracy: 67.34693908691406
FailedPreconditionError                   Traceback (most recent call last)
<ipython-input-3-165a6f89be48> in <module>()
    110   counter = 0
    111   for features, labels in train_ds:
--> 112     train_step(features, labels, counter)
    113     counter +=1

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    412       # In this case we have created variables on the first call, so we run the
    413       # defunned version which is guaranteed to never create variables.
--> 414       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    415     elif self._stateful_fn is not None:
    416       # In this case we have not created variables on the first call. So we can

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   1286     """Calls a graph function specialized to the inputs."""
   1287     graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 1288     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   1290   @property

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs)
    572     """
    573     return self._call_flat(
--> 574         (t for t in nest.flatten((args, kwargs))
    575          if isinstance(t, (ops.Tensor,
    576                            resource_variable_ops.ResourceVariable))))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args)
    625     # Only need to override the gradient in graph mode and when we have outputs.
    626     if context.executing_eagerly() or not self.outputs:
--> 627       outputs = self._inference_function.call(ctx, args)
    628     else:
    629       self._register_gradient()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args)
    413             attrs=("executor_type", executor_type,
    414                    "config_proto", config),
--> 415             ctx=ctx)
    416       # Replace empty list with None
    417       outputs = outputs or None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     64     else:
     65       message = e.message
---> 66     six.raise_from(core._status_to_exception(e.code, message), None)
     67   except TypeError as e:
     68     if any(ops._is_keras_symbolic_tensor(x) for x in inputs):

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

FailedPreconditionError: Table already initialized.
     [[{{node my_model_2/dense_features_2/age_bucketized_X_thal_indicator/thal_lookup/hash_table/table_init/LookupTableImportV2}}]] [Op:__inference_train_step_34833]


Epoch 1, Loss: 0.8113343119621277, Accuracy: 63.21243667602539, Test Loss: 0.5340840816497803, Test Accuracy: 71.42857360839844
Epoch 2, Loss: 0.6469629406929016, Accuracy: 69.4300537109375, Test Loss: 0.5265070199966431, Test Accuracy: 72.44898223876953
Epoch 3, Loss: 0.5749971270561218, Accuracy: 71.84800720214844, Test Loss: 0.5283268094062805, Test Accuracy: 72.10884094238281
Epoch 4, Loss: 0.5360371470451355, Accuracy: 72.79792785644531, Test Loss: 0.5270806550979614, Test Accuracy: 72.44898223876953
Epoch 5, Loss: 0.5122867226600647, Accuracy: 73.57512664794922, Test Loss: 0.5229357481002808, Test Accuracy: 72.65306091308594

