longest increasing subsequence
11 Dec 2016tl;dr
This post discusses the $O(n^2)$, $O(n log(n))$ complexity methods to find the length of longest increasing subsequence (LIS), and the way to recover the subsequence.
the naive dynamic programming
… which has $O(n^2)$ time complexity. dp[i]
is the max length of LIS ending with nums[i]
. Thus we have dp[i] = max(dp[j], 1) for 0 <= j < i and nums[j] < nums[i]
.
import java.util.*;
public class Solution {
public int lengthOfLIS(int[] nums) {
// time O(n^2)
// space O(n)
int[] dp = new int[nums.length];
// dp[i] is the max length of LIS ending with nums[i]
Arrays.fill(dp, 1); // attention here!
int res = 0;
for (int i = 0; i < nums.length; ++i) {
for (int j = 0; j < i; ++j) {
if (nums[j] < nums[i]) {
dp[i] = Math.max(dp[i], dp[j] + 1);
}
}
res = Math.max(res, dp[i]);
}
return res;
}
}
the optimized dynamic programming
… which has $O(n log(n))$ time complexity because we use the binary search somewhere. dp[i]
is the smallest possible ending value in all the increasing subsequence with length i + 1
. For each number in nums
, we insert the number into the position in dp
such that dp
is strictly increasing. GeeksforGeeks has one of the best explanations.
import java.util.*;
public class Solution {
public int lengthOfLIS(int[] nums) {
// time O(n log(n))
// space O(n)
int[] dp = new int[nums.length];
// dp[i] is the smallest possible ending value in all increasing subsequence with length i + 1
int res = 0;
for (int i = 0; i < nums.length; ++i) {
int num = nums[i];
int pos = Arrays.binarySearch(dp, 0, res, num); // (array, start, end, key) (end exclusive)
if (pos < 0) {
// not found, -> insertion point
pos = - (pos + 1);
}
dp[pos] = num;
if (pos == res) {
++res;
}
}
return res;
}
}
how to reconstruct the subsequence?
Many many posts provide the above optimized way, but do NOT discuss the way to reconstruct the subsequence clearly. For example, one SO post discusses the reconstruction, but doesn’t provide the code!
We should pay attention that the dp
in the end is NOT the LIS.
So one way is to maintain an array prev
. prev[i]
indicates the “parent / previous index in nums
” for the value nums[i]
. Moreover, for simplicity, we add another array dpIdx
to maintain the original index in nums
for each value in dp
(we could actually replace dp
with dpIdx
to save space, but we need to modify the binary search function). In each iteration, we use binary search to find the insertion point pos
to dp
. If pos
is larger than 0
, it means that the number to be inserted is an extension to the subsequence ending with dp[pos - 1]
, thus prev[i] = dpIdx[pos - 1]
. Otherwise, pos == 0
, then this number could potentially be the start of a new LIS, thus prev[i] = -1
to indicate that the number has no previous element. After we find out the length, we can construct the LIS backwards.
import java.util.*;
public class Solution {
public int lengthOfLIS(int[] nums) {
// time O(n log(n))
// space O(n)
int[] dp = new int[nums.length];
int[] dpIdx = new int[nums.length];
int[] prev = new int[nums.length];
int res = 0;
for (int i = 0; i < nums.length; ++i) {
int num = nums[i];
int pos = Arrays.binarySearch(dp, 0, res, num); // (array, start, end, key) (end exclusive)
if (pos < 0) {
// not found, -> insertion point
pos = - (pos + 1);
}
dp[pos] = num;
dpIdx[pos] = i;
if (pos > 0) {
prev[i] = dpIdx[pos - 1];
} else {
prev[i] = -1;
}
if (pos == res) {
++res;
}
}
// reconstruct the the LIS
int[] lis = new int[res];
for (int i = dpIdx[res - 1], j = lis.length - 1; i >= 0 && j >= 0; i = prev[i], j -= 1) {
lis[j] = nums[i];
}
System.out.printf("nums: %s\n", Arrays.toString(nums));
System.out.printf("dp: %s\n", Arrays.toString(dp));
System.out.printf("dpIdx: %s\n", Arrays.toString(dpIdx));
System.out.printf("prev: %s\n", Arrays.toString(prev));
System.out.printf("LIS: %s\n", Arrays.toString(lis));
return res;
}
public static void main(String[] args) {
Solution s = new Solution();
int res = s.lengthOfLIS(new int[]{0, 2, 6, 1, 7, 4, 8, 5, 9, 7});
}
}
how about longest non-decreasing subsequence?
The binary search usually returns the left-most / lower-bound index for the target. But if we want the subsequence be non-decreasing, we should insert the num
into the first position larger than the target! Acutally std::upper_bound
in C++ is what I mean.
For my above code in Java, pay attention to:
int pos = Arrays.binarySearch(dp, 0, res, num); // (array, start, end, key) (end exclusive)
if (pos < 0) {
// not found, -> insertion point
pos = - (pos + 1);
}
Change it to:
int pos = Arrays.binarySearch(dp, 0, res, num); // (array, start, end, key) (end exclusive)
if (pos < 0) {
// not found, -> insertion point
pos = - (pos + 1);
}
while (pos < res && dp[pos] == num) {
// shift right
pos += 1;
}
Or, if we are only dealing with integers, we could search the num + 1
instead:
// pay attention: search num + 1, not num
int pos = Arrays.binarySearch(dp, 0, res, num + 1); // (array, start, end, key) (end exclusive)
if (pos < 0) {
// not found, -> insertion point
pos = - (pos + 1);
}
try it
On LeetCode #300.