目前,使用 Python 在 mxnet 中可以做比使用 R 更多的事情。我正在使用 Gluon API,这使得编写代码更加简单,并且允许加载预训练模型。
您参考的教程中使用的模型是Inception 模型。可以在此处找到所有可用的预训练模型的列表。
本教程中的其余操作是数据规范化和扩充。您可以对新数据进行规范化,类似于它们在 API 页面上对其进行规范化的方式:
image = image/255
normalized = mx.image.color_normalize(image,
mean=mx.nd.array([0.485, 0.456, 0.406]),
std=mx.nd.array([0.229, 0.224, 0.225]))
可能的增强列表可在此处获得。
这是适合您的可运行示例。我只做了一个增强,mx.image.CreateAugmenter
如果你想做更多,你可以添加更多参数:
%matplotlib inline
import mxnet as mx
from mxnet.gluon.model_zoo import vision
from matplotlib.pyplot import imshow
def plot_mx_array(array, clip=False):
"""
Array expected to be 3 (channels) x heigh x width, and values are floats between 0 and 255.
"""
assert array.shape[2] == 3, "RGB Channel should be last"
if clip:
array = array.clip(0,255)
else:
assert array.min().asscalar() >= 0, "Value in array is less than 0: found " + str(array.min().asscalar())
assert array.max().asscalar() <= 255, "Value in array is greater than 255: found " + str(array.max().asscalar())
array = array/255
np_array = array.asnumpy()
imshow(np_array)
inception_model = vision.inception_v3(pretrained=True)
with open("/Volumes/Unix/workspace/MxNet/2018-02-20T19-43-45/types_of_data_augmentation/output_4_0.png", 'rb') as open_file:
encoded_image = open_file.read()
example_image = mx.image.imdecode(encoded_image)
example_image = example_image.astype("float32")
plot_mx_array(example_image)
augmenters = mx.image.CreateAugmenter(data_shape=(1, 100, 100))
for augementer in augmenters:
example_image = augementer(example_image)
plot_mx_array(example_image)
example_image = example_image / 255
normalized_image = mx.image.color_normalize(example_image,
mean=mx.nd.array([0.485, 0.456, 0.406]),
std=mx.nd.array([0.229, 0.224, 0.225]))