Find the Shortest Superstring

hard dynamic programming bitmask travelling salesman

Problem

Given an array of strings words (no string is a substring of another), return the shortest string that contains every word as a substring. If there are multiple valid answers, return any of them.

Two words can be merged by overlapping a suffix of one with a prefix of the next, e.g. "cat" + "atom" share the overlap "at" and merge into "catom". We want the ordering of all words that maximizes total overlap, which is exactly a Travelling-Salesman ordering solved with bitmask DP.

Inputwords = ["catg", "ctaagt", "gcta", "ttca", "atgcatc"]
Output"gctaagttcatgcatc"
Every input word appears inside the output, and no shorter superstring exists.

def shortest_superstring(words):
    n = len(words)
    # overlap[i][j] = length of longest suffix of words[i] == prefix of words[j]
    overlap = [[0] * n for _ in range(n)]
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            m = min(len(words[i]), len(words[j]))
            for k in range(m, 0, -1):
                if words[i].endswith(words[j][:k]):
                    overlap[i][j] = k
                    break
    # dp[mask][i] = max total overlap of a path using set `mask` ending at word i
    dp = [[0] * n for _ in range(1 << n)]
    parent = [[-1] * n for _ in range(1 << n)]
    for mask in range(1, 1 << n):
        for i in range(n):
            if not (mask >> i) & 1:
                continue
            pmask = mask ^ (1 << i)
            if pmask == 0:
                continue
            for j in range(n):
                if not (pmask >> j) & 1:
                    continue
                cand = dp[pmask][j] + overlap[j][i]
                if parent[mask][i] == -1 or cand > dp[mask][i]:
                    dp[mask][i] = cand
                    parent[mask][i] = j
    # pick best endpoint over the full set
    full = (1 << n) - 1
    last = max(range(n), key=lambda i: dp[full][i])
    # rebuild the order, then stitch with overlaps
    order, mask = [], full
    while last != -1:
        order.append(last)
        nxt = parent[mask][last]
        mask ^= (1 << last)
        last = nxt
    order.reverse()
    res = words[order[0]]
    for k in range(1, len(order)):
        a, b = order[k - 1], order[k]
        res += words[b][overlap[a][b]:]
    return res
function shortestSuperstring(words) {
  const n = words.length;
  const overlap = Array.from({ length: n }, () => new Array(n).fill(0));
  for (let i = 0; i < n; i++)
    for (let j = 0; j < n; j++) {
      if (i === j) continue;
      const m = Math.min(words[i].length, words[j].length);
      for (let k = m; k > 0; k--)
        if (words[i].endsWith(words[j].slice(0, k))) { overlap[i][j] = k; break; }
    }
  const dp = Array.from({ length: 1 << n }, () => new Array(n).fill(0));
  const parent = Array.from({ length: 1 << n }, () => new Array(n).fill(-1));
  for (let mask = 1; mask < (1 << n); mask++)
    for (let i = 0; i < n; i++) {
      if (!((mask >> i) & 1)) continue;
      const pmask = mask ^ (1 << i);
      if (pmask === 0) continue;
      for (let j = 0; j < n; j++) {
        if (!((pmask >> j) & 1)) continue;
        const cand = dp[pmask][j] + overlap[j][i];
        if (parent[mask][i] === -1 || cand > dp[mask][i]) {
          dp[mask][i] = cand;
          parent[mask][i] = j;
        }
      }
    }
  const full = (1 << n) - 1;
  let last = 0;
  for (let i = 1; i < n; i++) if (dp[full][i] > dp[full][last]) last = i;
  const order = [];
  let mask = full;
  while (last !== -1) { order.push(last); const nxt = parent[mask][last]; mask ^= (1 << last); last = nxt; }
  order.reverse();
  let res = words[order[0]];
  for (let k = 1; k < order.length; k++)
    res += words[order[k]].slice(overlap[order[k - 1]][order[k]]);
  return res;
}
class Solution {
    public String shortestSuperstring(String[] words) {
        int n = words.length;
        int[][] overlap = new int[n][n];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++) {
                if (i == j) continue;
                int m = Math.min(words[i].length(), words[j].length());
                for (int k = m; k > 0; k--)
                    if (words[i].endsWith(words[j].substring(0, k))) { overlap[i][j] = k; break; }
            }
        int[][] dp = new int[1 << n][n];
        int[][] parent = new int[1 << n][n];
        for (int[] row : parent) Arrays.fill(row, -1);
        for (int mask = 1; mask < (1 << n); mask++)
            for (int i = 0; i < n; i++) {
                if (((mask >> i) & 1) == 0) continue;
                int pmask = mask ^ (1 << i);
                if (pmask == 0) continue;
                for (int j = 0; j < n; j++) {
                    if (((pmask >> j) & 1) == 0) continue;
                    int cand = dp[pmask][j] + overlap[j][i];
                    if (parent[mask][i] == -1 || cand > dp[mask][i]) {
                        dp[mask][i] = cand;
                        parent[mask][i] = j;
                    }
                }
            }
        int full = (1 << n) - 1, last = 0;
        for (int i = 1; i < n; i++) if (dp[full][i] > dp[full][last]) last = i;
        StringBuilder order = new StringBuilder();
        int[] seq = new int[n]; int idx = n, mask = full;
        while (last != -1) { seq[--idx] = last; int nxt = parent[mask][last]; mask ^= (1 << last); last = nxt; }
        StringBuilder res = new StringBuilder(words[seq[idx]]);
        for (int k = idx + 1; k < n; k++)
            res.append(words[seq[k]].substring(overlap[seq[k - 1]][seq[k]]));
        return res.toString();
    }
}
string shortestSuperstring(vector<string>& words) {
    int n = words.size();
    vector<vector<int>> overlap(n, vector<int>(n, 0));
    for (int i = 0; i < n; i++)
        for (int j = 0; j < n; j++) {
            if (i == j) continue;
            int m = min(words[i].size(), words[j].size());
            for (int k = m; k > 0; k--)
                if (words[i].compare(words[i].size() - k, k, words[j], 0, k) == 0) { overlap[i][j] = k; break; }
        }
    vector<vector<int>> dp(1 << n, vector<int>(n, 0)), parent(1 << n, vector<int>(n, -1));
    for (int mask = 1; mask < (1 << n); mask++)
        for (int i = 0; i < n; i++) {
            if (!((mask >> i) & 1)) continue;
            int pmask = mask ^ (1 << i);
            if (pmask == 0) continue;
            for (int j = 0; j < n; j++) {
                if (!((pmask >> j) & 1)) continue;
                int cand = dp[pmask][j] + overlap[j][i];
                if (parent[mask][i] == -1 || cand > dp[mask][i]) {
                    dp[mask][i] = cand;
                    parent[mask][i] = j;
                }
            }
        }
    int full = (1 << n) - 1, last = 0;
    for (int i = 1; i < n; i++) if (dp[full][i] > dp[full][last]) last = i;
    vector<int> order;
    int mask = full;
    while (last != -1) { order.push_back(last); int nxt = parent[mask][last]; mask ^= (1 << last); last = nxt; }
    reverse(order.begin(), order.end());
    string res = words[order[0]];
    for (int k = 1; k < (int)order.size(); k++)
        res += words[order[k]].substr(overlap[order[k - 1]][order[k]]);
    return res;
}
Time: O(2ⁿ · n² + n²·L) Space: O(2ⁿ · n)