我正在尝试将 AlexNet 示例与我自己的数据一起使用,但出现错误:
Traceback (most recent call last):
File "D:/tensorflowWorkSpace/tflearn/testAlexnet.py", line 68, in <module>
snapshot_epoch=False, run_id='alexnet_chinese_characters')
File "D:\Anaconda3\lib\site-packages\tflearn\models\dnn.py", line 215, in fit
callbacks=callbacks)
File "D:\Anaconda3\lib\site-packages\tflearn\helpers\trainer.py", line 333, in fit
show_metric)
File "D:\Anaconda3\lib\site-packages\tflearn\helpers\trainer.py", line 774, in _train
feed_batch)
File "D:\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 767, in run
run_metadata_ptr)
File "D:\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 944, in _run
% (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (64, 277, 277) for Tensor 'InputData/X:0', which has shape '(?, 277, 277, 3)'
网络代码如下:
from __future__ import division, print_function, absolute_import
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression
# import tflearn.datasets.oxflower17 as oxflower17
# X, Y = oxflower17.load_data(one_hot=True, resize_pics=(227, 227))
def alexnet(input, num_class):
# Building 'AlexNet'
network = conv_2d(input, 96, 11, strides=4, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = conv_2d(network, 256, 5, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 256, 3, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)
network = fully_connected(network, 4096, activation='tanh')
network = dropout(network, 0.5)
network = fully_connected(network, 4096, activation='tanh')
network = dropout(network, 0.5)
network = fully_connected(network, num_class, activation='softmax')
network = regression(network, optimizer='momentum',
loss='categorical_crossentropy',
learning_rate=0.001)
return network
from tflearn.data_utils import image_preloader
# data_dir = "D:/tensorflowWorkSpace/tflearn/17flowers/jpg"
data_dir = "D:/课题相关~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~/ocr刘亮亮资料/OCR数据/汉字数据集部分"
X, Y = image_preloader(data_dir, image_shape=(277, 277), mode='folder',
categorical_labels=True, normalize=True,
files_extension=['.jpg', '.png'], filter_channel=False)
num_classes = 17
x = input_data(shape=[None, 277, 277, 3])
network = alexnet(x, num_classes)
# Training
model = tflearn.DNN(network, checkpoint_path='model_alexnet',
max_checkpoints=1, tensorboard_verbose=2)
model.fit(X, Y, n_epoch=30, validation_set=0.1, shuffle=True,
show_metric=True, batch_size=64, snapshot_step=100,
snapshot_epoch=False, run_id='alexnet_chinese_characters')
本例使用数据flower17。我更改了加载数据的方法,它运行成功。我用自己的数据汉字代替flower17,然后出现这个错误。
我该如何解决这个错误?