我正在运行此代码以对来自 Google Streetview 的门牌号进行分类,它运行然后在第一步中进入某种循环,我不知道为什么。
我已将问题缩小到输入管道。
def convert_labels(labels):
for i in range(len(labels)):
if len(labels[i])==2:
labels[i].append(10)
elif len(labels[i])==1:
labels[i].append(10)
labels[i].append(10)
return labels
with open('whole_data.pickle','rb') as f:
data_obj=pickle.load(f)
train_images=np.asarray(data_obj['train_data'][0],np.float32)
train_labels=np.asarray(convert_labels(data_obj['train_data'][1]),np.int32)
whole_train=[train_images,train_labels]
whole_train[0]=np.reshape(whole_train[0][:],[whole_train[0].shape[0],32,32,1])
graph = tf.Graph()
with graph.as_default():
with tf.name_scope('Pipelines'):
image_batch,label_batch=tf.train.shuffle_batch([whole_train[0],
batch_size=64,capacity=50000,enqueue_many=True,min_after_dequeue=100,
name='train_pipe')
with tf.Session(graph=graph) as sess:
tf.global_variables_initializer().run()
writer=tf.summary.FileWriter('/tmp/Mark2_logs',graph=graph)
step=1
print('hi')
for step in range(100):
print('step %d' %(step))
batch_images,batch_labels=sess.run([image_batch,label_batch])
plt.imshow(batch_images[0])
我已将链接附加到我存储数据和代码的 Google 云端硬盘:
https://drive.google.com/open?id=0B-hAFmA-zmGdTndyaHJzWEdQaFE
它包含三个文件。
- model.ipynb - 它是包含 CNN 和会话的主文件,
- preprocess.ipynb - 这是我用来处理我的数据和
- data.pickle - 我已经在这个文件中处理并存储了我的数据,所以你不必运行 preprocess.ipynb。