Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion hamilton/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@


def is_submodule(child: ModuleType, parent: ModuleType):
return parent.__name__ in child.__name__
if child is None:
return False
return child.__name__ == parent.__name__ or child.__name__.startswith(parent.__name__ + ".")


def find_functions(function_module: ModuleType) -> list[tuple[str, Callable]]:
Expand Down
72 changes: 72 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,40 @@
import tests.resources.typing_vs_not_typing


@pytest.mark.parametrize(
("child_name", "parent_name", "expected"),
[
("foo", "foo", True), # same module
("foo.bar", "foo", True), # direct child
("foo.bar.baz", "foo", True), # nested child
("foo.bar.baz", "foo.bar", True), # nested child of subpackage
("foobar", "foo", False), # not a submodule, just a prefix without dot separator
("hamilton.function_modifiers", "modifiers", False), # substring match, not a submodule
("hamilton.function_modifiers.dependencies", "modifiers", False), # substring deeper
("x.foo.y", "foo", False), # parent name in the middle, not a prefix
("bar", "foo", False), # completely unrelated
],
ids=[
"same_module",
"direct_child",
"nested_child",
"nested_child_of_subpackage",
"prefix_without_dot",
"substring_not_submodule",
"substring_deeper",
"parent_in_middle",
"unrelated",
],
)
def test_is_submodule(child_name, parent_name, expected):
"""Tests that is_submodule correctly checks module hierarchy using prefix matching."""
from types import ModuleType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top level import please.


child = ModuleType(child_name)
parent = ModuleType(parent_name)
assert hamilton.graph_utils.is_submodule(child, parent) == expected


def test_find_functions():
"""Tests that we filter out _ functions when passed a module and don't pull in anything from the imports."""
expected = [
Expand All @@ -64,6 +98,44 @@ def test_find_functions():
assert actual == expected


def test_find_functions_excludes_imports_with_substring_module_name():
"""Regression test: imported functions should not be included when the user module's
name is a substring of the imported function's module path.

Previously, is_submodule used `parent.__name__ in child.__name__` (substring match),
which caused e.g. a module named 'modifiers' to pull in functions from
'hamilton.function_modifiers'.
"""
import sys
from types import ModuleType

from hamilton.function_modifiers import source, value
Comment on lines +109 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top level imports please


# Create a fake module named "modifiers" with one real function and two imports
mod = ModuleType("modifiers")

def my_func(x: int) -> int:
return x * 2

# Assign the function to the module so inspect.getmodule can resolve it
my_func.__module__ = "modifiers"
mod.my_func = my_func
mod.source = source
mod.value = value

# Register in sys.modules so inspect.getmodule can find it
sys.modules["modifiers"] = mod
try:
actual = hamilton.graph_utils.find_functions(mod)
actual_names = [name for name, _ in actual]
assert actual_names == ["my_func"], (
f"Expected only ['my_func'] but got {actual_names}. "
"Imported functions from hamilton.function_modifiers should not be included."
)
finally:
del sys.modules["modifiers"]


def test_find_functions_from_temporary_function_module():
"""Tests that we handle the TemporaryFunctionModule object correctly."""
expected = [
Expand Down