Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import copy
import inspect
import sys
import textwrap
import types
import typing
from collections.abc import AsyncIterator
Expand Down Expand Up @@ -49,6 +50,56 @@
from typing import Self

from burr.core.state import State


def _validate_declared_reads(fn: Callable, declared_reads: list[str]) -> None:
if not declared_reads:
return

try:
source = inspect.getsource(fn)
except OSError:
return # skip if source unavailable

# detect actual state parameter name
sig = inspect.signature(fn)
state_param_name = None

for name, param in sig.parameters.items():
if param.annotation is State:
state_param_name = name
break

if state_param_name is None:
return

tree = ast.parse(textwrap.dedent(source))

declared = set(declared_reads)
violations = []

class Visitor(ast.NodeVisitor):
def visit_Subscript(self, node):
if (
isinstance(node.value, ast.Name)
and node.value.id == state_param_name
and isinstance(node.slice, ast.Constant)
and isinstance(node.slice.value, str)
):
key = node.slice.value
if key not in declared:
violations.append(key)
self.generic_visit(node)

Visitor().visit(tree)

if violations:
raise ValueError(
f"Action reads undeclared state keys: {violations}. "
f"Declared reads: {declared_reads}"
)


from burr.core.typing import ActionSchema

# This is here to make accessing the pydantic actions easier
Expand Down Expand Up @@ -628,6 +679,8 @@ def __init__(
self._fn = fn
self._reads = reads
self._writes = writes
_validate_declared_reads(self._originating_fn, self._reads)

self._bound_params = bound_params if bound_params is not None else {}
self._inputs = (
derive_inputs_from_fn(self._bound_params, self._fn)
Expand Down Expand Up @@ -1106,9 +1159,12 @@ def __init__(
:param writes:
"""
super(FunctionBasedStreamingAction, self).__init__()
self._originating_fn = originating_fn if originating_fn is not None else fn
self._fn = fn
self._reads = reads
self._writes = writes
_validate_declared_reads(self._originating_fn, self._reads)

self._bound_params = bound_params if bound_params is not None else {}
self._inputs = (
derive_inputs_from_fn(self._bound_params, self._fn)
Expand All @@ -1118,7 +1174,7 @@ def __init__(
[item for item in input_spec[1] if item not in self._bound_params],
)
)
self._originating_fn = originating_fn if originating_fn is not None else fn

self._schema = schema
self._tags = tags if tags is not None else []

Expand Down
55 changes: 55 additions & 0 deletions tests/core/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,3 +823,58 @@ def fn(state, a):
required, optional = derive_inputs_from_fn(bound_params, fn)
assert required == []
assert optional == []


def test_undeclared_state_read_raises_error():
with pytest.raises(ValueError):

@action(reads=["foo"], writes=[])
def bad_action(state: State):
_ = state["bar"]
return {}, state


def test_declared_state_read_passes():
@action(reads=["foo"], writes=[])
def good_action(state: State):
_ = state["foo"]
return {}, state


def test_multiple_undeclared_reads_interleaved():
with pytest.raises(ValueError) as exc:

@action(reads=["foo"], writes=[])
def bad_action(state: State):
_ = state["foo"]
_ = state["bar"]
_ = state["baz"]
return {}, state

message = str(exc.value)
assert "bar" in message
assert "baz" in message


def test_pydantic_action_not_impacted():
try:
from pydantic import BaseModel
except ImportError:
pytest.skip("pydantic not installed")

class MyState(BaseModel):
foo: str

@action.pydantic(
reads=["foo"],
writes=["foo"],
state_input_type=MyState,
state_output_type=MyState,
)
def good_action(state: MyState):
return {"foo": state.foo}

# ensure decoration didn't raise and action is creatable
from burr.core.action import create_action

create_action(good_action, name="test")