我是一名高中生,在使用 PyTorch 和 LIME 方面没有太多经验。我的图像形状有很多问题。最初我的图像形状是 (3,224,224),但是 LIME 算法仅适用于这种形状的图像 (...,...,3)。结果,我之前尝试过转置图像。通过这样做,我似乎取得了一些进步,但是,现在我遇到了一个不同的错误。这是我的一些代码,用于了解在错误出现之前我一直在做什么。
def get_preprocess_transform():
transf = transforms.Compose([
# transforms.ToPILImage(), #had to convert image to PIL as error was showing up two cells below about needing it in pil
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
return transf
preprocess_transform = get_preprocess_transform() ## use your data_transform but in a method version
def batch_predict(image):
model_ft.eval()
batch = torch.reshape(image,(1,3,224,224))
print(type(batch))
logits = model_ft(batch)
probs = F.softmax(logits, dim=1)
return probs.detach().cpu().numpy()
print(img_t.shape)
img_t = torch.reshape(img_t,(1,3,224,224))
test_pred = batch_predict(img_t)
test_pred.squeeze().argmax()
img_t = np.ones((3, 224, 224))
np.transpose(img_t, (2,1,0)).shape
img_x = np.transpose(img_t, (2, 1, 0))
print(img_x.shape)
from lime import lime_image
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(img_x, ## pass your image, do not transform
batch_predict, # classification function
top_labels=5,
hide_color=0,
num_samples=1000)