我正在做细分。因此,图像和标签都需要在相同的方向和相同的移动值上移动。我正在尝试添加增强层以在 x 和 y 方向上移动。我在这段Python Layer
代码中做错了什么?有人可以帮忙吗?
# N: number of batch
# K: number of channels, H: the Height, W: the Width
# N x K x H x W
# 0 1 2 3
'''Usage
layer {
name: 'aug_layer'
type: 'Python'
bottom: 'data'
bottom: 'label'
top: 'data'
top: 'label'
python_param {
module: 'augLayer'
layer: 'AugmentationLayer'
}
include{
phase: TRAIN
}
}
'''
import numpy as np
import caffe
import random
#from augmentation_main import doShift_xy
import os,sys
import h5py
import matplotlib.pyplot as plt
#import unittest
#import tempfile
#import os
import six
import scipy
import scipy.io as sio
import os,sys
from skimage import transform as tf
from skimage.transform import AffineTransform
from skimage.util import random_noise
TRAIN =0
TEST =1
class AugmentationLayer(caffe.Layer):
def setup(self,bottom,top):
#assert len(bottom)==2 #requires two layer.bottom 1:image( N x K x H x W) 2:label ( N x 1 x H x W)
#assert bottom[0].data.ndim>= 3 #requires image data
if len(bottom) != 2:
raise Exception("Wrong number of bottom blobs (data, label)")
if len(top)!=2:
raise Exception("Wrong number of top blobs (data,label)")
self.img=[]
self.lbl=[]
self.totalImgs=0
#self.phase=TRAIN
def reshape(self,bottom,top):
#top[0].reshape(*bottom[0].data.shape)
#top[1].reshape(*bottom[1].data.shape)
pass
def forward(self,bottom,top):
if self.phase == TRAIN:
image=bottom[0].data
label=bottom[1].data
self.totalImgs += len(label)
for i in range(len(label)): #len(label) is equal to batch size
img=image[i].transpose(1, 2, 0) # Image: change from (K x H x W) to (H x W x K)
#if img.shape[2]==3: #if it is an RGB three channel image
img=img[:,:,(2,1,0)] # change from BGR to RGB
lbl=label[i].transpose(1, 2, 0) # Label: shape (1 x H x W)
lbl=lbl.reshape(lbl.shape[:-1])
im,lb=self.doShift_xy(img,lbl)
im=im[:, :, (0,1,2)].transpose(2, 0, 1) # Change the channel from (RGB to BGR) and change from (H x W x K) to (K x H x W)
lb=lb.reshape(lb.shape[0],lb.shape[1], 1)
lb=lb.transpose(2, 0, 1) # Change from (H x W x K) to (K x H x W)
print('successfully tested')
top[0].data[i,...]=im
top[1].data[i,...]=lb
elif self.phase ==TEST:
pass
def doShift_xy(self,img_,lbl_):
num_channel=img_.shape[2]
x_trans=random.randrange(-10,10)
y_trans=random.randrange(-10,10)
for i in range(0,num_channel): #apply on the all channels of an image
tmp=tf.warp(img_[:,:,i], AffineTransform(translation=(x_trans,y_trans)))
img_[:,:,i]=tmp
lbl_=tf.warp(lbl_, AffineTransform(translation=(x_trans,y_trans)))
return img_,lbl_
def backward(self,top,propagate_down,bottom):
pass
我收到以下错误,与以下内容有关numpy
:
I0207 02:45:33.556780 19447 net.cpp:84] Creating Layer label
I0207 02:45:33.556783 19447 net.cpp:380] label -> label
I0207 02:45:33.556872 19447 data_layer.cpp:45] output data size: 1,1,87,256
I0207 02:45:33.557867 19447 net.cpp:122] Setting up label
I0207 02:45:33.557878 19447 net.cpp:129] Top shape: 1 1 87 256 (22272)
I0207 02:45:33.557880 19447 net.cpp:137] Memory required for data: 356352
I0207 02:45:33.557883 19447 layer_factory.hpp:77] Creating layer aug_layer
Traceback (most recent call last):
File "/home/ubuntu/caffe/python/augmentLayer.py", line 22, in <module>
import numpy as np
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/site-packages/numpy/__init__.py", line 142, in <module>
from . import add_newdocs
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/site-packages/numpy/add_newdocs.py", line 13, in <module>
from numpy.lib import add_newdoc
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/site-packages/numpy/lib/__init__.py", line 8, in <module>
from .type_check import *
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/site-packages/numpy/lib/type_check.py", line 11, in <module>
import numpy.core.numeric as _nx
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/site-packages/numpy/core/__init__.py", line 74, in <module>
from numpy.testing import _numpy_tester
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/site-packages/numpy/testing/__init__.py", line 10, in <module>
from unittest import TestCase
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/unittest/__init__.py", line 64, in <module>
from .main import TestProgram, main
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/unittest/main.py", line 7, in <module>
from . import loader, runner
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/unittest/runner.py", line 7, in <module>
from .signals import registerResult
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/unittest/signals.py", line 2, in <module>
import weakref
File "/home/ubuntu/anaconda2/envs/testcaffe/lib/python2.7/weakref.py", line 14, in <module>
from _weakref import (
ImportError: cannot import name _remove_dead_weakref