Last Updated on 2022-12-10 by Clay
題目
Given the root
of a binary tree, split the binary tree into two subtrees by removing one edge such that the product of the sums of the subtrees is maximized.
Return the maximum product of the sums of the two subtrees. Since the answer may be too large, return it modulo 109 + 7
.
Note that you need to maximize the answer before taking the mod and not after taking it.
Example 1:
Input: root = [1,2,3,4,5,6] Output: 110 Explanation: Remove the red edge and get 2 binary trees with sum 11 and 10. Their product is 110 (11*10)
Example 2:
Input: root = [1,null,2,3,4,null,null,5,6] Output: 90 Explanation: Remove the red edge and get 2 binary trees with sum 15 and 6.Their product is 90 (15*6)
Constraints:
- The number of nodes in the tree is in the range
[2, 5 * 104]
. 1 <= Node.val <= 104
題目給定一顆二元搜索樹(BST),我們則要判斷該從哪個連結點『切分』,才能使得分裂成的兩顆樹,假設稱為 A 樹和 B 樹,所有節點加總值 A_Sum 和 B_Sum 的乘積達到最大值。
解題思路
DFS
本質上,最簡單的方法步驟應如下:
- 使用 DFS 把整顆樹遍歷一遍,得到樹中節點的加總值
totalSum
。 - 再次使用 DFS 遍歷整顆樹,這次在每輪的遞迴函式中,都要計算當前加總值(當前加總值其實就是 DFS(左節點) + DFS(右節點) + 當前節點值)和與
totalSum
之間的差值乘積,然後紀錄最大的乘積值。 - 返回時,記得再對 1e9+7 取餘數
DFS 複雜度
Time Complexity | O(H) |
Space Complexity | O(1) |
C++ 範例程式碼
class Solution {
public:
long DFS (TreeNode* root, long& totalSum, long& ans) {
if (!root) {
return 0;
}
long currSum = DFS(root->left, totalSum, ans) + DFS(root->right, totalSum, ans) + root->val;
ans = max(ans, (totalSum-currSum)*currSum);
return currSum;
}
int maxProduct(TreeNode* root) {
// Init
long ans = 0;
// Find the total sum
long totalSum = DFS(root, ans, ans);
// Use total sum to find the answer
DFS(root, totalSum, ans);
return ans % long(1e9 + 7);
}
};
Python 範例程式碼
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def maxProduct(self, root: Optional[TreeNode]) -> int:
self.ans = 0
total_sum = self.DFS(root, 0)
self.DFS(root, total_sum)
return int(self.ans % (1e9 + 7))
def DFS(self, root: Optional[TreeNode], total_sum: int) -> int:
# Base case
if not root:
return 0
curr_sum = self.DFS(root.left, total_sum) + self.DFS(root.right, total_sum) + root.val
self.ans = max(self.ans, (total_sum-curr_sum)*curr_sum)
return curr_sum