For Tree problems, we are usually given the tree input as a List, which makes it hard if we want to debug our program. I wrote a simple tree constructor program which returns the root of the tree, which you can pass into your function.
def tree_constructor(node_list: List[int]) -> TreeNode:
root = TreeNode(node_list[0])
curr = [root]
ptr = 1
while ptr < len(node_list):
curr2 = []
for c in curr:
if node_list[ptr] is not None:
new = TreeNode(node_list[ptr])
c.left = new
curr2.append(new)
ptr += 1
if node_list[ptr] is not None:
new = TreeNode(node_list[ptr])
c.right = new
curr2.append(new)
ptr += 1
curr = curr2
return root