不要遍历所有像素,而是遍历所有颜色:
import pylab as pl
palette = pl.array([[180, 0, 0], [255, 150, 0], [255, 200, 0], [0, 128, 0]])
img = pl.imread('lena.jpg')[:, :, :3].astype('float')
R, G, B = img[:, :, 0].copy(), img[:, :, 1].copy(), img[:, :, 2].copy()
dist = pl.inf * R
for i in range(len(palette)):
new_dist = pl.square(img[:, :, 0] - palette[i, 0]) \
+ pl.square(img[:, :, 1] - palette[i, 1]) \
+ pl.square(img[:, :, 2] - palette[i, 2])
R[new_dist < dist] = palette[i, 0]
G[new_dist < dist] = palette[i, 1]
B[new_dist < dist] = palette[i, 2]
dist = pl.minimum(dist, new_dist)
pl.clf()
pl.subplot(1, 2, 1)
pl.imshow(img.astype('uint8'))
pl.subplot(1, 2, 2)
pl.imshow(pl.dstack((R, G, B)))
编辑:无循环替代方案。;)
import pylab as pl
palette = pl.array([[180, 0 , 0], [255, 150, 0], [255, 200, 0], [0, 128, 0]])
img = pl.imread('lena.jpg')[:, :, :3]
pl.clf()
pl.subplot(1, 2, 1)
pl.imshow(img)
IMG = img.reshape((512, 512, 3, 1))
PAL = palette.transpose().reshape((1, 1, 3, -1))
idx = pl.argmin(pl.sum((IMG - PAL)**2, axis=2), axis=2)
img = palette[idx, :]
pl.subplot(1, 2, 2)
pl.imshow(img)