二叉树系列算法核心纲领
本文讲解的例题
LeetCode | 力扣 | 难度 |
---|---|---|
543. Diameter of Binary Tree | 543. 二叉树的直径 | 🟢 |
104. Maximum Depth of Binary Tree | 104. 二叉树的最大深度 | 🟢 |
144. Binary Tree Preorder Traversal | 144. 二叉树的前序遍历 | 🟢 |
本文阅读方法
本文会把很多算法进行抽象和归纳,所以会包含大量其他文章链接。
第一次阅读本文的读者不要 DFS 学习本文,遇到没学过的算法或不理解的地方请跳过,只要对本文所总结的理论有些印象即可。在学习本站后面的算法技巧时,你自然可以逐渐理解本文的精髓所在,日后回来重读本文,会有更深的体会。
本站所有文章的脉络都是按照 学习数据结构和算法的框架思维 提出的框架来构建的,其中着重强调了二叉树题目的重要性,所以把本文放在第一章的必读系列中。
我刷了这么多年题,浓缩出二叉树算法的一个总纲放在这里,也许用词不是特别专业化,也没有什么教材会收录我的这些经验总结,但目前各个刷题平台的题库,没有一道二叉树题目能跳出本文划定的框架。如果你能发现一道题目和本文给出的框架不兼容,请留言告知我。
先在开头总结一下,二叉树解题的思维模式分两类:
1、是否可以通过遍历一遍二叉树得到答案?如果可以,用一个 traverse
函数配合外部变量来实现,这叫「遍历」的思维模式。
2、是否可以定义一个递归函数,通过子问题(子树)的答案推导出原问题的答案?如果可以,写出这个递归函数的定义,并充分利用这个函数的返回值,这叫「分解问题」的思维模式。
无论使用哪种思维模式,你都需要思考:
如果单独抽出一个二叉树节点,它需要做什么事情?需要在什么时候(前/中/后序位置)做?其他的节点不用你操心,递归函数会帮你在所有节点上执行相同的操作。
本文中会用题目来举例,但都是最最简单的题目,所以不用担心自己看不懂,我可以帮你从最简单的问题中提炼出所有二叉树题目的共性,并将二叉树中蕴含的思维进行升华,反手用到 动态规划,回溯算法,分治算法,图论算法 中去,这也是我一直强调框架思维的原因。希望你在学习了上述高级算法后,也能回头再来看看本文,会对它们有更深刻的认识。
首先,我还是要不厌其烦地强调一下二叉树这种数据结构及相关算法的重要性。
二叉树的重要性
举个例子,比如两个经典排序算法 快速排序 和 归并排序,对于它俩,你有什么理解?
如果你告诉我,快速排序就是个二叉树的前序遍历,归并排序就是个二叉树的后序遍历,那么我就知道你是个算法高手了。
为什么快速排序和归并排序能和二叉树扯上关系?我们来简单分析一下他们的算法思想和代码框架:
快速排序的逻辑是,若要对 nums[lo..hi]
进行排序,我们先找一个分界点 p
,通过交换元素使得 nums[lo..p-1]
都小于等于 nums[p]
,且 nums[p+1..hi]
都大于 nums[p]
,然后递归地去 nums[lo..p-1]
和 nums[p+1..hi]
中寻找新的分界点,最后整个数组就被排序了。
快速排序的代码框架如下:
void sort(int[] nums, int lo, int hi) {
// ****** 前序遍历位置 ******
// 通过交换元素构建分界点 p
int p = partition(nums, lo, hi);
// ************************
sort(nums, lo, p - 1);
sort(nums, p + 1, hi);
}
void sort(vector<int>& nums, int lo, int hi) {
// ****** 前序遍历位置 ******
// 通过交换元素构建分界点 p
int p = partition(nums, lo, hi);
// ************************
sort(nums, lo, p - 1);
sort(nums, p + 1, hi);
}
def sort(nums: list, lo: int, hi: int):
if lo >= hi:
return
# ****** 前序遍历位置 ******
# 通过交换元素构建分界点 p
p = partition(nums, lo, hi)
# ************************
sort(nums, lo, p - 1)
sort(nums, p + 1, hi)
func sort(nums []int, lo int, hi int) {
// ****** 前序遍历位置 ******
// 通过交换元素构建分界点 p
p := partition(nums, lo, hi)
// ************************
sort(nums, lo, p - 1)
sort(nums, p + 1, hi)
}
var sort = function(nums, lo, hi) {
// ****** 前序遍历位置 ******
// 通过交换元素构建分界点 p
var p = partition(nums, lo, hi);
// ************************
sort(nums, lo, p - 1);
sort(nums, p + 1, hi);
}
先构造分界点,然后去左右子数组构造分界点,你看这不就是一个二叉树的前序遍历吗?
再说说归并排序的逻辑,若要对 nums[lo..hi]
进行排序,我们先对 nums[lo..mid]
排序,再对 nums[mid+1..hi]
排序,最后把这两个有序的子数组合并,整个数组就排好序了。
归并排序的代码框架如下:
// 定义:排序 nums[lo..hi]
void sort(int[] nums, int lo, int hi) {
int mid = (lo + hi) / 2;
// 排序 nums[lo..mid]
sort(nums, lo, mid);
// 排序 nums[mid+1..hi]
sort(nums, mid + 1, hi);
// ****** 后序位置 ******
// 合并 nums[lo..mid] 和 nums[mid+1..hi]
merge(nums, lo, mid, hi);
// *********************
}
// 定义:排序 nums[lo..hi]
void sort(vector<int>& nums, int lo, int hi) {
if (hi <= lo) return;
int mid = lo + (hi - lo) / 2;
// 排序 nums[lo..mid]
sort(nums, lo, mid);
// 排序 nums[mid+1..hi]
sort(nums, mid + 1, hi);
// ****** 后序位置 ******
// 合并 nums[lo..mid] 和 nums[mid+1..hi]
merge(nums, lo, mid, hi);
// *********************
}
def sort(nums: List[int], lo: int, hi: int) -> None:
# 定义:排序 nums[lo..hi]
mid = (lo + hi) // 2
# 排序 nums[lo..mid]
sort(nums, lo, mid)
# 排序 nums[mid+1..hi]
sort(nums, mid + 1, hi)
# ****** 后序位置 ******
# 合并 nums[lo..mid] 和 nums[mid+1..hi]
merge(nums, lo, mid, hi)
# *********************
// 定义:排序 nums[lo..hi]
func sort(nums []int, lo int, hi int) {
mid := (lo + hi) / 2
// 排序 nums[lo..mid]
sort(nums, lo, mid)
// 排序 nums[mid+1..hi]
sort(nums, mid + 1, hi)
// ****** 后序位置 ******
// 合并 nums[lo..mid] 和 nums[mid+1..hi]
merge(nums, lo, mid, hi)
// *********************
}
var sort = function(nums, lo, hi) {
// 定义:排序 nums[lo..hi]
var mid = Math.floor((lo + hi) / 2);
// 排序 nums[lo..mid]
sort(nums, lo, mid);
// 排序 nums[mid+1..hi]
sort(nums, mid + 1, hi);
// ****** 后序位置 ******
// 合并 nums[lo..mid] 和 nums[mid+1..hi]
merge(nums, lo, mid, hi);
// *********************
}
先对左右子数组排序,然后合并(类似合并有序链表的逻辑),你看这是不是二叉树的后序遍历框架?另外,这不就是传说中的分治算法嘛,不过如此呀。
如果你一眼就识破这些排序算法的底细,还需要背这些经典算法吗?不需要。你可以手到擒来,从二叉树遍历框架就能扩展出算法了。
说了这么多,旨在说明,二叉树的算法思想的运用广泛,甚至可以说,只要涉及递归,都可以抽象成二叉树的问题。
接下来我们从二叉树的前中后序开始讲起,让你深刻理解这种数据结构的魅力。
深入理解前中后序
我先甩给你几个问题,请默默思考 30 秒:
1、你理解的二叉树的前中后序遍历是什么,仅仅是三个顺序不同的 List 吗?
2、请分析,后序遍历有什么特殊之处?
3、请分析,为什么多叉树没有中序遍历?
答不上来,说明你对前中后序的理解仅仅局限于教科书,不过没关系,我用类比的方式解释一下我眼中的前中后序遍历。
首先,回顾一下 二叉树的 DFS/BFS 遍历 中说到的二叉树递归遍历框架:
void traverse(TreeNode root) {
if (root == null) {
return;
}
// 前序位置
traverse(root.left);
// 中序位置
traverse(root.right);
// 后序位置
}
void traverse(TreeNode* root) {
if (root == nullptr) {
return;
}
// 前序位置
traverse(root->left);
// 中序位置
traverse(root->right);
// 后序位置
}
def traverse(root):
if root is None:
return
# 前序位置
traverse(root.left)
# 中序位置
traverse(root.right)
# 后序位置
func traverse(root *TreeNode) {
if root == nil {
return
}
// 前序位置
traverse(root.Left)
// 中序位置
traverse(root.Right)
// 后序位置
}
var traverse = function(root) {
if (root == null) {
return;
}
// 前序位置
traverse(root.left);
// 中序位置
traverse(root.right);
// 后序位置
};
先不管所谓前中后序,单看 traverse
函数,你说它在做什么事情?
其实它就是一个能够遍历二叉树所有节点的一个函数,和你遍历数组或者链表本质上没有区别:
// 迭代遍历数组
void traverse(int[] arr) {
for (int i = 0; i < arr.length; i++) {
}
}
// 递归遍历数组
void traverse(int[] arr, int i) {
if (i == arr.length) {
return;
}
// 前序位置
traverse(arr, i + 1);
// 后序位置
}
// 迭代遍历单链表
void traverse(ListNode head) {
for (ListNode p = head; p != null; p = p.next) {
}
}
// 递归遍历单链表
void traverse(ListNode head) {
if (head == null) {
return;
}
// 前序位置
traverse(head.next);
// 后序位置
}
// 迭代遍历数组
void traverse(vector<int>& arr) {
for (int i = 0; i < arr.size(); i++) {
}
}
// 递归遍历数组
void traverse(vector<int>& arr, int i) {
if (i == arr.size()) {
return;
}
// 前序位置
traverse(arr, i + 1);
// 后序位置
}
struct ListNode {
int val;
ListNode *next;
};
// 迭代遍历单链表
void traverse(ListNode* head) {
for (ListNode* p = head; p != nullptr; p = p->next) {
}
}
// 递归遍历单链表
void traverse(ListNode* head) {
if (head == nullptr) {
return;
}
// 前序位置
traverse(head->next);
// 后序位置
}
# 迭代遍历数组
def traverse(arr: List[int]) -> None:
for i in range(len(arr)):
pass
# 递归遍历数组
def traverse_recursive(arr: List[int], i: int) -> None:
if i == len(arr):
return
# 前序位置
traverse_recursive(arr, i + 1)
# 后序位置
# 迭代遍历单链表
def traverse_linked_list(head: ListNode) -> None:
p = head
while p:
p = p.next
# 递归遍历单链表
def traverse_linked_list_recursive(head: ListNode) -> None:
if not head:
return
# 前序位置
traverse_linked_list_recursive(head.next)
# 后序位置
// 迭代遍历数组
func traverse(arr []int) {
for i := 0; i < len(arr); i++ {
}
}
// 递归遍历数组
func traverse(arr []int, i int) {
if i == len(arr) {
return;
}
// 前序位置
traverse(arr, i + 1)
// 后序位置
}
// 迭代遍历单链表
func traverse(head *ListNode) {
for p := head; p != nil; p = p.Next {
}
}
// 递归遍历单链表
func traverse(head *ListNode) {
if head == nil {
return;
}
// 前序位置
traverse(head.Next)
// 后序位置
}
// 迭代遍历数组
var traverse = function(arr) {
for (var i = 0; i < arr.length; i++) {
}
}
// 递归遍历数组
var traverseRecursive = function(arr, i) {
if (i == arr.length) {
return;
}
// 前序位置
traverseRecursive(arr, i + 1);
// 后序位置
}
// 迭代遍历单链表
var traverseList = function(head) {
for (var p = head; p != null; p = p.next) {
}
}
// 递归遍历单链表
var traverseListRecursive = function(head) {
if (head == null) {
return;
}
// 前序位置
traverseListRecursive(head.next);
// 后序位置
}
单链表和数组的遍历可以是迭代的,也可以是递归的,二叉树这种结构无非就是二叉链表,它没办法简单改写成 for 循环的迭代形式,所以我们遍历二叉树一般都使用递归形式。
你也注意到了,只要是递归形式的遍历,都可以有前序位置和后序位置,分别在递归之前和递归之后。
所谓前序位置,就是刚进入一个节点(元素)的时候,后序位置就是即将离开一个节点(元素)的时候,那么进一步,你把代码写在不同位置,代码执行的时机也不同:
比如说,如果让你倒序打印一条单链表上所有节点的值,你怎么搞?
实现方式当然有很多,但如果你对递归的理解足够透彻,可以利用后序位置来操作:
// 递归遍历单链表,倒序打印链表元素
void traverse(ListNode head) {
if (head == null) {
return;
}
traverse(head.next);
// 后序位置
print(head.val);
}
// 递归遍历单链表,倒序打印链表元素
void traverse(ListNode* head) {
if (head == nullptr) {
return;
}
traverse(head->next);
// 后序位置
cout << head->val << endl;
}
# 递归遍历单链表,倒序打印链表元素
def traverse(head):
if head is None:
return
traverse(head.next)
# 后序位置
print(head.val)
// 递归遍历单链表,倒序打印链表元素
func traverse(head *ListNode) {
if head == nil {
return
}
traverse(head.Next)
// 后序位置
fmt.Println(head.Val)
}
// 递归遍历单链表,倒序打印链表元素
var traverse = function(head) {
if (head == null) {
return;
}
traverse(head.next);
// 后序位置
console.log(head.val);
};
结合上面那张图,你应该知道为什么这段代码能够倒序打印单链表了吧,本质上是利用递归的堆栈帮你实现了倒序遍历的效果。
那么说回二叉树也是一样的,只不过多了一个中序位置罢了。
教科书里只会问你前中后序遍历结果分别是什么,所以对于一个只上过大学数据结构课程的人来说,他大概以为二叉树的前中后序只不过对应三种顺序不同的 List<Integer>
列表。
但是我想说,前中后序是遍历二叉树过程中处理每一个节点的三个特殊时间点,绝不仅仅是三个顺序不同的 List:
前序位置的代码在刚刚进入一个二叉树节点的时候执行;
后序位置的代码在将要离开一个二叉树节点的时候执行;
中序位置的代码在一个二叉树节点左子树都遍历完,即将开始遍历右子树的时候执行。
你注意本文的用词,我一直说前中后序「位置」,就是要和大家常说的前中后序「遍历」有所区别:你可以在前序位置写代码往一个 List 里面塞元素,那最后得到的就是前序遍历结果;但并不是说你就不可以写更复杂的代码做更复杂的事。
画成图,前中后序三个位置在二叉树上是这样:
你可以发现每个节点都有「唯一」属于自己的前中后序位置,所以我说前中后序遍历是遍历二叉树过程中处理每一个节点的三个特殊时间点。
这里你也可以理解为什么多叉树没有中序位置,因为二叉树的每个节点只会进行唯一一次左子树切换右子树,而多叉树节点可能有很多子节点,会多次切换子树去遍历,所以多叉树节点没有「唯一」的中序遍历位置。
说了这么多基础的,就是要帮你对二叉树建立正确的认识,然后你会发现:
二叉树的所有问题,就是让你在前中后序位置注入巧妙的代码逻辑,去达到自己的目的,你只需要单独思考每一个节点应该做什么,其他的不用你管,抛给二叉树遍历框架,递归会在所有节点上做相同的操作。
你也可以看到,图论算法基础 把二叉树的遍历框架扩展到了图,并以遍历为基础实现了图论的各种经典算法,不过这是后话,本文就不多说了。
两种解题思路
前文 我的算法学习心得 说过:
二叉树题目的递归解法可以分两类思路,第一类是遍历一遍二叉树得出答案,第二类是通过分解问题计算出答案,这两类思路分别对应着 回溯算法核心框架 和 动态规划核心框架。
提示
这里说一下我的函数命名习惯:二叉树中用遍历思路解题时函数签名一般是 void traverse(...)
,没有返回值,靠更新外部变量来计算结果,而用分解问题思路解题时函数名根据该函数具体功能而定,而且一般会有返回值,返回值是子问题的计算结果。
与此对应的,你会发现我在 回溯算法核心框架 中给出的函数签名一般也是没有返回值的 void backtrack(...)
,而在 动态规划核心框架 中给出的函数签名是带有返回值的 dp
函数。这也说明它俩和二叉树之间千丝万缕的联系。
虽然函数命名没有什么硬性的要求,但我还是建议你也遵循我的这种风格,这样更能突出函数的作用和解题的思维模式,便于你自己理解和运用。
当时我是用二叉树的最大深度这个问题来举例,重点在于把这两种思路和动态规划和回溯算法进行对比,而本文的重点在于分析这两种思路如何解决二叉树的题目。
力扣第 104 题「二叉树的最大深度」就是最大深度的题目,所谓最大深度就是根节点到「最远」叶子节点的最长路径上的节点数,比如输入这棵二叉树,算法应该返回 3:
你做这题的思路是什么?显然遍历一遍二叉树,用一个外部变量记录每个节点所在的深度,取最大值就可以得到最大深度,这就是遍历二叉树计算答案的思路。
解法代码如下:
class Solution {
// 记录最大深度
int res = 0;
// 记录遍历到的节点的深度
int depth = 0;
public int maxDepth(TreeNode root) {
traverse(root);
return res;
}
// 二叉树遍历框架
void traverse(TreeNode root) {
if (root == null) {
return;
}
// 前序位置
depth++;
if (root.left == null && root.right == null) {
// 到达叶子节点,更新最大深度
res = Math.max(res, depth);
}
traverse(root.left);
traverse(root.right);
// 后序位置
depth--;
}
}
class Solution {
public:
// 记录最大深度
int res = 0;
// 记录遍历到的节点的深度
int depth = 0;
// 主函数
int maxDepth(TreeNode* root) {
traverse(root);
return res;
}
// 二叉树遍历框架
void traverse(TreeNode* root) {
if (root == nullptr) {
return;
}
// 前序位置
depth++;
if (root->left == nullptr && root->right == nullptr) {
// 到达叶子节点,更新最大深度
res = std::max(res, depth);
}
traverse(root->left);
traverse(root->right);
// 后序位置
depth--;
}
};
class Solution:
def __init__(self):
# 记录最大深度
self.res = 0
# 记录遍历到的节点的深度
self.depth = 0
def maxDepth(self, root: TreeNode) -> int:
self.traverse(root)
return self.res
# 二叉树遍历框架
def traverse(self, root: TreeNode) -> None:
if root is None:
return
# 前序位置
self.depth += 1
if root.left is None and root.right is None:
# 到达叶子节点,更新最大深度
self.res = max(self.res, self.depth)
self.traverse(root.left)
self.traverse(root.right)
# 后序位置
self.depth -= 1
// 主函数
func maxDepth(root *TreeNode) int {
// 记录最大深度
var res int = 0
// 记录遍历到的节点的深度
var depth int = 0;
// 二叉树遍历框架
var traverse func(root *TreeNode)
traverse = func(root *TreeNode) {
if (root == nil) {
return
}
// 前序位置
depth++
if (root.Left == nil && root.Right == nil) {
// 到达叶子节点,更新最大深度
res = max(res, depth)
}
traverse(root.Left)
traverse(root.Right)
// 后序位置
depth--
}
traverse(root)
return res
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
// 主函数
var maxDepth = function(root) {
// 记录最大深度
var res = 0;
// 记录遍历到的节点的深度
var depth = 0;
// 二叉树遍历框架
var traverse = function(root) {
if (root == null) {
return;
}
// 前序位置
depth++;
if (root.left == null && root.right == null) {
// 到达叶子节点,更新最大深度
res = Math.max(res, depth);
}
traverse(root.left);
traverse(root.right);
// 后序位置
depth--;
};
traverse(root);
return res;
};
这个解法应该很好理解,但为什么需要在前序位置增加 depth
,在后序位置减小 depth
?
因为前面说了,前序位置是进入一个节点的时候,后序位置是离开一个节点的时候,depth
记录当前递归到的节点深度,你把 traverse
理解成在二叉树上游走的一个指针,所以当然要这样维护。
至于对 res
的更新,你放到前中后序位置都可以,只要保证在进入节点之后,离开节点之前(即 depth
自增之后,自减之前)就行了。
当然,你也很容易发现一棵二叉树的最大深度可以通过子树的最大深度推导出来,这就是分解问题计算答案的思路。
解法代码如下:
class Solution {
// 定义:输入根节点,返回这棵二叉树的最大深度
public int maxDepth(TreeNode root) {
if (root == null) {
return 0;
}
// 利用定义,计算左右子树的最大深度
int leftMax = maxDepth(root.left);
int rightMax = maxDepth(root.right);
// 整棵树的最大深度等于左右子树的最大深度取最大值,
// 然后再加上根节点自己
int res = Math.max(leftMax, rightMax) + 1;
return res;
}
}
class Solution {
public:
// 定义:输入根节点,返回这棵二叉树的最大深度
int maxDepth(TreeNode* root) {
if (root == nullptr) {
return 0;
}
// 利用定义,计算左右子树的最大深度
int leftMax = maxDepth(root->left);
int rightMax = maxDepth(root->right);
// 整棵树的最大深度等于左右子树的最大深度取最大值,
// 然后再加上根节点自己
int res = std::max(leftMax, rightMax) + 1;
return res;
}
};
class Solution:
# 定义:输入根节点,返回这棵二叉树的最大深度
def maxDepth(self, root: TreeNode) -> int:
if root is None:
return 0
# 利用定义,计算左右子树的最大深度
leftMax = self.maxDepth(root.left)
rightMax = self.maxDepth(root.right)
# 整棵树的最大深度等于左右子树的最大深度取最大值,
# 然后再加上根节点自己
res = max(leftMax, rightMax) + 1
return res
// 定义:输入根节点,返回这棵二叉树的最大深度
func maxDepth(root *TreeNode) int {
if root == nil {
return 0
}
// 利用定义,计算左右子树的最大深度
leftMax := maxDepth(root.Left)
rightMax := maxDepth(root.Right)
// 整棵树的最大深度等于左右子树的最大深度取最大值,
// 然后再加上根节点自己
res := max(leftMax, rightMax) + 1
return res
}
func max(a int, b int) int {
if a > b {
return a
}
return b
}
// 定义:输入根节点,返回这棵二叉树的最大深度
var maxDepth = function(root) {
if (root === null) {
return 0;
}
// 利用定义,计算左右子树的最大深度
var leftMax = maxDepth(root.left);
var rightMax = maxDepth(root.right);
// 整棵树的最大深度等于左右子树的最大深度取最大值,
// 然后再加上根节点自己
var res = Math.max(leftMax, rightMax) + 1;
return res;
};
只要明确递归函数的定义,这个解法也不难理解,但为什么主要的代码逻辑集中在后序位置?
因为这个思路正确的核心在于,你确实可以通过子树的最大深度推导出原树的深度,所以当然要首先利用递归函数的定义算出左右子树的最大深度,然后推出原树的最大深度,主要逻辑自然放在后序位置。
如果你理解了最大深度这个问题的两种思路,那么我们再回头看看最基本的二叉树前中后序遍历,就比如力扣第 144 题「二叉树的前序遍历」,让你计算前序遍历结果。
我们熟悉的解法就是用「遍历」的思路,我想应该没什么好说的:
class Solution {
// 存放前序遍历结果
List<Integer> res = new LinkedList<>();
// 返回前序遍历结果
public List<Integer> preorderTraversal(TreeNode root) {
traverse(root);
return res;
}
// 二叉树遍历函数
void traverse(TreeNode root) {
if (root == null) {
return;
}
// 前序位置
res.add(root.val);
traverse(root.left);
traverse(root.right);
}
}
class Solution {
public:
// 存放前序遍历结果
vector<int> res;
// 返回前序遍历结果
vector<int> preorderTraversal(TreeNode* root) {
traverse(root);
return res;
}
// 二叉树遍历函数
void traverse(TreeNode* root) {
if (root == nullptr) {
return;
}
// 前序位置
res.push_back(root->val);
traverse(root->left);
traverse(root->right);
}
};
class Solution:
def __init__(self):
self.res = []
# 返回前序遍历结果
def preorderTraversal(self, root: TreeNode) -> List[int]:
self.traverse(root)
return self.res
# 二叉树遍历函数
def traverse(self, root: TreeNode):
if root is None:
return
# 前序位置
self.res.append(root.val)
self.traverse(root.left)
self.traverse(root.right)
// preorderTraverse 返回前序遍历结果
func preorderTraversal(root *TreeNode) []int {
// 存放前序遍历结果
var res []int
// 二叉树遍历函数
var traverse func(root *TreeNode)
traverse = func(root *TreeNode) {
if root == nil {
return
}
// 前序位置
res = append(res, root.Val)
traverse(root.Left)
traverse(root.Right)
}
traverse(root)
return res
}
// 返回前序遍历结果
var preorderTraversal = function(root) {
// 存放前序遍历结果
var res = [];
// 二叉树遍历函数
var traverse = function(root) {
if (root == null) {
return;
}
// 前序位置
res.push(root.val);
traverse(root.left);
traverse(root.right);
};
traverse(root);
return res;
}
但你是否能够用「分解问题」的思路,来计算前序遍历的结果?
换句话说,不要用像 traverse
这样的辅助函数和任何外部变量,单纯用题目给的 preorderTraverse
函数递归解题,你会不会?
我们知道前序遍历的特点是,根节点的值排在首位,接着是左子树的前序遍历结果,最后是右子树的前序遍历结果:
那这不就可以分解问题了么,一棵二叉树的前序遍历结果 = 根节点 + 左子树的前序遍历结果 + 右子树的前序遍历结果。
所以,你可以这样实现前序遍历算法:
class Solution {
// 定义:输入一棵二叉树的根节点,返回这棵树的前序遍历结果
List<Integer> preorderTraversal(TreeNode root) {
List<Integer> res = new LinkedList<>();
if (root == null) {
return res;
}
// 前序遍历的结果,root.val 在第一个
res.add(root.val);
// 利用函数定义,后面接着左子树的前序遍历结果
res.addAll(preorderTraversal(root.left));
// 利用函数定义,最后接着右子树的前序遍历结果
res.addAll(preorderTraversal(root.right));
return res;
}
}
class Solution {
public:
// 定义:输入一棵二叉树的根节点,返回这棵树的前序遍历结果
vector<int> preorderTraversal(TreeNode* root) {
vector<int> res;
if (root == nullptr) {
return res;
}
// 前序遍历的结果,root->val 在第一个
res.push_back(root->val);
// 利用函数定义,后面接着左子树的前序遍历结果
vector<int> left = preorderTraversal(root->left);
res.insert(res.end(), left.begin(), left.end());
// 利用函数定义,最后接着右子树的前序遍历结果
vector<int> right = preorderTraversal(root->right);
res.insert(res.end(), right.begin(), right.end());
return res;
}
};
class Solution:
# 定义:输入一棵二叉树的根节点,返回这棵树的前序遍历结果
def preorderTraversal(self, root):
res = []
if root == None:
return res
# 前序遍历的结果,root.val 在第一个
res.append(root.val)
# 利用函数定义,后面接着左子树的前序遍历结果
res.extend(self.preorderTraversal(root.left))
# 利用函数定义,最后接着右子树的前序遍历结果
res.extend(self.preorderTraversal(root.right))
return res
// 定义:输入一棵二叉树的根节点,返回这棵树的前序遍历结果
func preorderTraversal(root *TreeNode) []int {
var res []int
if root == nil {
return res
}
// 前序遍历的结果,root.Val 在第一个
res = append(res, root.Val)
// 利用函数定义,后面接着左子树的前序遍历结果
res = append(res, preorderTraversal(root.Left)...)
// 利用函数定义,最后接着右子树的前序遍历结果
res = append(res, preorderTraversal(root.Right)...)
return res
}
// 定义:输入一棵二叉树的根节点,返回这棵树的前序遍历结果
var preorderTraversal = function(root) {
var res = [];
if (root == null) {
return res;
}
// 前序遍历的结果,root.val 在第一个
res.push(root.val);
// 利用函数定义,后面接着左子树的前序遍历结果
res = res.concat(preorderTraversal(root.left));
// 利用函数定义,最后接着右子树的前序遍历结果
res = res.concat(preorderTraversal(root.right));
return res;
}
中序和后序遍历也是类似的,只要把 add(root.val)
放到中序和后序对应的位置就行了。
这个解法短小精干,但为什么不常见呢?
一个原因是这个算法的复杂度不好把控,比较依赖语言特性。
Java 的话无论 ArrayList 还是 LinkedList,addAll
方法的复杂度都是 O(N),所以总体的最坏时间复杂度会达到 O(N^2),除非你自己实现一个复杂度为 O(1) 的 addAll
方法,底层用链表的话是可以做到的,因为多条链表只要简单的指针操作就能连接起来。
当然,最主要的原因还是因为教科书上从来没有这么教过……
上文举了两个简单的例子,但还有不少二叉树的题目是可以同时使用两种思路来思考和求解的,这就要靠你自己多去练习和思考,不要仅仅满足于一种熟悉的解法思路。
综上,遇到一道二叉树的题目时的通用思考过程是:
1、是否可以通过遍历一遍二叉树得到答案?如果可以,用一个 traverse
函数配合外部变量来实现。
2、是否可以定义一个递归函数,通过子问题(子树)的答案推导出原问题的答案?如果可以,写出这个递归函数的定义,并充分利用这个函数的返回值。
3、无论使用哪一种思维模式,你都要明白二叉树的每一个节点需要做什么,需要在什么时候(前中后序)做。
本站 二叉树递归专项练习 中列举了 100 多道二叉树习题,完全使用上述两种思维模式手把手带你练习,助你完全掌握递归思维,更容易理解高级的算法。
后序位置的特殊之处
说后序位置之前,先简单说下前序和中序。
前序位置本身其实没有什么特别的性质,之所以你发现好像很多题都是在前序位置写代码,实际上是因为我们习惯把那些对前中后序位置不敏感的代码写在前序位置罢了。
中序位置主要用在 BST 场景中,你完全可以把 BST 的中序遍历认为是遍历有序数组。
划重点
仔细观察,前中后序位置的代码,能力依次增强。
前序位置的代码只能从函数参数中获取父节点传递来的数据。
中序位置的代码不仅可以获取参数数据,还可以获取到左子树通过函数返回值传递回来的数据。
后序位置的代码最强,不仅可以获取参数数据,还可以同时获取到左右子树通过函数返回值传递回来的数据。
所以,某些情况下把代码移到后序位置效率最高;有些事情,只有后序位置的代码能做。
举些具体的例子来感受下它们的能力区别。现在给你一棵二叉树,我问你两个简单的问题:
1、如果把根节点看做第 1 层,如何打印出每一个节点所在的层数?
2、如何打印出每个节点的左右子树各有多少节点?
第一个问题可以这样写代码:
// 二叉树遍历函数
void traverse(TreeNode root, int level) {
if (root == null) {
return;
}
// 前序位置
printf("Node %s at level %d", root.val, level);
traverse(root.left, level + 1);
traverse(root.right, level + 1);
}
// 这样调用
traverse(root, 1);
// 二叉树遍历函数
void traverse(TreeNode* root, int level) {
if (root == nullptr) {
return;
}
// 前序位置
printf("Node %d at level %d", root->val, level);
traverse(root->left, level + 1);
traverse(root->right, level + 1);
}
// 这样调用
traverse(root, 1);
# 二叉树遍历函数
def traverse(root, level):
if root is None:
return
# 前序位置
print(f"Node {root.val} at level {level}")
traverse(root.left, level + 1)
traverse(root.right, level + 1)
# 这样调用
traverse(root, 1)
// 二叉树遍历函数
func traverse(root *TreeNode, level int) {
if root == nil {
return
}
// 前序位置
fmt.Printf("Node %v at level %v\n", root.Val, level)
traverse(root.Left, level + 1)
traverse(root.Right, level + 1)
}
// 这样调用
traverse(root, 1)
// 二叉树遍历函数
var traverse = function(root, level) {
if (root == null) {
return;
}
// 前序位置
console.log("Node " + root.val + " at level " + level);
traverse(root.left, level + 1);
traverse(root.right, level + 1);
}
// 这样调用
traverse(root, 1);
第二个问题可以这样写代码:
// 定义:输入一棵二叉树,返回这棵二叉树的节点总数
int count(TreeNode root) {
if (root == null) {
return 0;
}
int leftCount = count(root.left);
int rightCount = count(root.right);
// 后序位置
printf("节点 %s 的左子树有 %d 个节点,右子树有 %d 个节点",
root, leftCount, rightCount);
return leftCount + rightCount + 1;
}
// 定义:输入一棵二叉树,返回这棵二叉树的节点总数
int count(TreeNode* root) {
if (root == nullptr) {
return 0;
}
int leftCount = count(root->left);
int rightCount = count(root->right);
// 后序位置
printf("节点 %p 的左子树有 %d 个节点,右子树有 %d 个节点",
root, leftCount, rightCount);
return leftCount + rightCount + 1;
}
# 定义:输入一棵二叉树,返回这棵二叉树的节点总数
def count(root):
if root is None:
return 0
leftCount = count(root.left)
rightCount = count(root.right)
# 后序位置
print(f"节点 {root} 的左子树有 {leftCount} 个节点,右子树有 {rightCount} 个节点")
return leftCount + rightCount + 1
// 定义:输入一棺二叉树,返回这棵二叉树的节点总数
func count(root *TreeNode) int {
if root == nil {
return 0
}
leftCount := count(root.Left)
rightCount := count(root.Right)
// 后序位置
fmt.Printf("节点 %v 的左子树有 %d 个节点,右子树有 %d 个节点\n",
root, leftCount, rightCount)
return leftCount + rightCount + 1
}
// 定义:输入一棵二叉树,返回这棵二叉树的节点总数
var count = function(root) {
if (root == null) {
return 0;
}
var leftCount = count(root.left);
var rightCount = count(root.right);
// 后序位置
console.log(`节点 ${root} 的左子树有 ${leftCount} 个节点,右子树有 ${rightCount} 个节点`);
return leftCount + rightCount + 1;
};
这两个问题的根本区别在于
一个节点在第几层,你从根节点遍历过来的过程就能顺带记录,用递归函数的参数就能传递下去;而以一个节点为根的整棵子树有多少个节点,你必须遍历完子树之后才能数清楚,然后通过递归函数的返回值拿到答案。
结合这两个简单的问题,你品味一下后序位置的特点,只有后序位置才能通过返回值获取子树的信息。
那么换句话说,一旦你发现题目和子树有关,那大概率要给函数设置合理的定义和返回值,在后序位置写代码了。
接下来看下后序位置是如何在实际的题目中发挥作用的,简单聊下力扣第 543 题「二叉树的直径」,让你计算一棵二叉树的最长直径长度。
所谓二叉树的「直径」长度,就是任意两个结点之间的路径长度。最长「直径」并不一定要穿过根结点,比如下面这棵二叉树:
它的最长直径是 3,即 [4,2,1,3]
,[4,2,1,9]
或者 [5,2,1,3]
这几条「直径」的长度。
解决这题的关键在于,每一条二叉树的「直径」长度,就是一个节点的左右子树的最大深度之和。
现在让我求整棵树中的最长「直径」,那直截了当的思路就是遍历整棵树中的每个节点,然后通过每个节点的左右子树的最大深度算出每个节点的「直径」,最后把所有「直径」求个最大值即可。
最大深度的算法我们刚才实现过了,上述思路就可以写出以下代码:
class Solution {
// 记录最大直径的长度
int maxDiameter = 0;
public int diameterOfBinaryTree(TreeNode root) {
// 对每个节点计算直径,求最大直径
traverse(root);
return maxDiameter;
}
// 遍历二叉树
void traverse(TreeNode root) {
if (root == null) {
return;
}
// 对每个节点计算直径
int leftMax = maxDepth(root.left);
int rightMax = maxDepth(root.right);
int myDiameter = leftMax + rightMax;
// 更新全局最大直径
maxDiameter = Math.max(maxDiameter, myDiameter);
traverse(root.left);
traverse(root.right);
}
// 计算二叉树的最大深度
int maxDepth(TreeNode root) {
if (root == null) {
return 0;
}
int leftMax = maxDepth(root.left);
int rightMax = maxDepth(root.right);
return 1 + Math.max(leftMax, rightMax);
}
}
class Solution {
public:
// 记录最大直径的长度
int maxDiameter = 0;
int diameterOfBinaryTree(TreeNode* root) {
// 对每个节点计算直径,求最大直径
traverse(root);
return maxDiameter;
}
private:
// 遍历二叉树
void traverse(TreeNode* root) {
if (root == nullptr) {
return;
}
// 对每个节点计算直径
int leftMax = maxDepth(root->left);
int rightMax = maxDepth(root->right);
int myDiameter = leftMax + rightMax;
// 更新全局最大直径
maxDiameter = max(maxDiameter, myDiameter);
traverse(root->left);
traverse(root->right);
}
// 计算二叉树的最大深度
int maxDepth(TreeNode* root) {
if (root == nullptr) {
return 0;
}
int leftMax = maxDepth(root->left);
int rightMax = maxDepth(root->right);
return 1 + max(leftMax, rightMax);
}
};
class Solution:
def __init__(self):
# 记录最大直径的长度
self.maxDiameter = 0
def diameterOfBinaryTree(self, root):
# 对每个节点计算直径,求最大直径
self.traverse(root)
return self.maxDiameter
# 遍历二叉树
def traverse(self, root):
if root is None:
return
# 对每个节点计算直径
leftMax = self.maxDepth(root.left)
rightMax = self.maxDepth(root.right)
myDiameter = leftMax + rightMax
# 更新全局最大直径
self.maxDiameter = max(self.maxDiameter, myDiameter)
self.traverse(root.left)
self.traverse(root.right)
# 计算二叉树的最大深度
def maxDepth(self, root):
if root is None:
return 0
leftMax = self.maxDepth(root.left)
rightMax = self.maxDepth(root.right)
return 1 + max(leftMax, rightMax)
// 记录最大直径的长度
func diameterOfBinaryTree(root *TreeNode) int {
maxDiameter := 0
// 遍历二叉树
var traverse func(root *TreeNode)
traverse = func(root *TreeNode) {
if root == nil {
return
}
// 对每个节点计算直径
leftMax := maxDepth(root.Left)
rightMax := maxDepth(root.Right)
myDiameter := leftMax + rightMax
// 更新全局最大直径
maxDiameter = max(maxDiameter, myDiameter)
traverse(root.Left)
traverse(root.Right)
}
traverse(root)
return maxDiameter
}
// 计算二叉树的最大深度
func maxDepth(root *TreeNode) int {
if root == nil {
return 0
}
leftMax := maxDepth(root.Left)
rightMax := maxDepth(root.Right)
return 1 + max(leftMax, rightMax)
}
// 辅助函数,求两个数的最大值
func max(a, b int) int {
if a > b {
return a
}
return b
}
var diameterOfBinaryTree = function(root) {
// 记录最大直径的长度
let maxDiameter = 0;
// 计算二叉树的最大深度
const maxDepth = function(root) {
if (root === null) {
return 0;
}
let leftMax = maxDepth(root.left);
let rightMax = maxDepth(root.right);
return 1 + Math.max(leftMax, rightMax);
};
// 遍历二叉树
const traverse = function(root) {
if (root === null) {
return;
}
// 对每个节点计算直径
let leftMax = maxDepth(root.left);
let rightMax = maxDepth(root.right);
let myDiameter = leftMax + rightMax;
// 更新全局最大直径
maxDiameter = Math.max(maxDiameter, myDiameter);
traverse(root.left);
traverse(root.right);
};
traverse(root);
return maxDiameter;
};
这个解法是正确的,但是运行时间很长,原因也很明显,traverse
遍历每个节点的时候还会调用递归函数 maxDepth
,而 maxDepth
是要遍历子树的所有节点的,所以最坏时间复杂度是 O(N^2)。
这就出现了刚才探讨的情况,前序位置无法获取子树信息,所以只能让每个节点调用 maxDepth
函数去算子树的深度。
那如何优化?我们应该把计算「直径」的逻辑放在后序位置,准确说应该是放在 maxDepth
的后序位置,因为 maxDepth
的后序位置是知道左右子树的最大深度的。
所以,稍微改一下代码逻辑即可得到更好的解法:
class Solution {
// 记录最大直径的长度
int maxDiameter = 0;
public int diameterOfBinaryTree(TreeNode root) {
maxDepth(root);
return maxDiameter;
}
int maxDepth(TreeNode root) {
if (root == null) {
return 0;
}
int leftMax = maxDepth(root.left);
int rightMax = maxDepth(root.right);
// 后序位置,顺便计算最大直径
int myDiameter = leftMax + rightMax;
maxDiameter = Math.max(maxDiameter, myDiameter);
return 1 + Math.max(leftMax, rightMax);
}
}
class Solution {
// 记录最大直径的长度
int maxDiameter = 0;
public:
int diameterOfBinaryTree(TreeNode* root) {
maxDepth(root);
return maxDiameter;
}
int maxDepth(TreeNode* root) {
if (root == nullptr) {
return 0;
}
int leftMax = maxDepth(root->left);
int rightMax = maxDepth(root->right);
// 后序位置,顺便计算最大直径
int myDiameter = leftMax + rightMax;
maxDiameter = max(maxDiameter, myDiameter);
return 1 + max(leftMax, rightMax);
}
};
class Solution:
def __init__(self):
# 记录最大直径的长度
self.maxDiameter = 0
def diameterOfBinaryTree(self, root):
self.maxDepth(root)
return self.maxDiameter
def maxDepth(self, root):
if root is None:
return 0
leftMax = self.maxDepth(root.left)
rightMax = self.maxDepth(root.right)
# 后序位置,顺便计算最大直径
myDiameter = leftMax + rightMax
self.maxDiameter = max(self.maxDiameter, myDiameter)
return 1 + max(leftMax, rightMax)
func diameterOfBinaryTree(root *TreeNode) int {
// 记录最大直径的长度
maxDiameter := 0
var maxDepth func(root *TreeNode) int
maxDepth = func(root *TreeNode) int {
if root == nil {
return 0
}
leftMax := maxDepth(root.Left)
rightMax := maxDepth(root.Right)
// 后序位置,顺便计算最大直径
myDiameter := leftMax + rightMax
maxDiameter = max(maxDiameter, myDiameter)
return 1 + max(leftMax, rightMax)
}
maxDepth(root)
return maxDiameter
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
var diameterOfBinaryTree = function(root) {
let maxDiameter = 0;
const maxDepth = function(root) {
if (root === null) {
return 0;
}
let leftMax = maxDepth(root.left);
let rightMax = maxDepth(root.right);
// 后序位置,顺便计算最大直径
let myDiameter = leftMax + rightMax;
maxDiameter = Math.max(maxDiameter, myDiameter);
return 1 + Math.max(leftMax, rightMax);
}
maxDepth(root);
return maxDiameter;
};
这下时间复杂度只有 maxDepth
函数的 O(N) 了。
讲到这里,照应一下前文:遇到子树问题,首先想到的是给函数设置返回值,然后在后序位置做文章。
相关信息
思考题:请你思考一下,运用后序遍历的题目使用的是「遍历」的思路还是「分解问题」的思路?
点击查看答案
利用后序位置的题目,一般都使用「分解问题」的思路。因为当前节点接收并利用了子树返回的信息,这就意味着你把原问题分解成了当前节点 + 左右子树的子问题。
反过来,如果你写出了类似一开始的那种递归套递归的解法,大概率也需要反思是不是可以通过后序遍历优化了。
更多利用后序位置的习题参见 手把手带你刷二叉树(后序篇)、手把手带你刷二叉搜索树(后序篇) 和 【强化练习】利用后序位置解题。
以树的视角看动归/回溯/DFS算法的区别和联系
前文我说动态规划/回溯算法就是二叉树算法两种不同思路的表现形式,相信能看到这里的读者应该也认可了我这个观点。但有细心的读者经常提问:你的思考方法让我豁然开朗,但你好像一直没讲过 DFS 算法?
其实我在 一文秒杀所有岛屿题目 中就是用的 DFS 算法,但我确实没有单独用一篇文章讲 DFS 算法,因为 DFS 算法和回溯算法非常类似,只是在细节上有所区别。
这个细节上的差别是什么呢?其实就是「做选择」和「撤销选择」到底在 for 循环外面还是里面的区别,DFS 算法在外面,回溯算法在里面。
为什么有这个区别?还是要结合着二叉树理解。这一部分我就把回溯算法、DFS 算法、动态规划三种经典的算法思想,以及它们和二叉树算法的联系和区别,用一句话来说明:
划重点
动归/DFS/回溯算法都可以看做二叉树问题的扩展,只是它们的关注点不同:
- 动态规划算法属于分解问题(分治)的思路,它的关注点在整棵「子树」。
- 回溯算法属于遍历的思路,它的关注点在节点间的「树枝」。
- DFS 算法属于遍历的思路,它的关注点在单个「节点」。
怎么理解?我分别举三个例子你就懂了。
例子一:分解问题的思想体现
第一个例子,给你一棵二叉树,请你用分解问题的思路写一个 count
函数,计算这棵二叉树共有多少个节点。代码很简单,上文都写过了:
// 定义:输入一棵二叉树,返回这棵二叉树的节点总数
int count(TreeNode root) {
if (root == null) {
return 0;
}
// 当前节点关心的是两个子树的节点总数分别是多少
// 因为用子问题的结果可以推导出原问题的结果
int leftCount = count(root.left);
int rightCount = count(root.right);
// 后序位置,左右子树节点数加上自己就是整棵树的节点数
return leftCount + rightCount + 1;
}
// 定义:输入一棵二叉树,返回这棵二叉树的节点总数
int count(TreeNode* root) {
if (root == nullptr) {
return 0;
}
// 当前节点关心的是两个子树的节点总数分别是多少
// 因为用子问题的结果可以推导出原问题的结果
int leftCount = count(root->left);
int rightCount = count(root->right);
// 后序位置,左右子树节点数加上自己就是整棵树的节点数
return leftCount + rightCount + 1;
}
# 定义:输入一棵二叉树,返回这棵二叉树的节点总数
def count(root):
if root is None:
return 0
# 当前节点关心的是两个子树的节点总数分别是多少
# 因为用子问题的结果可以推导出原问题的结果
leftCount = count(root.left)
rightCount = count(root.right)
# 后序位置,左右子树节点数加上自己就是整棵树的节点数
return leftCount + rightCount + 1
// 定义:输入一棵二叉树,返回这棵二叉树的节点总数
func count(root *TreeNode) int {
if root == nil {
return 0
}
// 当前节点关心的是两个子树的节点总数分别是多少
// 因为用子问题的结果可以推导出原问题的结果
leftCount := count(root.Left)
rightCount := count(root.Right)
// 后序位置,左右子树节点数加上自己就是整棵树的节点数
return leftCount + rightCount + 1
}
// 定义:输入一棵二叉树,返回这棵二叉树的节点总数
var count = function(root) {
if (root == null) {
return 0;
}
// 当前节点关心的是两个子树的节点总数分别是多少
// 因为用子问题的结果可以推导出原问题的结果
var leftCount = count(root.left);
var rightCount = count(root.right);
// 后序位置,左右子树节点数加上自己就是整棵树的节点数
return leftCount + rightCount + 1;
}
你看,这就是动态规划分解问题的思路,它的着眼点永远是结构相同的整个子问题,类比到二叉树上就是「子树」。
你再看看具体的动态规划问题,比如 动态规划框架套路详解 中举的斐波那契的例子,我们的关注点在一棵棵子树的返回值上:
int fib(int N) {
if (N == 1 || N == 2) return 1;
return fib(N - 1) + fib(N - 2);
}
int fib(int N) {
if (N == 1 || N == 2) return 1;
return fib(N - 1) + fib(N - 2);
}
def fib(N: int) -> int:
if N == 1 or N == 2:
return 1
return fib(N - 1) + fib(N - 2)
func fib(N int) int {
if N == 1 || N == 2 {
return 1
}
return fib(N - 1) + fib(N - 2)
}
var fib = function(N) {
if (N == 1 || N == 2) {
return 1;
}
return fib(N - 1) + fib(N - 2);
}
例子二:回溯算法的思想体现
第二个例子,给你一棵二叉树,请你用遍历的思路写一个 traverse
函数,打印出遍历这棵二叉树的过程,你看下代码:
void traverse(TreeNode root) {
if (root == null) return;
printf("从节点 %s 进入节点 %s", root, root.left);
traverse(root.left);
printf("从节点 %s 回到节点 %s", root.left, root);
printf("从节点 %s 进入节点 %s", root, root.right);
traverse(root.right);
printf("从节点 %s 回到节点 %s", root.right, root);
}
void traverse(TreeNode* root) {
// 如果节点为空(说明已经越过叶节点),就返回
if (root == nullptr) return;
// 从节点 root 出发,沿着 left 边走
printf("从节点 %s 进入节点 %s", root, root->left);
traverse(root->left);
// 从节点 root 的左边回来
printf("从节点 %s 回到节点 %s", root->left, root);
// 从节点 root 出发,沿着 right 边走
printf("从节点 %s 进入节点 %s", root, root->right);
traverse(root->right);
// 从节点 root 的右边回来
printf("从节点 %s 回到节点 %s", root->right, root);
}
def traverse(root):
if root is None:
return
print("从节点 %s 进入节点 %s" %(root, root.left))
traverse(root.left)
print("从节点 %s 回到节点 %s" %(root.left, root))
print("从节点 %s 进入节点 %s" %(root, root.right))
traverse(root.right)
print("从节点 %s 回到节点 %s" %(root.right, root))
func traverse(root *TreeNode) {
if root == nil {
return
}
fmt.Printf("从节点 %v 进入节点 %v\n", root, root.Left)
traverse(root.Left)
fmt.Printf("从节点 %v 回到节点 %v\n", root.Left, root)
fmt.Printf("从节点 %v 进入节点 %v\n", root, root.Right)
traverse(root.Right)
fmt.Printf("从节点 %v 回到节点 %v\n", root.Right, root)
}
var traverse = function(root) {
if (root == null) return;
console.log("从节点 "+ root + " 进入节点 " + root.left);
traverse(root.left);
console.log("从节点 "+ root.left + " 回到节点 " + root);
console.log("从节点 "+ root + " 进入节点 " + root.right);
traverse(root.right);
console.log("从节点 "+ root.right + " 回到节点 " + root);
}
不难理解吧,好的,我们现在从二叉树进阶成多叉树,代码也是类似的:
// 多叉树节点
class Node {
int val;
Node[] children;
}
void traverse(Node root) {
if (root == null) return;
for (Node child : root.children) {
printf("从节点 %s 进入节点 %s", root, child);
traverse(child);
printf("从节点 %s 回到节点 %s", child, root);
}
}
// 多叉树节点
class Node {
public:
int val;
std::vector<Node*> children;
};
void traverse(Node* root) {
if (root == NULL) return;
for (Node* child : root->children) {
std::cout << "从节点 " << root->val << " 进入节点 " << child->val << std::endl;
traverse(child);
std::cout << "从节点 " << child->val << " 回到节点 " << root->val << std::endl;
}
}
# 多叉树节点
class Node:
def __init__(self, val=0, children=None):
self.val = val
self.children = children if children is not None else []
def traverse(root):
if not root: return
for child in root.children:
print(f"从节点 {root} 进入节点 {child}")
traverse(child)
print(f"从节点 {child} 回到节点 {root}")
// 多叉树节点
type Node struct {
val int
children []*Node
}
func traverse(root *Node) {
if root == nil {
return
}
for _, child := range root.children {
fmt.Printf("Enter from node %v to node %v\n", root, child)
traverse(child)
fmt.Printf("Return from node %v to node %v\n", child, root)
}
}
// 多叉树节点
var Node = function() {
this.val = 0;
this.children = [];
}
var traverse = function(root) {
if (root === null) return;
for (var i = 0; i < root.children.length; i++) {
var child = root.children[i];
console.log("从节点 " + root + " 进入节点 " + child);
traverse(child);
console.log("从节点 " + child + " 回到节点 " + root);
}
}
这个多叉树的遍历框架就可以延伸出 回溯算法框架套路详解 中的回溯算法框架:
// 回溯算法框架
void backtrack(...) {
// base case
if (...) return;
for (int i = 0; i < ...; i++) {
// 做选择
...
// 进入下一层决策树
backtrack(...);
// 撤销刚才做的选择
...
}
}
你看,这就是回溯算法遍历的思路,它的着眼点永远是在节点之间移动的过程,类比到二叉树上就是「树枝」。
你再看看具体的回溯算法问题,比如 回溯算法秒杀排列组合子集的九种题型 中讲到的全排列,我们的关注点在一条条树枝上:
// 回溯算法核心部分代码
void backtrack(int[] nums) {
// 回溯算法框架
for (int i = 0; i < nums.length; i++) {
// 做选择
used[i] = true;
track.addLast(nums[i]);
// 进入下一层回溯树
backtrack(nums);
// 取消选择
track.removeLast();
used[i] = false;
}
}
例子三:DFS 的思想体现
第三个例子,我给你一棵二叉树,请你写一个 traverse
函数,把这棵二叉树上的每个节点的值都加一。很简单吧,代码如下:
void traverse(TreeNode root) {
if (root == null) return;
// 遍历过的每个节点的值加一
root.val++;
traverse(root.left);
traverse(root.right);
}
void traverse(TreeNode* root) {
if (root == nullptr) return;
// 遍历过的每个节点的值加一
root->val++;
traverse(root->left);
traverse(root->right);
}
def traverse(root):
if root is None:
return
# 遍历过的每个节点的值加一
root.val += 1
traverse(root.left)
traverse(root.right)
func traverse(root *TreeNode) {
if root == nil {
return
}
// 遍历过的每个节点的值加一
root.Val++
traverse(root.Left)
traverse(root.Right)
}
var traverse = function(root) {
if (root === null) return;
// 遍历过的每个节点的值加一
root.val++;
traverse(root.left);
traverse(root.right);
}
你看,这就是 DFS 算法遍历的思路,它的着眼点永远是在单一的节点上,类比到二叉树上就是处理每个「节点」。
你再看看具体的 DFS 算法问题,比如 一文秒杀所有岛屿题目 中讲的前几道题,我们的关注点是 grid
数组的每个格子(节点),我们要对遍历过的格子进行一些处理,所以我说是用 DFS 算法解决这几道题的:
// DFS 算法核心逻辑
void dfs(int[][] grid, int i, int j) {
int m = grid.length, n = grid[0].length;
if (i < 0 || j < 0 || i >= m || j >= n) {
return;
}
if (grid[i][j] == 0) {
return;
}
// 遍历过的每个格子标记为 0
grid[i][j] = 0;
dfs(grid, i + 1, j);
dfs(grid, i, j + 1);
dfs(grid, i - 1, j);
dfs(grid, i, j - 1);
}
好,请你仔细品一下上面三个简单的例子,是不是像我说的:动态规划关注整棵「子树」,回溯算法关注节点间的「树枝」,DFS 算法关注单个「节点」。
有了这些铺垫,你就很容易理解为什么回溯算法和 DFS 算法代码中「做选择」和「撤销选择」的位置不同了,看下面两段代码:
// DFS 算法把「做选择」「撤销选择」的逻辑放在 for 循环外面
void dfs(Node root) {
if (root == null) return;
// 做选择
print("enter node %s", root);
for (Node child : root.children) {
dfs(child);
}
// 撤销选择
print("leave node %s", root);
}
// 回溯算法把「做选择」「撤销选择」的逻辑放在 for 循环里面
void backtrack(Node root) {
if (root == null) return;
for (Node child : root.children) {
// 做选择
print("I'm on the branch from %s to %s", root, child);
backtrack(child);
// 撤销选择
print("I'll leave the branch from %s to %s", child, root);
}
}
// DFS 算法把「做选择」「撤销选择」的逻辑放在 for 循环外面
void dfs(Node* root) {
if (!root) return;
// 做选择
printf("enter node %s\n", root->val.c_str());
for (Node* child : root->children) {
dfs(child);
}
// 撤销选择
printf("leave node %s\n", root->val.c_str());
}
// 回溯算法把「做选择」「撤销选择」的逻辑放在 for 循环里面
void backtrack(Node* root) {
if (!root) return;
for (Node* child : root->children) {
// 做选择
printf("I'm on the branch from %s to %s\n", root->val.c_str(), child->val.c_str());
backtrack(child);
// 撤销选择
printf("I'll leave the branch from %s to %s\n", child->val.c_str(), root->val.c_str());
}
}
# DFS 算法把「做选择」「撤销选择」的逻辑放在 for 循环外面
def dfs(root):
if root is None:
return
# 做选择
print("enter node %s" % root)
for child in root.children:
dfs(child)
# 撤销选择
print("leave node %s" % root)
# 回溯算法把「做选择」「撤销选择」的逻辑放在 for 循环里面
def backtrack(root):
if root is None:
return
for child in root.children:
# 做选择
print("I'm on the branch from %s to %s" % (root, child))
backtrack(child)
# 撤销选择
print("I'll leave the branch from %s to %s" % (child, root))
import "fmt"
// DFS 算法把「做选择」「撤销选择」的逻辑放在 for 循环外面
func Dfs(root *Node) {
if root == nil {
return
}
// 做选择
fmt.Printf("enter node %v\n", root)
for _, child := range root.children {
Dfs(child)
}
// 撤销选择
fmt.Printf("leave node %v\n", root)
}
// 回溯算法把「做选择」「撤销选择」的逻辑放在 for 循环里面
func Backtrack(root *Node) {
if root == nil {
return
}
for _, child := range root.children {
// 做选择
fmt.Printf("I'm on the branch from %v to %v\n", root, child)
Backtrack(child)
// 撤销选择
fmt.Printf("I'll leave the branch from %v to %v\n", child, root)
}
}
// DFS 算法把「做选择」「撤销选择」的逻辑放在 for 循环外面
var dfs = function(root) {
if (root == null) return;
// 做选择
console.log("enter node %s", root);
for (var i = 0; i < root.children.length; i++) {
dfs(root.children[i]);
}
// 撤销选择
console.log("leave node %s", root);
}
// 回溯算法把「做选择」「撤销选择」的逻辑放在 for 循环里面
var backtrack = function(root) {
if (root == null) return;
for (var i = 0; i < root.children.length; i++) {
var child = root.children[i];
// 做选择
console.log("I'm on the branch from %s to %s", root, child);
backtrack(child);
// 撤销选择
console.log("I'll leave the branch from %s to %s", child, root);
}
}
看到了吧,你回溯算法必须把「做选择」和「撤销选择」的逻辑放在 for 循环里面,否则怎么拿到「树枝」的两个端点?
层序遍历
二叉树题型主要是用来培养递归思维的,而层序遍历属于迭代遍历,也比较简单,这里就过一下代码框架吧:
// 输入一棵二叉树的根节点,层序遍历这棵二叉树
void levelTraverse(TreeNode root) {
if (root == null) return;
Queue<TreeNode> q = new LinkedList<>();
q.offer(root);
// 从上到下遍历二叉树的每一层
while (!q.isEmpty()) {
int sz = q.size();
// 从左到右遍历每一层的每个节点
for (int i = 0; i < sz; i++) {
TreeNode cur = q.poll();
// 将下一层节点放入队列
if (cur.left != null) {
q.offer(cur.left);
}
if (cur.right != null) {
q.offer(cur.right);
}
}
}
}
// 输入一棵二叉树的根节点,层序遍历这棵二叉树
void levelTraverse(TreeNode* root) {
if (root == nullptr) return;
queue<TreeNode*> q;
q.push(root);
// 从上到下遍历二叉树的每一层
while (!q.empty()) {
int sz = q.size();
// 从左到右遍历每一层的每个节点
for (int i = 0; i < sz; i++) {
TreeNode* cur = q.front();
q.pop();
// 将下一层节点放入队列
if (cur->left != nullptr) {
q.push(cur->left);
}
if (cur->right != nullptr) {
q.push(cur->right);
}
}
}
}
# 输入一棵二叉树的根节点,层序遍历这棵二叉树
def levelTraverse(root: TreeNode) -> None:
if root is None:
return
q = [root]
# 从上到下遍历二叉树的每一层
while q:
sz = len(q)
# 从左到右遍历每一层的每个节点
for _ in range(sz):
cur = q.pop(0)
# 将下一层节点放入队列
if cur.left is not None:
q.append(cur.left)
if cur.right is not None:
q.append(cur.right)
// 输入一棵二叉树的根节点,层序遍历这棵二叉树
func levelTraverse(root *TreeNode) {
if root == nil {
return
}
q := make([]*TreeNode, 0)
q = append(q, root)
// 从上到下遍历二叉树的每一层
for len(q) > 0 {
sz := len(q)
// 从左到右遍历每一层的每个节点
for i := 0; i < sz; i++ {
cur := q[0]
q = q[1:]
// 将下一层节点放入队列
if cur.Left != nil {
q = append(q, cur.Left)
}
if cur.Right != nil {
q = append(q, cur.Right)
}
}
}
}
// 输入一棵二叉树的根节点,层序遍历这棵二叉树
var levelTraverse = function(root) {
if (root == null) return;
var q = [];
q.push(root);
// 从上到下遍历二叉树的每一层
while (q.length != 0) {
var sz = q.length;
// 从左到右遍历每一层的每个节点
for (var i = 0; i < sz; i++) {
var cur = q.shift();
// 将下一层节点放入队列
if (cur.left != null) {
q.push(cur.left);
}
if (cur.right != null) {
q.push(cur.right);
}
}
}
};
这里面 while 循环和 for 循环分管从上到下和从左到右的遍历:
前文 BFS 算法框架 就是从二叉树的层序遍历扩展出来的,常用于求无权图的最短路径问题。
当然这个框架还可以灵活修改,题目不需要记录层数(步数)时可以去掉上述框架中的 for 循环,比如后文 Dijkstra 算法 中计算加权图的最短路径问题,详细探讨了 BFS 算法的扩展。
值得一提的是,有些很明显需要用层序遍历技巧的二叉树的题目,也可以用递归遍历的方式去解决,而且技巧性会更强,非常考察你对前中后序的把控。
好了,本文已经够长了,围绕前中后序位置算是把二叉树题目里的各种套路给讲透了,真正能运用出来多少,就需要你亲自刷题实践和思考了。
希望大家能探索尽可能多的解法,只要参透二叉树这种基本数据结构的原理,那么就很容易在学习其他高级算法的道路上找到抓手,打通回路,形成闭环(手动狗头)。
最后,二叉树递归专项练习 中会手把手带你运用本文所讲的技巧。
回答评论区的问题
关于层序遍历(以及其扩展出的 BFS 算法框架),我在最后多说几句。
如果你对二叉树足够熟悉,可以想到很多方式通过递归函数得到层序遍历结果,比如下面这种写法:
class Solution {
List<List<Integer>> res = new ArrayList<>();
public List<List<Integer>> levelTraverse(TreeNode root) {
if (root == null) {
return res;
}
// root 视为第 0 层
traverse(root, 0);
return res;
}
void traverse(TreeNode root, int depth) {
if (root == null) {
return;
}
// 前序位置,看看是否已经存储 depth 层的节点了
if (res.size() <= depth) {
// 第一次进入 depth 层
res.add(new LinkedList<>());
}
// 前序位置,在 depth 层添加 root 节点的值
res.get(depth).add(root.val);
traverse(root.left, depth + 1);
traverse(root.right, depth + 1);
}
}
class Solution {
public:
vector<vector<int>> res;
vector<vector<int>> levelTraverse(TreeNode* root) {
if (root == nullptr) {
return res;
}
// root 视为第 0 层
traverse(root, 0);
return res;
}
void traverse(TreeNode* root, int depth) {
if (root == nullptr) {
return;
}
// 前序位置,看看是否已经存储 depth 层的节点了
if (res.size() <= depth) {
// 第一次进入 depth 层
res.push_back(vector<int> {});
}
// 前序位置,在 depth 层添加 root 节点的值
res[depth].push_back(root->val);
traverse(root->left, depth + 1);
traverse(root->right, depth + 1);
}
};
class Solution:
def __init__(self):
self.res = []
def levelTraverse(self, root):
if root is None:
return self.res
# root 视为第 0 层
self.traverse(root, 0)
return self.res
def traverse(self, root, depth):
if root is None:
return
# 前序位置,看看是否已经存储 depth 层的节点了
if len(self.res) <= depth:
# 第一次进入 depth 层
self.res.append([])
# 前序位置,在 depth 层添加 root 节点的值
self.res[depth].append(root.val)
self.traverse(root.left, depth + 1)
self.traverse(root.right, depth + 1)
func levelTraverse(root *TreeNode) [][]int {
res := make([][]int, 0)
if root == nil {
return res
}
// root 视为第 0 层
traverse(root, 0, &res)
return res
}
func traverse(root *TreeNode, depth int, res *[][]int) {
if root == nil {
return
}
// 前序位置,看看是否已经存储 depth 层的节点了
if len(*res) <= depth {
// 第一次进入 depth 层
*res = append(*res, make([]int, 0))
}
// 前序位置,在 depth 层添加 root 节点的值
(*res)[depth] = append((*res)[depth], root.Val)
traverse(root.Left, depth+1, res)
traverse(root.Right, depth+1, res)
}
var levelTraverse = function(){
res = [];
const traverse = function(root, depth) {
if (root === null) {
return;
}
// 前序位置,看看是否已经存储 depth 层的节点了
if (res.length <= depth) {
// 第一次进入 depth 层
res.push([]);
}
// 前序位置,在 depth 层添加 root 节点的值
res[depth].push(root.val);
traverse(root.left, depth + 1);
traverse(root.right, depth + 1);
}
// root 视为第 0 层
traverse(root, 0);
return res;
}
这种思路从结果上说确实可以得到层序遍历结果,但其本质还是二叉树的前序遍历,或者说 DFS 的思路,而不是层序遍历,或者说 BFS 的思路。因为这个解法是依赖前序遍历自顶向下、自左向右的顺序特点得到了正确的结果。
抽象点说,这个解法更像是从左到右的「列序遍历」,而不是自顶向下的「层序遍历」。所以对于计算最小距离的场景,这个解法完全等同于 DFS 算法,没有 BFS 算法的性能的优势。
还有优秀读者评论了这样一种递归进行层序遍历的思路:
class Solution {
List<List<Integer>> res = new LinkedList<>();
public List<List<Integer>> levelTraverse(TreeNode root) {
if (root == null) {
return res;
}
List<TreeNode> nodes = new LinkedList<>();
nodes.add(root);
traverse(nodes);
return res;
}
void traverse(List<TreeNode> curLevelNodes) {
// base case
if (curLevelNodes.isEmpty()) {
return;
}
// 前序位置,计算当前层的值和下一层的节点列表
List<Integer> nodeValues = new LinkedList<>();
List<TreeNode> nextLevelNodes = new LinkedList<>();
for (TreeNode node : curLevelNodes) {
nodeValues.add(node.val);
if (node.left != null) {
nextLevelNodes.add(node.left);
}
if (node.right != null) {
nextLevelNodes.add(node.right);
}
}
// 前序位置添加结果,可以得到自顶向下的层序遍历
res.add(nodeValues);
traverse(nextLevelNodes);
// 后序位置添加结果,可以得到自底向上的层序遍历结果
// res.add(nodeValues);
}
}
class Solution {
private:
vector<vector<int>> res;
public:
vector<vector<int>> levelTraverse(TreeNode* root) {
if (root == NULL) {
return res;
}
vector<TreeNode*> nodes;
nodes.push_back(root);
traverse(nodes);
return res;
}
void traverse(vector<TreeNode*>& curLevelNodes) {
// base case
if (curLevelNodes.empty()) {
return;
}
// 前序位置,计算当前层的值和下一层的节点列表
vector<int> nodeValues;
vector<TreeNode*> nextLevelNodes;
for (TreeNode* node : curLevelNodes) {
nodeValues.push_back(node->val);
if (node->left != NULL) {
nextLevelNodes.push_back(node->left);
}
if (node->right != NULL) {
nextLevelNodes.push_back(node->right);
}
}
// 前序位置添加结果,可以得到自顶向下的层序遍历
res.push_back(nodeValues);
traverse(nextLevelNodes);
// 后序位置添加结果,可以得到自底向上的层序遍历结果
// res.push_back(nodeValues);
}
};
class Solution:
def __init__(self):
self.res = []
def levelTraverse(self, root):
if not root:
return self.res
nodes = [root]
self.traverse(nodes)
return self.res
def traverse(self, curLevelNodes):
# base case
if not curLevelNodes:
return
# 前序位置,计算当前层的值和下一层的节点列表
nodeValues = []
nextLevelNodes = []
for node in curLevelNodes:
nodeValues.append(node.val)
if node.left:
nextLevelNodes.append(node.left)
if node.right:
nextLevelNodes.append(node.right)
# 前序位置添加结果,可以得到自顶向下的层序遍历
self.res.append(nodeValues)
self.traverse(nextLevelNodes)
# 后序位置添加结果,可以得到自底向上的层序遍历结果
# res.append(nodeValues)
// The converted levelTraverse function
func levelTraverse(root *TreeNode) [][]int {
res := [][]int{}
if root == nil {
return res
}
nodes := []*TreeNode{root}
traverse(nodes, &res)
// return the final results
return res
}
// The converted traverse function
func traverse(curLevelNodes []*TreeNode, res *[][]int) {
// base case
if len(curLevelNodes) == 0 {
return
}
// 前序位置,计算当前层的值和下一层的节点列表
nodeValues := []int{}
nextLevelNodes := []*TreeNode{}
for _, node := range curLevelNodes {
nodeValues = append(nodeValues, node.Val)
if node.Left != nil {
nextLevelNodes = append(nextLevelNodes, node.Left)
}
if node.Right != nil {
nextLevelNodes = append(nextLevelNodes, node.Right)
}
}
// 前序位置添加结果,可以得到自顶向下的层序遍历
*res = append(*res, nodeValues)
// Traverse the next level
traverse(nextLevelNodes, res)
// 后序位置添加结果,可以得到自底向上的层序遍历结果
// *res = append(*res, nodeValues)
}
var res = [];
var levelTraverse = function(root) {
if (root === null) {
return res;
}
let nodes = [];
nodes.push(root);
traverse(nodes);
return res;
};
var traverse = function(curLevelNodes) {
// base case
if (curLevelNodes.length === 0) {
return;
}
// 前序位置,计算当前层的值和下一层的节点列表
let nodeValues = [];
let nextLevelNodes = [];
for (let node of curLevelNodes) {
nodeValues.push(node.val);
if (node.left !== null) {
nextLevelNodes.push(node.left);
}
if (node.right !== null) {
nextLevelNodes.push(node.right);
}
}
// 前序位置添加结果,可以得到自顶向下的层序遍历
res.push(nodeValues);
traverse(nextLevelNodes);
// 后序位置添加结果,可以得到自底向上的层序遍历结果
// res.push(nodeValues);
};
这个 traverse
函数很像递归遍历单链表的函数,其实就是把二叉树的每一层抽象理解成单链表的一个节点进行遍历。
相较上一个递归解法,这个递归解法是自顶向下的「层序遍历」,更接近 BFS 的奥义,可以作为 BFS 算法的递归实现扩展一下思维。
引用本文的题目
安装 我的 Chrome 刷题插件 点开下列题目可直接查看解题思路: