Skip to content

LeetCode: 4-Median of Two Sorted Arrays 解題紀錄

Last Updated on 2021-05-11 by Clay

題目

Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.

Follow up: The overall run time complexity should be O(log (m+n)).

Example:

Input: nums1 = [1,2], nums2 = [3,4]
Output: 2.50000
Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.

題目的輸入為兩組已經排序過的陣列(由小到大),而我們的目標就是回傳中間值

  • 若陣列的總長度為奇數,則回傳中央的值
  • 若陣列的總長度為偶數,則回傳中央兩值加總的平均。

解題思路

第一感是將兩個陣列拼接在一起,排序後再判斷陣列長度,並依照長度奇偶的不同回傳不同的值。這個做法很簡單也很好寫,可是速度跟記憶體用量卻不盡如人意。(註:實測過後下面的方法也不快

思考過後,最後我決定按照以下步驟解題(index 為陣列編號的意思):

  1. 先計算陣列相加後的總長度,確認是奇數或是偶數(處理方式不同)
  2. 確認答案需要的 index
    • 偶數有兩個中央的 index
    • 奇數只有一個中央的 index
  3. 在 nums1 或 num2 不為空的條件下進行迴圈,並確認是否為中央的 index,有以下四種情況
    • nums1 為空,必定取 nums2 的值
    • nums2 為空,必定取 nums1 的值
    • nums1 的最大值比 nums2 的最大值大,取 nums1 的值
    • nums2 的最大值比 nums1 的最大值大,取 nums2 的值
  4. 若匹配到中央的 index 值(偶數情況下則是兩個 index 值都匹配過),則終止迴圈,回傳答案


C++ 程式碼

設計規則硬解

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        // Init
        int length = nums1.size() + nums2.size();
        int index = 0;
        
        vector<int> results;
        vector<int> hit_index;
        
        // Check length
        if (length % 2 == 1) {
            hit_index.push_back(length/2);
        }
        else {
            hit_index.push_back(length/2);
            hit_index.push_back(length/2-1);
        }
    
        // Loop
        while (!nums1.empty() or !nums2.empty()) {
            if (nums1.empty()) {
                if (index == hit_index[hit_index.size()-1]) {
                    hit_index.pop_back();
                    results.push_back(nums2[nums2.size()-1]);
                }
                nums2.pop_back();
            }
            else if (nums2.empty()) {
                if (index == hit_index[hit_index.size()-1]) {
                    hit_index.pop_back();
                    results.push_back(nums1[nums1.size()-1]);
                }
                nums1.pop_back();
            }
            else if (nums2[nums2.size()-1] >= nums1[nums1.size()-1]) {
                if (index == hit_index[hit_index.size()-1]) {
                    hit_index.pop_back();
                    results.push_back(nums2[nums2.size()-1]);
                }
                nums2.pop_back();                
            }
            else if (nums1[nums1.size()-1] > nums2[nums2.size()-1]) {
                if (index == hit_index[hit_index.size()-1]) {
                    hit_index.pop_back();
                    results.push_back(nums1[nums1.size()-1]);
                }
                nums1.pop_back();                
            }
            
            // Finished
            if (hit_index.empty()) {
                break;
            }
     
            index++;
        }
        return double(std::accumulate(results.begin(), results.end(), 0)) / results.size();;
    }
};



合併陣列後排序、依長度不同情況回傳答案(簡單版)

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        // Init
        nums1.insert(nums1.end(), nums2.begin(), nums2.end());
        
        // Sort
        sort(nums1.begin(), nums1.end());
        
        // Branch
        if (nums1.size() % 2 == 0) {
            return double(nums1[nums1.size()/2-1]+nums1[nums1.size()/2])/2;
        }
        else {
            return double(nums1[nums1.size()/2]);
        }
    }
};




Python3 程式碼

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        results = []
        length = len(nums1+nums2)
        i = 0   
        
        if length % 2 == 1:
            hit_num = [length//2]
        else:
            hit_num = [length//2, length//2-1]
        
        while nums1 or nums2:            
            if not nums1:
                temp = nums2.pop()
            elif not nums2:
                temp = nums1.pop()
            elif nums1[-1] >= nums2[-1]:
                temp = nums1.pop()
            elif nums1[-1] < nums2[-1]:
                temp = nums2.pop()
            
            if i in hit_num:
                hit_num.pop()
                results.append(temp)
                
                if not hit_num:
                    if len(results) == 1:                    
                        return results[0]
                    else:
                        return sum(results)/2
            
            i += 1


不過,這個方法也稱不上是真的很好的解決方法。


References

Leave a Reply