diff --git a/hamilton/graph_utils.py b/hamilton/graph_utils.py index 1913ef3c0..b41c15f17 100644 --- a/hamilton/graph_utils.py +++ b/hamilton/graph_utils.py @@ -21,7 +21,9 @@ def is_submodule(child: ModuleType, parent: ModuleType): - return parent.__name__ in child.__name__ + if child is None: + return False + return child.__name__ == parent.__name__ or child.__name__.startswith(parent.__name__ + ".") def find_functions(function_module: ModuleType) -> list[tuple[str, Callable]]: diff --git a/tests/test_graph.py b/tests/test_graph.py index 6c6faf3b9..17e99dd1d 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -52,6 +52,40 @@ import tests.resources.typing_vs_not_typing +@pytest.mark.parametrize( + ("child_name", "parent_name", "expected"), + [ + ("foo", "foo", True), # same module + ("foo.bar", "foo", True), # direct child + ("foo.bar.baz", "foo", True), # nested child + ("foo.bar.baz", "foo.bar", True), # nested child of subpackage + ("foobar", "foo", False), # not a submodule, just a prefix without dot separator + ("hamilton.function_modifiers", "modifiers", False), # substring match, not a submodule + ("hamilton.function_modifiers.dependencies", "modifiers", False), # substring deeper + ("x.foo.y", "foo", False), # parent name in the middle, not a prefix + ("bar", "foo", False), # completely unrelated + ], + ids=[ + "same_module", + "direct_child", + "nested_child", + "nested_child_of_subpackage", + "prefix_without_dot", + "substring_not_submodule", + "substring_deeper", + "parent_in_middle", + "unrelated", + ], +) +def test_is_submodule(child_name, parent_name, expected): + """Tests that is_submodule correctly checks module hierarchy using prefix matching.""" + from types import ModuleType + + child = ModuleType(child_name) + parent = ModuleType(parent_name) + assert hamilton.graph_utils.is_submodule(child, parent) == expected + + def test_find_functions(): """Tests that we filter out _ functions when passed a module and don't pull in anything from the imports.""" expected = [ @@ -64,6 +98,44 @@ def test_find_functions(): assert actual == expected +def test_find_functions_excludes_imports_with_substring_module_name(): + """Regression test: imported functions should not be included when the user module's + name is a substring of the imported function's module path. + + Previously, is_submodule used `parent.__name__ in child.__name__` (substring match), + which caused e.g. a module named 'modifiers' to pull in functions from + 'hamilton.function_modifiers'. + """ + import sys + from types import ModuleType + + from hamilton.function_modifiers import source, value + + # Create a fake module named "modifiers" with one real function and two imports + mod = ModuleType("modifiers") + + def my_func(x: int) -> int: + return x * 2 + + # Assign the function to the module so inspect.getmodule can resolve it + my_func.__module__ = "modifiers" + mod.my_func = my_func + mod.source = source + mod.value = value + + # Register in sys.modules so inspect.getmodule can find it + sys.modules["modifiers"] = mod + try: + actual = hamilton.graph_utils.find_functions(mod) + actual_names = [name for name, _ in actual] + assert actual_names == ["my_func"], ( + f"Expected only ['my_func'] but got {actual_names}. " + "Imported functions from hamilton.function_modifiers should not be included." + ) + finally: + del sys.modules["modifiers"] + + def test_find_functions_from_temporary_function_module(): """Tests that we handle the TemporaryFunctionModule object correctly.""" expected = [