我有一个简单的 pyspark 应用程序可以在 TensorFlowOnSpark 上运行。
我使用TFCluster.InputMode.SPARK
模式将 RDD 提供给模型。fit_generator() 方法工作正常,但推理阶段没有任何进展并卡住了。为了可生产性,我使用糖尿病数据集(https://www.kaggle.com/uciml/pima-indians-diabetes-database)。我将数据集分成两部分,trainFolder
一部分放在testFolder
. 生成器方法generate_rdd_data()
将批量数据提供给模型进行训练,generate_rdd_test()
并为推理做同样的事情,只是它不产生标签。
这是代码:
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from tensorflowonspark import TFCluster, TFNode
def main_fun(args, ctx):
from keras.models import Sequential
from keras.layers import Dense
import numpy
def generate_rdd_data(tf_feed, batch_size=1):
print("generate_rdd_data invoked")
while not tf_feed.should_stop():
batch = tf_feed.next_batch(batch_size)
if len(batch)> 0:
features = []
lbls = []
for item in batch:
features.append(item[0])
lbls.append(item[1])
xs = numpy.array(features).astype('float32')
ys = numpy.array(lbls).astype('float32')
yield (xs, ys)
def generate_rdd_test(tf_feed, batch_size=1):
print("generate_rdd_test invoked")
while not tf_feed.should_stop():
batch = tf_feed.next_batch(batch_size)
print("batch len: %s" %len(batch))
if len(batch)> 0:
features = []
for item in batch:
features.append(item[0])
xs = numpy.array(features).astype('float32')
yield xs
batch_size = 10
# fix random seed for reproducibility
numpy.random.seed(7)
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
tf_feed = TFNode.DataFeed(ctx.mgr)
model.fit_generator(generator=generate_rdd_data(tf_feed, batch_size),
steps_per_epoch=20,
epochs=5,
verbose=1,
callbacks=None)
tf_feed.terminate()
# evaluate the model
tf_feed_eval = TFNode.DataFeed(ctx.mgr, train_mode=False)
predicts = model.predict_generator(generator=generate_rdd_test(tf_feed_eval, batch_size),
steps=20
)
tf_feed_eval.batch_results(predicts)
#tf_feed_eval.terminate()
sc = SparkContext(conf=SparkConf().setAppName("keras model test on spark (MM)"))
num_executors = 1
num_ps = 0
tensorboard = False
args = None
trainFolder = "/path/to/train"
testFolder = "/path/to/test"
def parse(ln):
vals = ln.split(',')
return [float(x) for x in vals[:-1]], int(vals[-1])
def parse_test(ln):
vals = ln.split(',')
return [float(x) for x in vals[:-1]]
cluster = TFCluster.run(sc, main_fun, args, num_executors, num_ps, tensorboard, TFCluster.InputMode.SPARK)
dataRDD = sc.textFile(trainFolder).map(parse)
cluster.train(dataRDD, 5)
testRDD = sc.textFile(testFolder).map(parse_test)
predictRDD = cluster.inference(testRDD)
predictRDD.take(2)
cluster.shutdown()
print ('Done!')
我用一个执行器将作业提交到我机器上的独立集群。这是提交脚本:
${SPARK_HOME}/bin/spark-submit \
--master ${MASTER} \
--conf spark.cores.max=1 \
--conf spark.task.cpus=1 \
--conf spark.executorEnv.JAVA_HOME="$JAVA_HOME" \
/path/to/this_code.py
在这个示例中,为了继续进行预测,遗漏了什么或者可能是错误的?
这是执行器的日志尾部:
.
.
.
1/20 [>.............................] - ETA: 18s - loss: 1.4276 - acc: 0.7000
13/20 [==================>...........] - ETA: 0s - loss: 1.4257 - acc: 0.5000
20/20 [==============================] - 1s 54ms/step - loss: 1.4379 - acc: 0.4750
2018-08-16 13:50:19,548 INFO (MainThread-50639) Processed 200 items in partition
2018-08-16 13:50:19,578 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='running'
2018-08-16 13:50:19,584 INFO (MainThread-50639) mgr.state='running'
2018-08-16 13:50:19,584 INFO (MainThread-50639) Feeding partition <itertools.chain object at 0x10c00b590> into input queue <multiprocessing.queues.JoinableQueue object at 0xb1f231e50>
2018-08-16 13:50:19,588 INFO (MainThread-50638) terminate() invoked
2018-08-16 13:50:20,616 INFO (MainThread-50639) Processed 200 items in partition
2018-08-16 13:50:20,621 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:20,622 INFO (MainThread-50639) connected to server at ('10.5.193.158', 60356)
2018-08-16 13:50:20,646 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:20,653 INFO (MainThread-50639) mgr.state='terminating'
2018-08-16 13:50:20,653 INFO (MainThread-50639) mgr is terminating, skipping partition
2018-08-16 13:50:20,654 INFO (MainThread-50639) Skipped 200 items from partition
2018-08-16 13:50:20,656 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:20,689 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:20,695 INFO (MainThread-50639) mgr.state='terminating'
2018-08-16 13:50:20,695 INFO (MainThread-50639) mgr is terminating, skipping partition
2018-08-16 13:50:20,696 INFO (MainThread-50639) Skipped 200 items from partition
2018-08-16 13:50:20,699 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:20,734 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:20,740 INFO (MainThread-50639) mgr.state='terminating'
2018-08-16 13:50:20,740 INFO (MainThread-50639) mgr is terminating, skipping partition
2018-08-16 13:50:20,741 INFO (MainThread-50639) Skipped 200 items from partition
2018-08-16 13:50:20,744 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:20,778 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:20,785 INFO (MainThread-50639) mgr.state='terminating'
2018-08-16 13:50:20,785 INFO (MainThread-50639) mgr is terminating, skipping partition
2018-08-16 13:50:20,786 INFO (MainThread-50639) Skipped 200 items from partition
2018-08-16 13:50:20,789 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:20,822 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:20,830 INFO (MainThread-50639) mgr.state='terminating'
2018-08-16 13:50:20,830 INFO (MainThread-50639) mgr is terminating, skipping partition
2018-08-16 13:50:20,831 INFO (MainThread-50639) Skipped 200 items from partition
2018-08-16 13:50:20,834 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:20,867 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:20,874 INFO (MainThread-50639) mgr.state='terminating'
2018-08-16 13:50:20,874 INFO (MainThread-50639) mgr is terminating, skipping partition
2018-08-16 13:50:20,875 INFO (MainThread-50639) Skipped 200 items from partition
2018-08-16 13:50:20,878 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:20,912 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:20,921 INFO (MainThread-50639) mgr.state='terminating'
2018-08-16 13:50:20,922 INFO (MainThread-50639) mgr is terminating, skipping partition
2018-08-16 13:50:20,923 INFO (MainThread-50639) Skipped 200 items from partition
2018-08-16 13:50:20,927 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:20,960 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:20,966 INFO (MainThread-50639) mgr.state='terminating'
2018-08-16 13:50:20,966 INFO (MainThread-50639) mgr is terminating, skipping partition
2018-08-16 13:50:20,967 INFO (MainThread-50639) Skipped 200 items from partition
2018-08-16 13:50:20,970 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:21,005 INFO (MainThread-50639) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:21,012 INFO (MainThread-50639) mgr.state='terminating'
2018-08-16 13:50:21,012 INFO (MainThread-50639) mgr is terminating, skipping partition
2018-08-16 13:50:21,013 INFO (MainThread-50639) Skipped 200 items from partition
2018-08-16 13:50:21,017 INFO (MainThread-50639) TFSparkNode: requesting stop
2018-08-16 13:50:21,163 INFO (MainThread-50618) Connected to TFSparkNode.mgr on 10.5.193.158, executor=0, state='terminating'
2018-08-16 13:50:21,167 INFO (MainThread-50618) Feeding partition <itertools.imap object at 0xb1f231190> into input queue <multiprocessing.queues.JoinableQueue object at 0xb1f231e50>
2018-08-16 13:50:22,181 INFO (MainThread-50618) Processed 101 items in partition
2018-08-16 13:50:26,188 INFO (MainThread-50638) dropped 292 items from queue
generate_rdd_test invoked
请注意,该generate_rdd_test()
方法被调用,但之后没有任何反应。任何帮助表示赞赏。