这是我在 pyspark 上使用 distkeras 运行的 python 代码。
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM, GRU
from pyspark.ml.feature import VectorAssembler
from keras.models import Sequential
from keras.layers.core import *
from distkeras.trainers import *
from pyspark import SQLContext
sqlContext = SQLContext(sc)
reader = sqlContext
raw_dataset = reader.read.format('com.databricks.spark.csv').options(header='false', inferSchema='true').load(
'/home/hpcc/test/11.csv')
# raw_dataset = raw_dataset.repartition(4)
features = raw_dataset.columns
features.remove('C0')
vector_assembler = VectorAssembler(inputCols=features, outputCol="features")
dataset = vector_assembler.transform(raw_dataset)
dataset = dataset.select("features", "C0")
model1 = Sequential()
model1.add(Embedding(input_dim=52965, output_dim=256))
model1.add(LSTM(128))
model1.add(Dropout(0.5))
model1.add(Dense(1))
model1.add(Activation('sigmoid'))
model1.summary()
optimizer_mlp= 'adagrad'
loss_mlp = 'binary_crossentropy'
dataset.cache()
trainer = DOWNPOUR(keras_model=model1, worker_optimizer=optimizer_mlp, loss=loss_mlp, num_workers=4,
batch_size=16, communication_window=5, learning_rate=0.1, num_epoch=1,
features_col="features", label_col="C0")
trainer.set_parallelism_factor(1)
trained_model = trainer.train(dataset)
但它打印这样的错误:
16/12/30 20:53:09 ERROR Executor: Exception in task 1.0 in stage 4.0 (TID 7)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/worker.py", line 111, in main
process()
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/worker.py", line 106, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "build/bdist.linux-x86_64/egg/distkeras/workers.py", line 193, in train
X = np.asarray([x[self.features_column] for x in feature_iterator])
File "/root/anaconda2/lib/python2.7/site-packages/numpy/core/numeric.py", line 482, in asarray
return array(a, dtype, copy=False, order=order)
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/mllib/linalg/__init__.py", line 769, in __getitem__
raise ValueError("Index %d out of bounds." % index)
ValueError: Index 50 out of bounds.
我的 11.csv 是这样的:
1,23,5211,871,1223,11,5,355,9999,1,28032,33057,1259,5575,4,1,9604,23790,83,136,18,4695,3,121,1,151,521,2,130,24677,4,1,612,947,612,1645,4534,59,29008,1,21679,1,416,3474,3,4189,1,142,5,11323,3
........
它的 C0 列是标签,C1 到 C50 是特征。希望您的回答。非常感谢。
这是 dataset.printSchema():
root
|-- features: vector (nullable = true)
|-- C0: integer (nullable = true)