Kth Smallest Product of Two Sorted Arrays

hard array binary search

Problem

Given two sorted integer arrays nums1 and nums2 and an integer k, return the kth smallest value of nums1[i] * nums2[j] taken over all index pairs (i, j).

Inputnums1 = [-2, -1, 0, 1, 2], nums2 = [-3, -1, 2, 4, 5], k = 3
Output-6
All products sorted include −10, −8, −6, ... so the 3rd smallest is −6.

def kth_smallest_product(nums1, nums2, k):
    def count_le(x):
        total = 0
        for a in nums1:
            if a > 0:
                lo, hi = 0, len(nums2)
                while lo < hi:
                    m = (lo + hi) // 2
                    if a * nums2[m] <= x: lo = m + 1
                    else: hi = m
                total += lo
            elif a < 0:
                lo, hi = 0, len(nums2)
                while lo < hi:
                    m = (lo + hi) // 2
                    if a * nums2[m] <= x: hi = m
                    else: lo = m + 1
                total += len(nums2) - lo
            elif x >= 0:
                total += len(nums2)
        return total

    lo, hi = -10**10, 10**10
    while lo < hi:
        mid = (lo + hi) // 2
        if count_le(mid) >= k: hi = mid
        else: lo = mid + 1
    return lo
function kthSmallestProduct(nums1, nums2, k) {
  const countLE = (x) => {
    let total = 0n;
    for (const a of nums1) {
      const A = BigInt(a);
      if (a > 0) {
        let lo = 0, hi = nums2.length;
        while (lo < hi) {
          const m = (lo + hi) >> 1;
          if (A * BigInt(nums2[m]) <= x) lo = m + 1; else hi = m;
        }
        total += BigInt(lo);
      } else if (a < 0) {
        let lo = 0, hi = nums2.length;
        while (lo < hi) {
          const m = (lo + hi) >> 1;
          if (A * BigInt(nums2[m]) <= x) hi = m; else lo = m + 1;
        }
        total += BigInt(nums2.length - lo);
      } else if (x >= 0n) total += BigInt(nums2.length);
    }
    return total;
  };
  let lo = -10n ** 10n, hi = 10n ** 10n;
  while (lo < hi) {
    const mid = (lo + hi) >> 1n;
    if (countLE(mid) >= BigInt(k)) hi = mid; else lo = mid + 1n;
  }
  return Number(lo);
}
class Solution {
    public long kthSmallestProduct(int[] nums1, int[] nums2, long k) {
        long lo = -(long)1e10, hi = (long)1e10;
        while (lo < hi) {
            long mid = lo + (hi - lo) / 2;
            if (countLE(nums1, nums2, mid) >= k) hi = mid;
            else lo = mid + 1;
        }
        return lo;
    }
    private long countLE(int[] a, int[] b, long x) {
        long total = 0;
        for (int v : a) {
            if (v > 0) {
                int lo = 0, hi = b.length;
                while (lo < hi) {
                    int m = (lo + hi) >>> 1;
                    if ((long)v * b[m] <= x) lo = m + 1; else hi = m;
                }
                total += lo;
            } else if (v < 0) {
                int lo = 0, hi = b.length;
                while (lo < hi) {
                    int m = (lo + hi) >>> 1;
                    if ((long)v * b[m] <= x) hi = m; else lo = m + 1;
                }
                total += b.length - lo;
            } else if (x >= 0) total += b.length;
        }
        return total;
    }
}
long long kthSmallestProduct(vector<int>& a, vector<int>& b, long long k) {
    auto countLE = [&](long long x) -> long long {
        long long total = 0;
        for (int v : a) {
            if (v > 0) {
                int lo = 0, hi = b.size();
                while (lo < hi) {
                    int m = (lo + hi) / 2;
                    if ((long long)v * b[m] <= x) lo = m + 1; else hi = m;
                }
                total += lo;
            } else if (v < 0) {
                int lo = 0, hi = b.size();
                while (lo < hi) {
                    int m = (lo + hi) / 2;
                    if ((long long)v * b[m] <= x) hi = m; else lo = m + 1;
                }
                total += (long long)b.size() - lo;
            } else if (x >= 0) total += b.size();
        }
        return total;
    };
    long long lo = -(long long)1e10, hi = (long long)1e10;
    while (lo < hi) {
        long long mid = lo + (hi - lo) / 2;
        if (countLE(mid) >= k) hi = mid;
        else lo = mid + 1;
    }
    return lo;
}
Time: O((m + n) log m log V) Space: O(1)