Skip to content
Open
11 changes: 11 additions & 0 deletions src/google/adk/agents/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,17 @@ class InvocationContext(BaseModel):
canonical_tools_cache: Optional[list[BaseTool]] = None
"""The cache of canonical tools for this invocation."""

metadata: Optional[dict[str, Any]] = None
"""Per-request metadata passed from Runner entry points.

This field allows passing arbitrary metadata that can be accessed during
the invocation lifecycle, particularly in callbacks like before_model_callback.
Common use cases include passing user_id, trace_id, memory context keys, or
other request-specific context that needs to be available during processing.

Supported entry points: run(), run_async(), run_live(), run_debug().
"""

_invocation_cost_manager: _InvocationCostManager = PrivateAttr(
default_factory=_InvocationCostManager
)
Expand Down
4 changes: 2 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def run_live(
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Runs the flow using live api."""
llm_request = LlmRequest()
llm_request = LlmRequest(metadata=invocation_context.metadata)
event_id = Event.new_id()

# Preprocess before calling the LLM.
Expand Down Expand Up @@ -376,7 +376,7 @@ async def _run_one_step_async(
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""One step means one LLM call."""
llm_request = LlmRequest()
llm_request = LlmRequest(metadata=invocation_context.metadata)

# Preprocess before calling the LLM.
async with Aclosing(
Expand Down
10 changes: 10 additions & 0 deletions src/google/adk/models/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import logging
from typing import Any
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -99,6 +100,15 @@ class LlmRequest(BaseModel):
the full history.
"""

metadata: Optional[dict[str, Any]] = None
"""Per-request metadata for callbacks and custom processing.

This field allows passing arbitrary metadata from the Runner.run_async()
call to callbacks like before_model_callback. This is useful for passing
request-specific context such as user_id, trace_id, or memory context keys
that need to be available during model invocation.
"""

def append_instructions(
self, instructions: Union[list[str], types.Content]
) -> list[types.Content]:
Expand Down
38 changes: 37 additions & 1 deletion src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def run(
session_id: str,
new_message: types.Content,
run_config: Optional[RunConfig] = None,
metadata: Optional[dict[str, Any]] = None,
) -> Generator[Event, None, None]:
"""Runs the agent.

Expand All @@ -365,6 +366,7 @@ def run(
session_id: The session ID of the session.
new_message: A new message to append to the session.
run_config: The run config for the agent.
metadata: Optional per-request metadata that will be passed to callbacks.

Yields:
The events generated by the agent.
Expand All @@ -380,6 +382,7 @@ async def _invoke_run_async():
session_id=session_id,
new_message=new_message,
run_config=run_config,
metadata=metadata,
)
) as agen:
async for event in agen:
Expand Down Expand Up @@ -415,6 +418,7 @@ async def run_async(
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
metadata: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Event, None]:
"""Main entry method to run the agent in this runner.

Expand All @@ -432,6 +436,13 @@ async def run_async(
new_message: A new message to append to the session.
state_delta: Optional state changes to apply to the session.
run_config: The run config for the agent.
metadata: Optional per-request metadata that will be passed to callbacks.
This allows passing request-specific context such as user_id, trace_id,
or memory context keys to before_model_callback and other callbacks.
Comment on lines +439 to +441
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential subtle bugs, it's a good practice to clarify the copy behavior of the metadata dictionary in the docstring. Since a shallow copy is performed, modifications to nested mutable objects within a callback will affect the original object passed by the caller. Please add a note about this to help users of the API understand this behavior and avoid unexpected side effects. For example, you could add: Note: A shallow copy is made of this dictionary, so changes to nested mutable objects will affect the original object.

Note: A shallow copy is made of this dictionary, so top-level changes
within callbacks won't affect the original. However, modifications to
nested mutable objects (e.g., nested dicts or lists) will affect the
original.

Yields:
The events generated by the agent.
Expand All @@ -441,13 +452,16 @@ async def run_async(
new_message are None.
"""
run_config = run_config or RunConfig()
# Create a shallow copy to isolate from caller's modifications
metadata = metadata.copy() if metadata is not None else None

if new_message and not new_message.role:
new_message.role = 'user'

async def _run_with_trace(
new_message: Optional[types.Content] = None,
invocation_id: Optional[str] = None,
metadata: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Event, None]:
with tracer.start_as_current_span('invocation'):
session = await self.session_service.get_session(
Expand Down Expand Up @@ -478,6 +492,7 @@ async def _run_with_trace(
invocation_id=invocation_id,
run_config=run_config,
state_delta=state_delta,
metadata=metadata,
)
if invocation_context.end_of_agents.get(
invocation_context.agent.name
Expand All @@ -491,6 +506,7 @@ async def _run_with_trace(
new_message=new_message, # new_message is not None.
run_config=run_config,
state_delta=state_delta,
metadata=metadata,
)

async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
Expand All @@ -517,7 +533,9 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
self.app, session, self.session_service
)

async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
async with Aclosing(
_run_with_trace(new_message, invocation_id, metadata)
) as agen:
async for event in agen:
yield event

Expand Down Expand Up @@ -889,6 +907,7 @@ async def run_live(
live_request_queue: LiveRequestQueue,
run_config: Optional[RunConfig] = None,
session: Optional[Session] = None,
metadata: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Event, None]:
"""Runs the agent in live mode (experimental feature).

Expand Down Expand Up @@ -930,6 +949,7 @@ async def run_live(
run_config: The run config for the agent.
session: The session to use. This parameter is deprecated, please use
`user_id` and `session_id` instead.
metadata: Optional per-request metadata that will be passed to callbacks.

Yields:
AsyncGenerator[Event, None]: An asynchronous generator that yields
Expand All @@ -944,6 +964,7 @@ async def run_live(
Either `session` or both `user_id` and `session_id` must be provided.
"""
run_config = run_config or RunConfig()
metadata = metadata.copy() if metadata is not None else None
# Some native audio models requires the modality to be set. So we set it to
# AUDIO by default.
if run_config.response_modalities is None:
Expand Down Expand Up @@ -974,6 +995,7 @@ async def run_live(
session,
live_request_queue=live_request_queue,
run_config=run_config,
metadata=metadata,
)

root_agent = self.agent
Expand Down Expand Up @@ -1119,6 +1141,7 @@ async def run_debug(
run_config: RunConfig | None = None,
quiet: bool = False,
verbose: bool = False,
metadata: dict[str, Any] | None = None,
) -> list[Event]:
"""Debug helper for quick agent experimentation and testing.

Expand All @@ -1142,6 +1165,7 @@ async def run_debug(
shown).
verbose: If True, shows detailed tool calls and responses. Defaults to
False for cleaner output showing only final agent responses.
metadata: Optional per-request metadata that will be passed to callbacks.

Returns:
list[Event]: All events from all messages.
Expand Down Expand Up @@ -1204,6 +1228,7 @@ async def run_debug(
session_id=session.id,
new_message=types.UserContent(parts=[types.Part(text=message)]),
run_config=run_config,
metadata=metadata,
):
if not quiet:
print_event(event, verbose=verbose)
Expand All @@ -1219,6 +1244,7 @@ async def _setup_context_for_new_invocation(
new_message: types.Content,
run_config: RunConfig,
state_delta: Optional[dict[str, Any]],
metadata: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Sets up the context for a new invocation.

Expand All @@ -1227,6 +1253,7 @@ async def _setup_context_for_new_invocation(
new_message: The new message to process and append to the session.
run_config: The run config of the agent.
state_delta: Optional state changes to apply to the session.
metadata: Optional per-request metadata to pass to callbacks.

Returns:
The invocation context for the new invocation.
Expand All @@ -1236,6 +1263,7 @@ async def _setup_context_for_new_invocation(
session,
new_message=new_message,
run_config=run_config,
metadata=metadata,
)
# Step 2: Handle new message, by running callbacks and appending to
# session.
Expand All @@ -1258,6 +1286,7 @@ async def _setup_context_for_resumed_invocation(
invocation_id: Optional[str],
run_config: RunConfig,
state_delta: Optional[dict[str, Any]],
metadata: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Sets up the context for a resumed invocation.

Expand All @@ -1267,6 +1296,7 @@ async def _setup_context_for_resumed_invocation(
invocation_id: The invocation id to resume.
run_config: The run config of the agent.
state_delta: Optional state changes to apply to the session.
metadata: Optional per-request metadata to pass to callbacks.

Returns:
The invocation context for the resumed invocation.
Expand All @@ -1292,6 +1322,7 @@ async def _setup_context_for_resumed_invocation(
new_message=user_message,
run_config=run_config,
invocation_id=invocation_id,
metadata=metadata,
)
# Step 3: Maybe handle new message.
if new_message:
Expand Down Expand Up @@ -1336,6 +1367,7 @@ def _new_invocation_context(
new_message: Optional[types.Content] = None,
live_request_queue: Optional[LiveRequestQueue] = None,
run_config: Optional[RunConfig] = None,
metadata: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Creates a new invocation context.

Expand All @@ -1345,6 +1377,7 @@ def _new_invocation_context(
new_message: The new message for the context.
live_request_queue: The live request queue for the context.
run_config: The run config for the context.
metadata: Optional per-request metadata for the context.

Returns:
The new invocation context.
Expand Down Expand Up @@ -1376,6 +1409,7 @@ def _new_invocation_context(
live_request_queue=live_request_queue,
run_config=run_config,
resumability_config=self.resumability_config,
metadata=metadata,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent accidental modification of the original metadata dictionary by the caller of run_async, it's a good practice to work with a copy of the metadata. Since dictionaries are mutable, any changes made to metadata within the runner's logic would also affect the caller's original dictionary. Creating a shallow copy here isolates the runner's execution context from the caller. This is especially important as run_async is an async generator, and the caller might modify the metadata dictionary while iterating over the yielded events.

Suggested change
metadata=metadata,
metadata=metadata.copy() if metadata is not None else None,

)

def _new_invocation_context_for_live(
Expand All @@ -1384,6 +1418,7 @@ def _new_invocation_context_for_live(
*,
live_request_queue: LiveRequestQueue,
run_config: Optional[RunConfig] = None,
metadata: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Creates a new invocation context for live multi-agent."""
run_config = run_config or RunConfig()
Expand All @@ -1402,6 +1437,7 @@ def _new_invocation_context_for_live(
session,
live_request_queue=live_request_queue,
run_config=run_config,
metadata=metadata,
)

async def _handle_new_message(
Expand Down
Loading