17

我正在练习算法,我的任务之一是计算给定0 < n <= 10^6数字的所有最长递增子序列的数量。解决方案O(n^2)不是一个选项。

我已经实现了查找 LIS 及其长度(LIS 算法),但该算法将数字切换到尽可能低的值。因此,不可能确定具有先前数字(较大的数字)的子序列是否能够达到最长的长度,否则我猜我可以只计算那些开关。

关于O(nlogn)的任何想法?我知道应该使用动态编程来解决它。

我实现了一个解决方案,它运行良好,但它需要两个嵌套循环(i in 1..n) x (j in 1..i-1)
所以我认为它是O(n^2),但它太慢了。

我什至尝试将这些数字从数组移动到二叉树(因为在每次i迭代中,我都会查找所有较小的数字然后number[i] - 遍历元素i-1..1),但它甚至更慢。

示例测试:

1 3 2 2 4
result: 3 (1,3,4 | 1,2,4 | 1,2,4)

3 2 1
result: 3 (1 | 2 | 3)

16 5 8 6 1 10 5 2 15 3 2 4 1
result: 3 (5,8,10,15 | 5,6,10,15 | 1,2,3,4)
4

5 回答 5

24

查找所有最长递增子序列的数量

下面是改进的 LIS 算法的完整 Java 代码,它不仅发现了最长递增子序列的长度,而且发现了这种长度的子序列的数量。我更喜欢使用泛型不仅允许整数,还允许任何可比较的类型。

@Test
public void testLisNumberAndLength() {

    List<Integer> input = Arrays.asList(16, 5, 8, 6, 1, 10, 5, 2, 15, 3, 2, 4, 1);
    int[] result = lisNumberAndlength(input);
    System.out.println(String.format(
            "This sequence has %s longest increasing subsequenses of length %s", 
            result[0], result[1]
            ));
}


/**
 * Body of improved LIS algorithm
 */
public <T extends Comparable<T>> int[] lisNumberAndLength(List<T> input) {

    if (input.size() == 0) 
        return new int[] {0, 0};

    List<List<Sub<T>>> subs = new ArrayList<>();
    List<Sub<T>> tails = new ArrayList<>();

    for (T e : input) {
        int pos = search(tails, new Sub<>(e, 0), false);      // row for a new sub to be placed
        int sum = 1;
        if (pos > 0) {
            List<Sub<T>> pRow = subs.get(pos - 1);            // previous row
            int index = search(pRow, new Sub<T>(e, 0), true); // index of most left element that <= e
            if (pRow.get(index).value.compareTo(e) < 0) {
                index--;
            } 
            sum = pRow.get(pRow.size() - 1).sum;              // sum of tail element in previous row
            if (index >= 0) {
                sum -= pRow.get(index).sum;
            }
        }

        if (pos >= subs.size()) {                             // add a new row
            List<Sub<T>> row = new ArrayList<>();
            row.add(new Sub<>(e, sum));
            subs.add(row);
            tails.add(new Sub<>(e, 0));

        } else {                                              // add sub to existing row
            List<Sub<T>> row = subs.get(pos);
            Sub<T> tail = row.get(row.size() - 1); 
            if (tail.value.equals(e)) {
                tail.sum += sum;
            } else {
                row.add(new Sub<>(e, tail.sum + sum));
                tails.set(pos, new Sub<>(e, 0));
            }
        }
    }

    List<Sub<T>> lastRow = subs.get(subs.size() - 1);
    Sub<T> last = lastRow.get(lastRow.size() - 1);
    return new int[]{last.sum, subs.size()};
}



/**
 * Implementation of binary search in a sorted list
 */
public <T> int search(List<? extends Comparable<T>> a, T v, boolean reversed) {

    if (a.size() == 0)
        return 0;

    int sign = reversed ? -1 : 1;
    int right = a.size() - 1;

    Comparable<T> vRight = a.get(right);
    if (vRight.compareTo(v) * sign < 0)
        return right + 1;

    int left = 0;
    int pos = 0;
    Comparable<T> vPos;
    Comparable<T> vLeft = a.get(left);

    for(;;) {
        if (right - left <= 1) {
            if (vRight.compareTo(v) * sign >= 0 && vLeft.compareTo(v) * sign < 0) 
                return right;
            else 
                return left;
        }
        pos = (left + right) >>> 1;
        vPos = a.get(pos);
        if (vPos.equals(v)) {
            return pos;
        } else if (vPos.compareTo(v) * sign > 0) {
            right = pos;
            vRight = vPos;
        } else {
            left = pos;
            vLeft = vPos;
        }
    } 
}



/**
 * Class for 'sub' pairs
 */
public static class Sub<T extends Comparable<T>> implements Comparable<Sub<T>> {

    T value;
    int sum;

    public Sub(T value, int sum) { 
        this.value = value; 
        this.sum = sum; 
    }

    @Override public String toString() {
        return String.format("(%s, %s)", value, sum); 
    }

    @Override public int compareTo(Sub<T> another) { 
        return this.value.compareTo(another.value); 
    }
}

解释

由于我的解释似乎很长,我将初始序列称为“seq”及其任何子序列“sub”。所以任务是计算可以从 seq 中获得的最长增加子的计数。

正如我之前提到的,想法是记录在前面的步骤中获得的所有可能的最长潜艇。因此,让我们创建一个编号的行列表,其中每行的数量等于存储在该行中的子项的长度。让我们将 subs 存储为数字对 (v, c),其中“v”是结束元素的“值”,“c”是给定长度的以“v”结尾的 subs 的“计数”。例如:

1: (16, 1) // that means that so far we have 1 sub of length 1 which ends by 16.

我们将逐步构建这样的列表,按顺序从初始序列中获取元素。在每一步中,我们都会尝试将此元素添加到可以添加的最长子中并记录更改。

建立一个列表

让我们使用您示例中的序列构建列表,因为它具有所有可能的选项:

 16 5 8 6 1 10 5 2 15 3 2 4 1

首先,取元素16。到目前为止,我们的列表是空的,所以我们只放了一对:

1: (16, 1) <= one sub that ends by 16

接下来是5。它不能添加到以 16 结尾的 sub,因此它将创建长度为 1 的新 sub。我们创建一对 (5, 1) 并将其放入第 1 行:

1: (16, 1)(5, 1)

元素8即将到来。它不能创建长度为 2 的子 [16, 8],但可以创建子 [5, 8]。所以,这就是算法来的地方。首先,我们颠倒迭代列表行,查看最后一对的“值”。如果我们的元素大于所有行中所有最后一个元素的值,那么我们可以将其添加到现有的 sub(s) 中,将其长度增加一。所以值 8 将创建列表的新行,因为它大于迄今为止列表中存在的所有最后一个元素的值(即 > 5):

1: (16, 1)(5, 1) 
2: (8, ?)   <=== need to resolve how many longest subs ending by 8 can be obtained

元素 8 可以继续 5,但不能继续 16。所以我们需要搜索上一行,从它的末尾开始,计算“值”小于 8 的成对“计数”之和:

(16, 1)(5, 1)^  // sum = 0
(16, 1)^(5, 1)  // sum = 1
^(16, 1)(5, 1)  // value 16 >= 8: stop. count = sum = 1, so write 1 in pair next to 8

1: (16, 1)(5, 1)
2: (8, 1)  <=== so far we have 1 sub of length 2 which ends by 8.

为什么我们不将值 8 存储到长度为 1(第一行)的 subs 中?因为我们需要最大可能长度的 subs,并且 8 可以继续一些以前的 subs。因此,每个大于 8 的下一个数字也将继续这样的 sub,并且没有必要保持 8 作为 sub 的长度小于它可以的长度。

下一个。6 . 按行中的最后一个“值”倒置搜索:

1: (16, 1)(5, 1)  <=== 5 < 6, go next
2: (8, 1)

1: (16, 1)(5, 1)
2: (8, 1 )  <=== 8 >= 6, so 6 should be put here

找到6个房间,需要计算一个数:

take previous line
(16, 1)(5, 1)^  // sum = 0
(16, 1)^(5, 1)  // 5 < 6: sum = 1
^(16, 1)(5, 1)  // 16 >= 6: stop, write count = sum = 1

1: (16, 1)(5, 1)
2: (8, 1)(6, 1) 

处理后1

1: (16, 1)(5, 1)(1, 1) <===
2: (8, 1)(6, 1)

处理后10

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)
3: (10, 2) <=== count is 2 because both "values" 8 and 6 from previous row are less than 10, so we summarized their "counts": 1 + 1

处理后5

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1) <===
3: (10, 2)

处理后2

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 1) <===
3: (10, 2)

处理后15

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 1)
3: (10, 2)
4: (15, 2) <===

处理后3

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 1)
3: (10, 2)(3, 1) <===
4: (15, 2)  

处理后2

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 2) <===
3: (10, 2)(3, 1) 
4: (15, 2)  

如果在按最后一个元素搜索行时,我们找到相等的元素,我们会根据前一行再次计算其“计数”,并添加到现有的“计数”。

处理后4

1: (16, 1)(5, 1)(1, 1)
2: (8, 1)(6, 1)(5, 1)(2, 2)  
3: (10, 2)(3, 1) 
4: (15, 2)(4, 1) <===

处理后1

1: (16, 1)(5, 1)(1, 2) <===
2: (8, 1)(6, 1)(5, 1)(2, 2)  
3: (10, 2)(3, 1) 
4: (15, 2)(4, 1)  

那么在处理完所有初始序列后我们有什么?查看最后一行,我们看到我们有 3 个最长的子,每个子由 4 个元素组成:2 个以 15 结尾,1 个以 4 结尾。

复杂性呢?

在每次迭代中,当从初始序列中获取下一个元素时,我们会进行 2 次循环:第一次是迭代行以查找下一个元素的空间,第二次是汇总前一行中的计数。因此,对于每个元素,我们最多进行n次迭代(最坏的情况:如果初始 seq 由按升序排列的元素组成,我们将得到一个 n 行列表,每行有 1 对;如果 seq 按降序排序,我们将获得包含 n 个元素的 1 行列表)。顺便说一句,O(n 2 )复杂度不是我们想要的。

首先,这很明显,在每个中间状态中,行都按其最后一个“值”的升序排序。因此,可以执行二进制搜索而不是暴力循环,其复杂度为 O(log n)。

其次,我们不需要通过每次循环遍历行元素来总结 subs 的“计数”。当新的对添加到行中时,我们可以在过程中总结它们,例如:

1: (16, 1)(5, 2) <=== instead of 1, put 1 + "count" of previous element in the row

因此,第二个数字将不显示可以在最后以给定值获得的最长子项的计数,而是任何大于或等于该对“值”的元素结尾的所有最长子项的汇总计数。

因此,“计数”将被“总和”取代。而不是迭代前一行中的元素,我们只执行二进制搜索(这是可能的,因为任何行中的对总是按它们的“值”排序)并将新对的“总和”作为前一行中最后一个元素的“总和”从左边的元素减去“sum”到前一行中找到的位置加上当前行中前一个元素的“sum”。

所以在处理4时:

1: (16, 1)(5, 2)(1, 3)
2: (8, 1)(6, 2)(5, 3)(2, 5) 
3: (10, 2)(3, 3) 
4: (15, 2) <=== room for (4, ?)

search in row 3 by "values" < 4:
3: (10, 2)^(3, 3) 

4 将与 (3-2+2) 配对:(前一行最后一对的“sum”)-(前一行左侧对的“sum”)+(当前前一对的“sum”排):

4: (15, 2)(4, 3)

在这种情况下,所有最长子项的最终计数是列表最后一行的最后一对的“总和”,即 3,而不是 3 + 2。

因此,对行搜索和求和搜索执行二进制搜索,我们将得到 O(n*log n) 复杂度。

内存消耗怎么样,在处理完所有数组后,我们获得最大 n 对,所以动态数组的内存消耗将是 O(n)。此外,当使用动态数组或集合时,需要一些额外的时间来分配和调整它们的大小,但大多数操作都是在 O(1) 时间内完成的,因为我们在处理过程中没有进行任何类型的排序和重新排列。所以复杂性估计似乎是最终的。

于 2014-04-08T18:30:28.020 回答
0

Sasha Salauyou 的回答很好,但我不清楚为什么

sum -= pRow.get(index).sum;

这是我基于相同想法的代码

import java.math.BigDecimal;
import java.util.*;

class lisCount {
  static BigDecimal lisCount(int[] a) {
    class Container {
      Integer    v;
      BigDecimal count;

      Container(Integer v) {
        this.v = v;
      }
    }
    List<List<Container>> lisIdxSeq = new ArrayList<List<Container>>();
    int lisLen, lastIdx;
    List<Container> lisSeqL;
    Container lisEle;
    BigDecimal count;
    int pre;
    for (int i = 0; i < a.length; i++){
      pre = -1;
      count = new BigDecimal(1);
      lisLen = lisIdxSeq.size();
      lastIdx = lisLen - 1;
      lisEle = new Container(i);
      if(lisLen == 0 || a[i] > a[lisIdxSeq.get(lastIdx).get(0).v]){
        // lis len increased
        lisSeqL = new ArrayList<Container>();
        lisSeqL.add(lisEle);
        lisIdxSeq.add(lisSeqL);
        pre = lastIdx;
      }else{
        int h = lastIdx;
        int l = 0;

        while(l < h){
          int m = (l + h) / 2;
          if(a[lisIdxSeq.get(m).get(0).v] < a[i]) l = m + 1;
          else h = m;
        }

        List<Container> lisSeqC = lisIdxSeq.get(l);
        if(a[i] <= a[lisSeqC.get(0).v]){
          int hi = lisSeqC.size() - 1;
          int lo = 0;
          while(hi < lo){
            int mi = (hi + lo) / 2;
            if(a[lisSeqC.get(mi).v] < a[i]) lo = mi + 1;
            else hi = mi;
          }
          lisSeqC.add(lo, lisEle);
          pre = l - 1;
        }
      }
      if(pre >= 0){
        Iterator<Container> it = lisIdxSeq.get(pre).iterator();
        count = new BigDecimal(0);
        while(it.hasNext()){
          Container nt = it.next();
          if(a[nt.v] < a[i]){
            count = count.add(nt.count);
          }else break;
        }
      }
      lisEle.count = count;
    }

    BigDecimal rst = new BigDecimal(0);
    Iterator<Container> i = lisIdxSeq.get(lisIdxSeq.size() - 1).iterator();
    while(i.hasNext()){
      rst = rst.add(i.next().count);
    }
    return rst;
  }

  public static void main(String[] args) {
    System.out.println(lisCount(new int[] { 1, 3, 2, 2, 4 }));
    System.out.println(lisCount(new int[] { 3, 2, 1 }));
    System.out.println(lisCount(new int[] { 16, 5, 8, 6, 1, 10, 5, 2, 15, 3, 2, 4, 1 }));
  }
}
于 2015-11-20T23:15:37.523 回答
0

耐心排序也是 O(N*logN),但比基于二分查找的方法更短更简单:

static int[] input = {4, 5, 2, 8, 9, 3, 6, 2, 7, 8, 6, 6, 7, 7, 3, 6};

/**
 * Every time a value is tested it either adds to the length of LIS (by calling decs.add() with it), or reduces the remaining smaller cards that must be found before LIS consists of smaller cards. This way all inputs/cards contribute in one way or another (except if they're equal to the biggest number in the sequence; if want't to include in sequence, replace 'card <= decs.get(decIndex)' with 'card < decs.get(decIndex)'. If they're bigger than all decs, they add to the length of LIS (which is something we want), while if they're smaller than a dec, they replace it. We want this, because the smaller the biggest dec is, the smaller input we need before we can add onto LIS.
 *
 * If we run into a decreasing sequence the input from this sequence will replace each other (because they'll always replace the leftmost dec). Thus this algorithm won't wrongfully register e.g. {2, 1, 3} as {2, 3}, but rather {2} -> {1} -> {1, 3}.
 *
 * WARNING: This can only be used to find length, not actual sequence, seeing how parts of the sequence will be replaced by smaller numbers trying to make their sequence dominate
 *
 * Due to bigger decs being added to the end/right of 'decs' and the leftmost decs always being the first to be replaced with smaller decs, the further a dec is to the right (the bigger it's index), the bigger it must be. Thus, by always replacing the leftmost decs, we don't run the risk of replacing the biggest number in a sequence (the number which determines if more cards can be added to that sequence) before a sequence with the same length but smaller numbers (thus currently equally good, due to length, and potentially better, due to less needed to increase length) has been found.
 */
static void patienceFindLISLength() {
    ArrayList<Integer> decs = new ArrayList<>();
    inputLoop: for (Integer card : input) {
        for (int decIndex = 0; decIndex < decs.size(); decIndex++) {
            if (card <= decs.get(decIndex)) {
                decs.set(decIndex, card);
                continue inputLoop;
            }
        }
        decs.add(card);
    }
    System.out.println(decs.size());
}
于 2016-02-16T12:47:11.200 回答
0

上述逻辑的cpp实现:

#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define pob pop_back
#define pll pair<ll, ll>
#define pii pair<int, int>
#define ll long long
#define ull unsigned long long
#define fori(a,b) for(i=a;i<b;i++)
#define forj(a,b) for(j=a;j<b;j++)
#define fork(a,b) for(k=a;k<b;k++)
#define forl(a,b) for(l=a;l<b;l++)
#define forir(a,b) for(i=a;i>=b;i--)
#define forjr(a,b) for(j=a;j>=b;j--)
#define mod 1000000007
#define boost std::ios::sync_with_stdio(false)

struct comp_pair_int_rev
{
    bool operator()(const pair<int,int> &a, const int & b)
    {
        return (a.first > b);
    }
    bool operator()(const int & a,const pair<int,int> &b)
    {
        return (a > b.first);
    }
};

struct comp_pair_int
{
    bool operator()(const pair<int,int> &a, const int & b)
    {
        return (a.first < b);
    }
    bool operator()(const int & a,const pair<int,int> &b)
    {
        return (a < b.first);
    }
};

int main()
{
    int n,i,mx=0,p,q,r,t;
    cin>>n;

    int a[n];
    vector<vector<pii > > v(100005);
    vector<pii > v1(100005);

    fori(0,n)
    cin>>a[i];

    v[1].pb({a[0], 1} );
    v1[1]= {a[0], 1};

    mx=1;
    fori(1,n)
    {
        if(a[i]<=v1[1].first)
        {
            r=v1[1].second;

            if(v1[1].first==a[i])
                v[1].pob();

            v1[1]= {a[i], r+1};
            v[1].pb({a[i], r+1});
        }
        else if(a[i]>v1[mx].first)
        {
            q=upper_bound(v[mx].begin(), v[mx].end(), a[i], comp_pair_int_rev() )-v[mx].begin();
            if(q==0)
            {
                r=v1[mx].second;
            }
            else
            {
                r=v1[mx].second-v[mx][q-1].second;
            }

            v1[++mx]= {a[i], r};
            v[mx].pb({a[i], r});
        }
        else if(a[i]==v1[mx].first)
        {
            q=upper_bound(v[mx-1].begin(), v[mx-1].end(), a[i], comp_pair_int_rev() )-v[mx-1].begin();
            if(q==0)
            {
                r=v1[mx-1].second;
            }
            else
            {
                r=v1[mx-1].second-v[mx-1][q-1].second;
            }
            p=v1[mx].second;
            v1[mx]= {a[i], p+r};

            v[mx].pob();
            v[mx].pb({a[i], p+r});


        }
        else
        {
            p=lower_bound(v1.begin()+1, v1.begin()+mx+1, a[i], comp_pair_int() )-v1.begin();
            t=v1[p].second;

            if(v1[p].first==a[i])
            {

                v[p].pob();
            }

            q=upper_bound(v[p-1].begin(), v[p-1].end(), a[i], comp_pair_int_rev() )-v[p-1].begin();
            if(q==0)
            {
                r=v1[p-1].second;
            }
            else
            {
                r=v1[p-1].second-v[p-1][q-1].second;
            }

            v1[p]= {a[i], t+r};
            v[p].pb({a[i], t+r});

        }


    }

    cout<<v1[mx].second;

    return 0;
}
于 2018-01-31T02:14:46.060 回答
0

尽管我完全同意 Alex 的观点,但使用 Segment tree 可以很容易地做到这一点。这是在 NlogN 中使用段树查找 LIS 长度的逻辑。 https://www.quora.com/What-is-the-approach-to-find-the-length-of-the-strictly-increasing-longest-subsequence 这是一种找不到 LIS 但需要 N^ 的方法2 复杂性。 https://codeforces.com/blog/entry/48677

我们使用段树(如此处使用)来优化本文中给出的方法。这是逻辑:

首先按升序对数组进行排序(也保持原始顺序),用零初始化段树,段树应该查询给定范围的两件事(为此使用对):a。最大的第一个。湾。对应于 max-first 的秒的总和。遍历排序数组。设 j 为当前元素的原始索引,然后我们查询 (0 - j-1) 并更新第 j 个元素(如果查询结果是 0,0 则我们用 (1,1) 更新它)。

这是我在 C++ 中的代码:

#include<bits/stdc++.h>
#define tr(container, it) for(typeof(container.begin()) it = container.begin(); it != container.end(); it++)
#define ll          long long
#define pb          push_back
#define endl        '\n'
#define pii         pair<ll int,ll int>
#define vi          vector<ll int>
#define all(a)      (a).begin(),(a).end()
#define F           first
#define S           second
#define sz(x)       (ll int)x.size()
#define hell        1000000007
#define rep(i,a,b)  for(ll int i=a;i<b;i++)
#define lbnd        lower_bound
#define ubnd        upper_bound
#define bs          binary_search
#define mp          make_pair
using namespace std;

#define N  100005

ll max(ll a , ll b)

{
    if( a > b) return a ;
    else return
         b;
}
ll n,l,r;
vector< pii > seg(4*N);

pii query(ll cur,ll st,ll end,ll l,ll r)
{
    if(l<=st&&r>=end)
    return seg[cur];
    if(r<st||l>end)
    return mp(0,0);                           /*  2-change here  */
    ll mid=(st+end)>>1;
    pii ans1=query(2*cur,st,mid,l,r);
    pii ans2=query(2*cur+1,mid+1,end,l,r);
    if(ans1.F>ans2.F)
        return ans1;
    if(ans2.F>ans1.F)
        return ans2;

    return make_pair(ans1.F,ans2.S+ans1.S);                 /*  3-change here  */
}
void update(ll cur,ll st,ll end,ll pos,ll upd1, ll upd2)
{
    if(st==end)
    {
        // a[pos]=upd;                  /*  4-change here  */
        seg[cur].F=upd1;    
        seg[cur].S=upd2;            /*  5-change here  */
        return;
    }
    ll mid=(st+end)>>1;
    if(st<=pos&&pos<=mid)
        update(2*cur,st,mid,pos,upd1,upd2);
    else
        update(2*cur+1,mid+1,end,pos,upd1,upd2);
    seg[cur].F=max(seg[2*cur].F,seg[2*cur+1].F);


    if(seg[2*cur].F==seg[2*cur+1].F)
        seg[cur].S = seg[2*cur].S+seg[2*cur+1].S;
    else
    {
        if(seg[2*cur].F>seg[2*cur+1].F)
            seg[cur].S = seg[2*cur].S;
        else
            seg[cur].S = seg[2*cur+1].S;
        /*  6-change here  */
    }
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int TESTS=1;
//  cin>>TESTS;
    while(TESTS--)
    {
        int n ;
        cin >> n;
        vector< pii > arr(n);
        rep(i,0,n)
        {
            cin >> arr[i].F;
            arr[i].S = -i;
        }

        sort(all(arr));
        update(1,0,n-1,-arr[0].S,1,1);
        rep(i,1,n)
        {
            pii x = query(1,0,n-1,-1,-arr[i].S - 1 );
            update(1,0,n-1,-arr[i].S,x.F+1,max(x.S,1));

        }

        cout<<seg[1].S;//answer



    }
    return 0;
}
于 2018-12-12T21:23:31.623 回答