Yifei Li
MSCS @ UBC | Vancouver, BC
A common operation on a binary tree is to traverse all the nodes in the tree. Normally, there are three ways to do the traversing: pre-order, in-order and post-order. Recursive algorithms are well suited for this traversing task. However, we need a stack to keep track of the previous nodes we want to go back, leading to a O(n) space complexity.
In this post, I will introduce an algorithm called Morris Traversal that is designed to do the in-order traversal with O(1) space complexity and O(n) time complexity. The intuition is we manipulate the right pointers of the rightmost nodes in each left subtree so we can directly access the next node when finish visiting the left subtree.
Binary Search Tree Iterator
Implement the BSTIterator class that represents an iterator over the in-order traversal of a binary search tree (BST)
(from leetcode 173)
Explanation
For the in-order traversal, the general rule is to follow the left child -> node -> right child visiting order for each node. The steps of the Morris traversal for in-order traversal are:
- we start from
current_node = root - while
current_nodeis not NULL:- if
current_nodehas a left child (go to left half and change right child pointer if needed):- find the rightmost node in its left subtree as
prev_node- if the
prev_node’s right child iscurrent_node, reset the right child pointer and visitcurrent_node. Assigncurrent_nodeto becurrent_node’s right child - else, set
prev_node’s right child to becurrent_node. Assigncurrent_nodeto becurrent_node’s left child
- if the
- find the rightmost node in its left subtree as
- else (no left subtree to check, time to visit the current node and then move on to the right half):
- visit
current_node, assigncurrent_nodeto becurrent_node’s right child
- visit
- if
To implement an iterator with this algorithm, we just need to keep track of an attribute self.cur. This attribute stores the node pointer of the node we want to output next. The naming can be a little misleading though, I named it as self.cur just to be consistent with the variable name I use in the above explanation.
Example
Here’s a step-by-step illustration of how morris traversal works on a binary search tree. A node with orange-filled color means this node has been visited.
Python3 Implementation
class BSTIterator:
def __init__(self, root: Optional[TreeNode]):
self.cur = root
def next(self) -> int:
while self.cur:
if self.cur.left:
prev = self.cur.left
while prev.right and prev.right!=self.cur:
prev = prev.right
if prev.right == self.cur:
prev.right = None
res = self.cur.val
self.cur = self.cur.right
return res
else:
prev.right = self.cur
self.cur = self.cur.left
else:
res = self.cur.val
self.cur = self.cur.right
return res
def hasNext(self) -> bool:
return self.cur!=None
Recover Binary Search Tree
You are given the root of a binary search tree (BST), where the values of exactly two nodes of the tree were swapped by mistake. Recover the tree without changing its structure.
(from leetcode 99)
Explanation
We need to find a way to detect the two nodes being swapped in the tree. The in-order traversal of a normal binary search tree is an array of non-decreasing sequence. We can in turn observe what the in-order traversal will be like when exactly two nodes are swapped.
Assume the in-order traversal of a correct BST is 1,2,3,4,5,6
4and5are swapped, traversal output will be1,2,3,5,4,6. There’s a decreasing trend between5and4.3and5are swapped, traversal output will be1,2,5,4,3,6. There are two decreasing trend:5->4and4->3.2and5are swapped, traversal output will be1,5,3,4,2,6. There are two decreasing trend:5->3and4->2.
We can observe from the above examples that if we keep track of all the numbers involved in decreasing trends, the first and the last are the original swapping numbers.
Now we can just use morris traversal and swap the two nodes we find before function returns.
Python3 Implementation
def recoverTree(self, root: Optional[TreeNode]) -> None:
cur = root
last = TreeNode(float('-inf'))
swaps = []
while cur:
if cur.left:
prev = cur.left
while prev.right and prev.right!=cur:
prev = prev.right
if prev.right == cur:
prev.right = None
if cur.val<last.val:
# last = cur.val
swaps.append(last)
swaps.append(cur)
last = cur
cur = cur.right
else:
prev.right = cur
cur = cur.left
else:
if cur.val<last.val:
swaps.append(last)
swaps.append(cur)
last = cur
cur = cur.right
swaps[0].val, swaps[-1].val = swaps[-1].val, swaps[0].val