0

我有一个这样的数组

sample = np.array([[9.99995470e-01],
                   [9.99992013e-01],
                   [1.00000000e+00],
                   [1.00000000e+00],
                   [1.00000000e+00],
                   [1.00000000e+00],
                   [9.99775827e-01],
                   [9.99439061e-01],
                   [9.98361528e-01],
                   [9.96853650e-01],
                   [1.00000000e+00],
                   [1.00000000e+00],
                   [1.00000000e+00],
                   [1.00000000e+00],
                   [1.00000000e+00],
                   [1.00000000e+00],
                   [9.99999762e-01]])

我想获得值 = 1 的最大索引,并且它连续出现超过 5 次。所以输出应该是索引号 15。

我想知道是否有一个简单的功能可以解决这个问题

4

6 回答 6

2

使用 groupby

代码

import numpy as np
from itertools import groupby

def find_max_index(arr):

  # consecutive runs of ones
  # Use enumerate so we have the index with each value
  run_ones = [list(v) for k, v in groupby(enumerate(sample.flatten()), lambda x: x[1]) if k == 1]

  # Sorting by length to insure that max is at end of the list of lists
  # Since this is a stable last item will still be the largest index
  run_ones.sort(key=len) 

  last_list = run_ones[-1]
  if len(last_list) > 5:        # need max to have at least a run of five
    return last_list[-1][0]     # index of last value in max run of ones
  else:
    return None

print(find_max_index(sample))

# Output: 15

解释

函数find_max_index

  1. groupby 保持子列表中的组运行。每个项目都是索引,值对(来自枚举)

    run_ones = [[(2, 1.0), (3, 1.0), (4, 1.0), (5, 1.0)], [(10, 1.0), (11, 1.0), (12, 1.0), (13 , 1.0), (14, 1.0), (15, 1.0)]]

  2. 排序列表以确保最大值位于末尾

    run_ones: [[(2, 1.0), (3, 1.0), (4, 1.0), (5, 1.0)], [(10, 1.0), (11, 1.0), (12, 1.0), (13 , 1.0), (14, 1.0), (15, 1.0)]]

  3. 包含运行的最后一个列表

    last_list: [(10, 1.0), (11, 1.0), (12, 1.0), (13, 1.0), (14, 1.0), (15, 1.0)]

  4. last_list 中最后一个元组的索引

    最后一个列表[-1][0]

于 2020-05-12T11:38:02.450 回答
1

这应该为您提供 1 在一组 5 中最后出现的索引。

输入:

max([index for index, window in enumerate(windowed(sample, 5)) if list(window) == [1]*5]) + 4

输出:

15
于 2020-05-12T11:38:46.137 回答
1

这是一个可以为您解决问题的功能

def find_repeated_index(sample, min_value, min_repeats):
  max_index = -1
  history   = []
  for index, value in enumerate(np.array(sample).flatten()):
    if value >= min_value: 
        history.append(value)
        if len(history) >= min_repeats: max_index = index
    else:
        if len(history) >= min_repeats: break                  
        history = []
  return max_index

find_repeated_index(sample, 1.0, 5)
15

find_repeated_index(sample, 1.0, 4)
5
于 2020-05-12T11:41:09.750 回答
1

以下是如何使用 O(n) 运行时复杂度且不分配额外内存(不包括展平和列表转换)来解决此问题。

def find_last_index_of_longest_window(array, window_value):

    if len(array) <= 0:
        return -1

    if len(array) == 1:
        return 0 if array[0] == window_value else -1

    max_length = 0
    length = 0

    for i, value in enumerate(array):
        if value == window_value:
            length += 1
        else:
            if length >= max_length:
                max_length = length
                max_index = i - 1
                length = 0

    if length > max_length:
        max_length = length
        max_index = i

    return max_index


print(find_last_index_of_longest_window(sample.flatten().tolist(), 1.0))

更新:如果您想避免展平和转换为列表:

def find_last_index_of_longest_window(array, window_value):

    if len(array) <= 0:
        return -1

    if len(array) == 1:
        return 0 if array[0][0] == window_value else -1

    max_length = 0
    length = 0

    for i, item in enumerate(array):
        value = item[0]
        if value == window_value:
            length += 1
        else:
            if length >= max_length:
                max_length = length
                max_index = i - 1
                length = 0

    if length > max_length:
        max_length = length
        max_index = i

    return max_index


print(find_last_index_of_longest_window(sample, 1.0))
于 2020-05-12T12:03:21.833 回答
0

大型数组的快速分析表明,以下基于在 NumPy 数组中计数连续 1 的代码的解决方案将明显快于此处介绍的其他解决方案:

import numpy as np


def group_cumsum(a):
    """Taken from https://stackoverflow.com/a/42129610"""
    a_ext = np.concatenate(([0], a, [0]))
    idx = np.flatnonzero(a_ext[1:] != a_ext[:-1])
    a_ext[1:][idx[1::2]] = idx[::2] - idx[1::2]
    return a_ext.cumsum()[1:-1]


array = sample[:, 0]
value = 1
n = 5

mask = array == value
cumsums = group_cumsum(mask)
if not np.any(cumsums > n):
    print(f"No more than {n} consecutive {value}'s are found in the array.")
else:
    index = len(sample) - np.argmax(cumsums[::-1] > n) - 1
    print(index)  # returns 15 for your example
于 2020-05-12T16:07:42.173 回答
0

基于此片段

def find_runs(x):
    """Find runs of consecutive items in an array."""

    # ensure array
    x = np.asanyarray(x)
    if x.ndim != 1:
        raise ValueError('only 1D array supported')
    n = x.shape[0]

    # handle empty array
    if n == 0:
        return np.array([]), np.array([]), np.array([])

    else:
        # find run starts
        loc_run_start = np.empty(n, dtype=bool)
        loc_run_start[0] = True
        np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
        run_starts = np.nonzero(loc_run_start)[0]

        # find run values
        run_values = x[loc_run_start]

        # find run lengths
        run_lengths = np.diff(np.append(run_starts, n))

        return run_values, run_starts, run_lengths

# Part added by me

values,indices,lengths = find_runs(sample.flatten())
ones = np.where(values==1)
fiveormore = np.where(lengths[ones]>=5)
r = indices[ones][fiveormore]
last_indices = r + lengths[ones][fiveormore] - 1 

last_indices变量将是数组中每个 5 个或更长的连续部分的最后一个索引的数组,其中值为 1。获取这些索引中的最后一个只是一个last_indices[-1]调用。如果没有这样的索引,则数组将为空。

于 2020-05-12T12:17:04.950 回答