0

张量板图

我正在运行此代码以对来自 Google Streetview 的门牌号进行分类,它运行然后在第一步中进入某种循环,我不知道为什么。

我已将问题缩小到输入管道。

def convert_labels(labels):
    for i in range(len(labels)):
        if len(labels[i])==2:
            labels[i].append(10)
        elif len(labels[i])==1:
            labels[i].append(10)
            labels[i].append(10)
    return labels

with open('whole_data.pickle','rb') as f:
    data_obj=pickle.load(f)
train_images=np.asarray(data_obj['train_data'][0],np.float32)
train_labels=np.asarray(convert_labels(data_obj['train_data'][1]),np.int32)
whole_train=[train_images,train_labels]
whole_train[0]=np.reshape(whole_train[0][:],[whole_train[0].shape[0],32,32,1])

graph = tf.Graph()

with graph.as_default():
    with tf.name_scope('Pipelines'):
        image_batch,label_batch=tf.train.shuffle_batch([whole_train[0],
        batch_size=64,capacity=50000,enqueue_many=True,min_after_dequeue=100,
        name='train_pipe')


with tf.Session(graph=graph) as sess:
tf.global_variables_initializer().run()
writer=tf.summary.FileWriter('/tmp/Mark2_logs',graph=graph)
step=1 
print('hi')
for step in range(100): 
    print('step %d' %(step))
    batch_images,batch_labels=sess.run([image_batch,label_batch])
    plt.imshow(batch_images[0])

我已将链接附加到我存储数据和代码的 Google 云端硬盘:

https://drive.google.com/open?id=0B-hAFmA-zmGdTndyaHJzWEdQaFE

它包含三个文件。

  • model.ipynb - 它是包含 CNN 和会话的主文件,
  • preprocess.ipynb - 这是我用来处理我的数据和
  • data.pickle - 我已经在这个文件中处理并存储了我的数据,所以你不必运行 preprocess.ipynb。
4

0 回答 0