diff --git a/hamilton/graph.py b/hamilton/graph.py index 11b6f55f3..d807a6ead 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -1090,13 +1090,19 @@ def directional_dfs_traverse( nodes = set() user_nodes = set() - def dfs_traverse(node: node.Node): - nodes.add(node) - for n in next_nodes_fn(node): - if n not in nodes: - dfs_traverse(n) - if node.user_defined: - user_nodes.add(node) + def dfs_traverse_iterative(start_node: node.Node): + """Iterative DFS to avoid recursion depth limits with large DAGs.""" + stack = [start_node] + while stack: + n = stack.pop() + if n in nodes: + continue + nodes.add(n) + if n.user_defined: + user_nodes.add(n) + for next_n in next_nodes_fn(n): + if next_n not in nodes: + stack.append(next_n) missing_vars = [] for var in starting_nodes: @@ -1107,7 +1113,7 @@ def dfs_traverse(node: node.Node): # if it's not in the runtime inputs, it's a properly missing variable missing_vars.append(var) continue # collect all missing final variables - dfs_traverse(self.nodes[var]) + dfs_traverse_iterative(self.nodes[var]) if missing_vars: missing_vars_str = ",\n".join(missing_vars) raise ValueError(f"Unknown nodes [{missing_vars_str}] requested. Check for typos?") diff --git a/tests/test_graph.py b/tests/test_graph.py index 48027e7b6..de02135ed 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -17,6 +17,7 @@ import inspect import pathlib +import sys import uuid from itertools import permutations from typing import List @@ -1479,3 +1480,41 @@ def test_display_name_list_value_uses_first_element(): assert "First Name" in dot_string # The function name should NOT appear since display_name is set assert "node_with_list_display_name" not in dot_string + + +def test_get_upstream_nodes_large_chain_no_recursion_error(): + """Regression test: get_upstream_nodes with only final_node on a large chain DAG. + + A recursive DFS would exceed Python's recursion limit (~1000) when traversing + a long dependency chain from a single final node. This test verifies that + the iterative DFS in directional_dfs_traverse handles large DAGs correctly. + + Chain size is chosen to exceed recursion limit: 1200 nodes > 1000. + """ + from hamilton import ad_hoc_utils + from hamilton import function_modifiers as fm + + def step(prev: float) -> float: + """Single step in a linear chain.""" + return prev + 1.0 + + # Build a linear chain: node_0 -> node_1 -> ... -> node_N + chain_size = sys.getrecursionlimit() + 200 # Exceeds recursion limit + config = {} + for i in range(chain_size): + prev = f"node_{i - 1}" if i > 0 else 0.0 + config[f"node_{i}"] = { + "prev": fm.source(prev) if i > 0 else fm.value(0.0), + } + decorated = fm.parameterize(**config)(step) + module = ad_hoc_utils.create_temporary_module(decorated, module_name="large_chain") + + fg = graph.FunctionGraph.from_modules(module, config={}) + final_node = f"node_{chain_size - 1}" + + # This would raise RecursionError with recursive DFS + nodes, user_nodes = fg.get_upstream_nodes([final_node]) + + assert len(nodes) == chain_size + assert len(user_nodes) == 0 + assert all(fg.nodes[f"node_{i}"] in nodes for i in range(chain_size))