5

I'm trying to work out how to speed up a Python function which uses numpy. The output I have received from lineprofiler is below, and this shows that the vast majority of the time is spent on the line ind_y, ind_x = np.where(seg_image == i).

seg_image is an integer array which is the result of segmenting an image, thus finding the pixels where seg_image == i extracts a specific segmented object. I am looping through lots of these objects (in the code below I'm just looping through 5 for testing, but I'll actually be looping through over 20,000), and it takes a long time to run!

Is there any way in which the np.where call can be speeded up? Or, alternatively, that the penultimate line (which also takes a good proportion of the time) can be speeded up?

The ideal solution would be to run the code on the whole array at once, rather than looping, but I don't think this is possible as there are side-effects to some of the functions I need to run (for example, dilating a segmented object can make it 'collide' with the next region and thus give incorrect results later on).

Does anyone have any ideas?

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     5                                           def correct_hot(hot_image, seg_image):
     6         1       239810 239810.0      2.3      new_hot = hot_image.copy()
     7         1       572966 572966.0      5.5      sign = np.zeros_like(hot_image) + 1
     8         1        67565  67565.0      0.6      sign[:,:] = 1
     9         1      1257867 1257867.0     12.1      sign[hot_image > 0] = -1
    10                                           
    11         1          150    150.0      0.0      s_elem = np.ones((3, 3))
    12                                           
    13                                               #for i in xrange(1,seg_image.max()+1):
    14         6           57      9.5      0.0      for i in range(1,6):
    15         5      6092775 1218555.0     58.5          ind_y, ind_x = np.where(seg_image == i)
    16                                           
    17                                                   # Get the average HOT value of the object (really simple!)
    18         5         2408    481.6      0.0          obj_avg = hot_image[ind_y, ind_x].mean()
    19                                           
    20         5          333     66.6      0.0          miny = np.min(ind_y)
    21                                                   
    22         5          162     32.4      0.0          minx = np.min(ind_x)
    23                                                   
    24                                           
    25         5          369     73.8      0.0          new_ind_x = ind_x - minx + 3
    26         5          113     22.6      0.0          new_ind_y = ind_y - miny + 3
    27                                           
    28         5          211     42.2      0.0          maxy = np.max(new_ind_y)
    29         5          143     28.6      0.0          maxx = np.max(new_ind_x)
    30                                           
    31                                                   # 7 is + 1 to deal with the zero-based indexing, + 2 * 3 to deal with the 3 cell padding above
    32         5          217     43.4      0.0          obj = np.zeros( (maxy+7, maxx+7) )
    33                                           
    34         5          158     31.6      0.0          obj[new_ind_y, new_ind_x] = 1
    35                                           
    36         5         2482    496.4      0.0          dilated = ndimage.binary_dilation(obj, s_elem)
    37         5         1370    274.0      0.0          border = mahotas.borders(dilated)
    38                                           
    39         5          122     24.4      0.0          border = np.logical_and(border, dilated)
    40                                           
    41         5          355     71.0      0.0          border_ind_y, border_ind_x = np.where(border == 1)
    42         5          136     27.2      0.0          border_ind_y = border_ind_y + miny - 3
    43         5          123     24.6      0.0          border_ind_x = border_ind_x + minx - 3
    44                                           
    45         5          645    129.0      0.0          border_avg = hot_image[border_ind_y, border_ind_x].mean()
    46                                           
    47         5      2167729 433545.8     20.8          new_hot[seg_image == i] = (new_hot[ind_y, ind_x] + (sign[ind_y, ind_x] * np.abs(obj_avg - border_avg)))
    48         5        10179   2035.8      0.1          print obj_avg, border_avg
    49                                           
    50         1            4      4.0      0.0      return new_hot
4

2 回答 2

4

EDIT I have left my original answer at the bottom for the record, but I have actually looked into your code in more detail over lunch, and I think that using np.where is a big mistake:

In [63]: a = np.random.randint(100, size=(1000, 1000))

In [64]: %timeit a == 42
1000 loops, best of 3: 950 us per loop

In [65]: %timeit np.where(a == 42)
100 loops, best of 3: 7.55 ms per loop

You could get a boolean array (that you can use for indexing) in 1/8 of the time you need to get the actual coordinates of the points!!!

There is of course the cropping of the features that you do, but ndimage has a find_objects function that returns enclosing slices, and appears to be very fast:

In [66]: %timeit ndimage.find_objects(a)
100 loops, best of 3: 11.5 ms per loop

This returns a list of tuples of slices enclosing all of your objects, in 50% more time thn it takes to find the indices of one single object.

It may not work out of the box as I cannot test it right now, but I would restructure your code into something like the following:

def correct_hot_bis(hot_image, seg_image):
    # Need this to not index out of bounds when computing border_avg
    hot_image_padded = np.pad(hot_image, 3, mode='constant',
                              constant_values=0)
    new_hot = hot_image.copy()
    sign = np.ones_like(hot_image, dtype=np.int8)
    sign[hot_image > 0] = -1
    s_elem = np.ones((3, 3))

    for j, slice_ in enumerate(ndimage.find_objects(seg_image)):
        hot_image_view = hot_image[slice_]
        seg_image_view = seg_image[slice_]
        new_shape = tuple(dim+6 for dim in hot_image_view.shape)
        new_slice = tuple(slice(dim.start,
                                dim.stop+6,
                                None) for dim in slice_)
        indices = seg_image_view == j+1

        obj_avg = hot_image_view[indices].mean()

        obj = np.zeros(new_shape)
        obj[3:-3, 3:-3][indices] = True

        dilated = ndimage.binary_dilation(obj, s_elem)
        border = mahotas.borders(dilated)
        border &= dilated

        border_avg = hot_image_padded[new_slice][border == 1].mean()

        new_hot[slice_][indices] += (sign[slice_][indices] *
                                     np.abs(obj_avg - border_avg))

    return new_hot

You would still need to figure out the collisions, but you could get about a 2x speed-up by computing all the indices simultaneously using a np.unique based approach:

a = np.random.randint(100, size=(1000, 1000))

def get_pos(arr):
    pos = []
    for j in xrange(100):
        pos.append(np.where(arr == j))
    return pos

def get_pos_bis(arr):
    unq, flat_idx = np.unique(arr, return_inverse=True)
    pos = np.argsort(flat_idx)
    counts = np.bincount(flat_idx)
    cum_counts = np.cumsum(counts)
    multi_dim_idx = np.unravel_index(pos, arr.shape)
    return zip(*(np.split(coords, cum_counts) for coords in multi_dim_idx))

In [33]: %timeit get_pos(a)
1 loops, best of 3: 766 ms per loop

In [34]: %timeit get_pos_bis(a)
1 loops, best of 3: 388 ms per loop

Note that the pixels for each object are returned in a different order, so you can't simply compare the returns of both functions to assess equality. But they should both return the same.

于 2013-07-08T18:13:34.743 回答
2

One thing you could do to same a little bit of time is to save the result of seg_image == i so that you don't need to compute it twice. You're computing it on lines 15 & 47, you could add seg_mask = seg_image == i and then reuse that result (It might also be good to separate out that piece for profiling purposes).

While there a some other minor things that you could do to eke out a little bit of performance, the root issue is that you're using a O(M * N) algorithm where M is the number of segments and N is the size of your image. It's not obvious to me from your code whether there is a faster algorithm to accomplish the same thing, but that's the first place I'd try and look for a speedup.

于 2013-07-08T17:08:25.857 回答