diff --git a/pyproject.toml b/pyproject.toml index a16132881..5a17405af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ mistral = ["mistralai>=1.8.2"] ollama = ["ollama>=0.4.8,<1.0.0"] openai = ["openai>=1.68.0,<2.0.0"] writer = ["writer-sdk>=2.2.0,<3.0.0"] +xai = ["xai-sdk>=1.5.0,<2.0.0"] sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface @@ -79,7 +80,7 @@ bidi = [ bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] bidi-openai = ["websockets>=15.0.0,<17.0.0"] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,xai,otel]"] bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"] dev = [ diff --git a/src/strands/models/__init__.py b/src/strands/models/__init__.py index d5f88d09a..8200a0345 100644 --- a/src/strands/models/__init__.py +++ b/src/strands/models/__init__.py @@ -62,4 +62,8 @@ def __getattr__(name: str) -> Any: from .writer import WriterModel return WriterModel + if name == "xAIModel": + from .xai import xAIModel + + return xAIModel raise AttributeError(f"cannot import name '{name}' from '{__name__}' ({__file__})") diff --git a/src/strands/models/xai.py b/src/strands/models/xai.py new file mode 100644 index 000000000..809cfff30 --- /dev/null +++ b/src/strands/models/xai.py @@ -0,0 +1,765 @@ +"""xAI model provider. + +- Docs: https://docs.x.ai/docs + +This module implements the xAI model provider using the native xAI Python SDK (xai-sdk). +The xAI SDK uses a gRPC-based Chat API pattern: + + from xai_sdk import Client + from xai_sdk.chat import system, user + + client = Client(api_key="...") + chat = client.chat.create(model="grok-4", store_messages=False) + chat.append(system("You are helpful")) + chat.append(user("Hello")) + response = chat.sample() # or: for response, chunk in chat.stream() + +Server-Side State Preservation +============================== + +xAI's server-side tools (x_search, web_search, code_execution) and encrypted reasoning +content present a unique challenge: their results are returned in an encrypted format +that cannot be reconstructed from plain text. To maintain multi-turn conversation +context, we must preserve this encrypted state across turns. + +The solution uses Strands' `reasoningContent.redactedContent` field to store serialized +xAI SDK messages. This field is designed for encrypted/hidden content and is NOT rendered +when printing AgentResult, keeping the output clean for users. + +Flow: +1. After each response with server-side tools or encrypted reasoning, we capture the + SDK's internal protobuf messages (which contain encrypted tool results) +2. Serialize these messages to base64 and wrap them with XAI_STATE markers +3. Store in `reasoningContent.redactedContent` - this field is preserved in message + history but NOT displayed to users (unlike `text` content blocks) +4. On the next turn, extract and deserialize these messages to rebuild the xAI chat + with full encrypted context + +Why `reasoningContent.redactedContent` instead of `text`? +- `text` content blocks are rendered in AgentResult.__str__(), showing ugly markers +- `redactedContent` is designed for encrypted content and is NOT rendered +- The Strands event loop already handles accumulating `redactedContent` properly +- This keeps the user-facing output clean while preserving internal state +""" + +import base64 +import json +import logging +import mimetypes +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar + +import pydantic +from typing_extensions import Required, Unpack, override + +from ..types.content import Messages +from ..types.streaming import StreamEvent +from ..types.tools import ToolChoice, ToolSpec +from ._validation import validate_config_keys +from .model import Model + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=pydantic.BaseModel) + +try: + from xai_sdk import AsyncClient + from xai_sdk.chat import chat_pb2 as xai_chat_pb2 + from xai_sdk.chat import image as xai_image + from xai_sdk.chat import system as xai_system + from xai_sdk.chat import tool as xai_tool + from xai_sdk.chat import tool_result as xai_tool_result + from xai_sdk.chat import user as xai_user + from xai_sdk.tools import get_tool_call_type +except ImportError as e: + raise ImportError( + "The 'xai-sdk' package is required to use xAIModel. Install it with: pip install strands-agents[grok]" + ) from e + +# Markers for xAI state serialization. +# The state is stored in reasoningContent.redactedContent to keep it hidden from users. +# Format: +# The JSON contains {"messages": [base64_encoded_protobuf_messages...]} +XAI_STATE_MARKER = "" + + +class xAIModel(Model): + """xAI model provider implementation. + + This provider uses the native xAI Python SDK (xai-sdk) which provides a gRPC-based + conversational API with features including server-side agentic tools (web_search, + x_search, code_execution), reasoning models, and stateful conversations. + + - Docs: https://docs.x.ai/docs + """ + + class xAIConfig(TypedDict, total=False): + """Configuration options for xAI models. + + Attributes: + model_id: xAI model ID (e.g., "grok-4", "grok-4-fast", "grok-3-mini"). + params: Additional model parameters (e.g., temperature, max_tokens). + xai_tools: xAI server-side tools (web_search, x_search, code_execution). + reasoning_effort: Reasoning effort level ("low" or "high") for grok-3-mini. + include: Optional xAI features (e.g., ["inline_citations", "verbose_streaming"]). + use_encrypted_content: Return encrypted reasoning for multi-turn context (grok-4). + """ + + model_id: Required[str] + params: dict[str, Any] + xai_tools: list[Any] + reasoning_effort: str + include: list[str] + use_encrypted_content: bool + + def __init__( + self, + *, + client: AsyncClient | None = None, + client_args: dict[str, Any] | None = None, + **model_config: Unpack[xAIConfig], + ) -> None: + """Initialize provider instance. + + Args: + client: Pre-configured AsyncClient to reuse across requests. + client_args: Arguments for the underlying xAI client (e.g., api_key, timeout). + **model_config: Configuration options for the Grok model. + + Raises: + ValueError: If both `client` and `client_args` are provided. + """ + validate_config_keys(model_config, xAIModel.xAIConfig) + self.config = xAIModel.xAIConfig(**model_config) + + if client is not None and client_args is not None and len(client_args) > 0: + raise ValueError("Only one of 'client' or 'client_args' should be provided, not both.") + + self._custom_client = client + self.client_args = client_args or {} + + if "xai_tools" in self.config: + self._validate_xai_tools(self.config["xai_tools"]) + # Auto-enable encrypted content when using server-side tools + # This is required to preserve server-side tool state across turns + if not self.config.get("use_encrypted_content"): + self.config["use_encrypted_content"] = True + logger.debug("auto-enabled use_encrypted_content for server-side tool state preservation") + + logger.debug("config=<%s> | initializing", self.config) + + def _get_client(self) -> AsyncClient: + """Get an xAI AsyncClient for making requests.""" + if self._custom_client is not None: + return self._custom_client + return AsyncClient(**self.client_args) + + @staticmethod + def _validate_xai_tools(xai_tools: list[Any]) -> None: + """Validate that xai_tools contains only server-side tools.""" + for tool in xai_tools: + if isinstance(tool, dict) and tool.get("type") == "function": + raise ValueError( + "xai_tools should not contain function-based tools. " + "Use the standard tools interface for function calling tools." + ) + + def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[Any]: + """Format tool specs into xAI SDK compatible tools.""" + tools: list[Any] = [] + + for tool_spec in tool_specs or []: + tools.append( + xai_tool( + name=tool_spec["name"], + description=tool_spec["description"], + parameters=tool_spec["inputSchema"]["json"], + ) + ) + + if self.config.get("xai_tools"): + tools.extend(self.config["xai_tools"]) + + return tools + + def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format xAI response events into standardized StreamEvent format.""" + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event.get("data_type") == "tool": + tool_data = event["data"] + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": tool_data["name"], + "toolUseId": tool_data["id"], + } + } + } + } + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event.get("data_type") == "tool": + return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].get("arguments", "")}}}} + if event.get("data_type") == "reasoning_content": + return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}} + if event.get("data_type") == "server_tool": + tool_data = event["data"] + tool_text = f"\n[xAI Tool: {tool_data['name']}({tool_data.get('arguments', '{}')})]\n" + return {"contentBlockDelta": {"delta": {"text": tool_text}}} + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event.get("data"): + case "tool_use": + return {"messageStop": {"stopReason": "tool_use"}} + case "max_tokens" | "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + usage_data = event["data"] + metadata_event: StreamEvent = { + "metadata": { + "usage": { + "inputTokens": getattr(usage_data, "prompt_tokens", 0), + "outputTokens": getattr(usage_data, "completion_tokens", 0), + "totalTokens": getattr(usage_data, "total_tokens", 0), + }, + "metrics": {"latencyMs": 0}, + } + } + if hasattr(usage_data, "reasoning_tokens") and usage_data.reasoning_tokens: + metadata_event["metadata"]["usage"]["reasoningTokens"] = usage_data.reasoning_tokens # type: ignore[typeddict-unknown-key] + if event.get("citations"): + metadata_event["metadata"]["citations"] = event["citations"] # type: ignore[typeddict-unknown-key] + if event.get("server_tool_calls"): + metadata_event["metadata"]["serverToolCalls"] = event["server_tool_calls"] # type: ignore[typeddict-unknown-key] + return metadata_event + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']}> | unknown type") + + def _build_chat(self, client: AsyncClient, tool_specs: list[ToolSpec] | None = None) -> Any: + """Build a chat instance with the configured parameters.""" + chat_kwargs: dict[str, Any] = { + "model": self.config["model_id"], + "store_messages": False, + } + + tools = self._format_request_tools(tool_specs) + if tools: + chat_kwargs["tools"] = tools + + if self.config.get("reasoning_effort"): + chat_kwargs["reasoning_effort"] = self.config["reasoning_effort"] + + if self.config.get("include"): + chat_kwargs["include"] = self.config["include"] + + if self.config.get("use_encrypted_content"): + chat_kwargs["use_encrypted_content"] = self.config["use_encrypted_content"] + + if self.config.get("params"): + chat_kwargs.update(self.config["params"]) + + logger.debug("chat_kwargs=<%s> | creating xAI chat", chat_kwargs) + return client.chat.create(**chat_kwargs) + + def _format_image_content(self, content: dict[str, Any]) -> Any: + """Format image content block into xAI SDK image() helper format.""" + image_data = content["image"] + mime_type = mimetypes.types_map.get(f".{image_data['format']}", "image/png") + b64_data = base64.b64encode(image_data["source"]["bytes"]).decode("utf-8") + image_url = f"data:{mime_type};base64,{b64_data}" + return xai_image(image_url=image_url, detail="auto") + + def _extract_xai_state(self, messages: Messages) -> list[bytes] | None: + """Extract serialized xAI SDK messages from Strands message history. + + This method searches for preserved xAI state in the message history. The state + contains serialized protobuf messages from the xAI SDK that include encrypted + server-side tool results and reasoning content. + + The state is stored in `reasoningContent.redactedContent` to keep it hidden + from users (this field is not rendered in AgentResult.__str__). We also check + `text` content for backwards compatibility. + + Why is this needed? + - Server-side tools (x_search, web_search) return encrypted results + - Encrypted reasoning (grok-4 with use_encrypted_content=True) cannot be reconstructed + - The xAI SDK requires the original protobuf messages to maintain context + - Without this, multi-turn conversations would lose server-side tool context + + Args: + messages: The Strands message history to search. + + Returns: + List of serialized protobuf message bytes if found, None otherwise. + """ + for message in messages: + for content in message.get("content", []): + # Primary location: reasoningContent.redactedContent + # This field is designed for encrypted content and is NOT rendered to users + if "reasoningContent" in content: + rc = content["reasoningContent"] + if "redactedContent" in rc: + redacted: Any = rc["redactedContent"] + if isinstance(redacted, bytes): + redacted_text = redacted.decode("utf-8") + else: + redacted_text = str(redacted) + if redacted_text.startswith(XAI_STATE_MARKER) and redacted_text.endswith(XAI_STATE_MARKER_END): + # Extract base64-encoded JSON between markers + encoded = redacted_text[len(XAI_STATE_MARKER) : -len(XAI_STATE_MARKER_END)] + try: + state_data = json.loads(base64.b64decode(encoded).decode("utf-8")) + return [base64.b64decode(msg_b64) for msg_b64 in state_data.get("messages", [])] + except (json.JSONDecodeError, ValueError) as e: + logger.warning("failed to decode xAI state from redactedContent: %s", e) + + # Fallback: check text content (for backwards compatibility with older versions) + if "text" in content: + text_content = content["text"] + if text_content.startswith(XAI_STATE_MARKER) and text_content.endswith(XAI_STATE_MARKER_END): + encoded = text_content[len(XAI_STATE_MARKER) : -len(XAI_STATE_MARKER_END)] + try: + state_data = json.loads(base64.b64decode(encoded).decode("utf-8")) + return [base64.b64decode(msg_b64) for msg_b64 in state_data.get("messages", [])] + except (json.JSONDecodeError, ValueError) as e: + logger.warning("failed to decode xAI state from text: %s", e) + return None + + def _append_messages_to_chat( + self, + chat: Any, + messages: Messages, + system_prompt: str | None = None, + ) -> None: + """Append Strands messages to an xAI chat. + + This method handles two cases: + 1. If xAI state is present (from previous turns with server-side tools or + encrypted reasoning), use the serialized protobuf messages directly to + preserve encrypted content, then append only the new user message. + 2. Otherwise, reconstruct all messages from the Strands format. + + The first case is critical for multi-turn conversations with server-side tools + because the encrypted tool results cannot be reconstructed from plain text. + """ + if system_prompt: + chat.append(xai_system(system_prompt)) + + # Check for preserved xAI state (contains server-side tool results) + xai_state = self._extract_xai_state(messages) + if xai_state: + logger.debug("xai_state_count=<%d> | using preserved xAI messages", len(xai_state)) + for serialized_msg in xai_state: + msg = xai_chat_pb2.Message() + msg.ParseFromString(serialized_msg) + chat.append(msg) + + # Append the new user message (last message in the list) + # The xAI state contains the old conversation, but we need to add the new input + if messages and messages[-1]["role"] == "user": + last_message = messages[-1] + user_parts: list[Any] = [] + for content in last_message["content"]: + if "text" in content: + text = content["text"] + # Skip xAI state markers + if not (text.startswith(XAI_STATE_MARKER) and text.endswith(XAI_STATE_MARKER_END)): + user_parts.append(text) + elif "image" in content: + user_parts.append(self._format_image_content(dict(content))) + + if user_parts: + if len(user_parts) == 1 and isinstance(user_parts[0], str): + chat.append(xai_user(user_parts[0])) + else: + chat.append(xai_user(*user_parts)) + logger.debug("appended new user message after xAI state") + return + + # No preserved state - reconstruct from Strands format + for message in messages: + role = message["role"] + contents = message["content"] + + if role == "user": + tool_results: list[tuple[str, str]] = [] + user_parts_list: list[Any] = [] + + for content in contents: + if "toolResult" in content: + tr = content["toolResult"] + result_parts: list[str] = [] + for tr_content in tr["content"]: + if "json" in tr_content: + result_parts.append(json.dumps(tr_content["json"])) + elif "text" in tr_content: + result_parts.append(tr_content["text"]) + result_str = "\n".join(result_parts) if result_parts else "" + tool_results.append((tr.get("toolUseId", ""), result_str)) + elif "text" in content: + user_parts_list.append(content["text"]) + elif "image" in content: + user_parts_list.append(self._format_image_content(dict(content))) + + for _tool_use_id, result in tool_results: + chat.append(xai_tool_result(result)) + + if user_parts_list: + if len(user_parts_list) == 1 and isinstance(user_parts_list[0], str): + chat.append(xai_user(user_parts_list[0])) + else: + chat.append(xai_user(*user_parts_list)) + + elif role == "assistant": + assistant_msg = xai_chat_pb2.Message() + assistant_msg.role = xai_chat_pb2.ROLE_ASSISTANT + + text_parts: list[str] = [] + reasoning_parts: list[str] = [] + encrypted_content: str | None = None + tool_uses: list[dict[str, Any]] = [] + + for content in contents: + if "text" in content: + # Skip xAI state markers - they're only for state preservation + text = content["text"] + if not (text.startswith(XAI_STATE_MARKER) and text.endswith(XAI_STATE_MARKER_END)): + text_parts.append(text) + elif "reasoningContent" in content: + rc = content["reasoningContent"] + # Handle visible reasoning text (grok-3-mini) + if "reasoningText" in rc: + reasoning_text: Any = rc["reasoningText"] + if isinstance(reasoning_text, dict) and "text" in reasoning_text: + reasoning_parts.append(reasoning_text["text"]) + elif isinstance(reasoning_text, str): + reasoning_parts.append(reasoning_text) + # Handle encrypted reasoning (grok-4 with use_encrypted_content=True) + if "redactedContent" in rc: + redacted: Any = rc["redactedContent"] + if isinstance(redacted, bytes): + encrypted_content = redacted.decode("utf-8") + else: + encrypted_content = str(redacted) + elif "toolUse" in content: + tool_use_block = content["toolUse"] + tool_uses.append( + { + "id": tool_use_block.get("toolUseId", ""), + "name": tool_use_block.get("name", ""), + "arguments": tool_use_block.get("input", ""), + } + ) + + # Add reasoning content if present (for grok-3-mini) + if reasoning_parts: + assistant_msg.reasoning_content = " ".join(reasoning_parts) + + # Add encrypted content if present (for grok-4 with use_encrypted_content) + if encrypted_content: + assistant_msg.encrypted_content = encrypted_content + + if text_parts: + text_content = assistant_msg.content.add() + text_content.text = " ".join(text_parts) + + for tool_use_item in tool_uses: + tc = assistant_msg.tool_calls.add() + tc.id = tool_use_item["id"] + tc.type = xai_chat_pb2.TOOL_CALL_TYPE_CLIENT_SIDE_TOOL + tc.function.name = tool_use_item["name"] + args = tool_use_item["arguments"] + if isinstance(args, dict): + tc.function.arguments = json.dumps(args) + else: + tc.function.arguments = str(args) if args else "" + + chat.append(assistant_msg) + + def _capture_xai_state(self, chat: Any, response: Any) -> list[str]: + """Capture xAI SDK messages for state preservation across turns. + + This method is called after receiving a response that contains server-side + tool results or encrypted reasoning content. It serializes the SDK's internal + protobuf messages so they can be restored on the next turn. + + The flow: + 1. Append the response to the chat (SDK creates proper message structure) + 2. Iterate through all messages in the chat (excluding system) + 3. Serialize each message to protobuf bytes, then base64 encode + + These serialized messages contain encrypted content that cannot be reconstructed + from plain text, which is why we must preserve them exactly as-is. + + Args: + chat: The xAI chat instance with the conversation. + response: The response to append before capturing state. + + Returns: + List of base64-encoded serialized protobuf messages. + """ + chat.append(response) + + serialized_messages: list[str] = [] + for msg in chat.messages: + # Skip system messages - they're added fresh each turn + if msg.role == xai_chat_pb2.ROLE_SYSTEM: + continue + serialized = msg.SerializeToString() + serialized_messages.append(base64.b64encode(serialized).decode("utf-8")) + + logger.debug("captured_messages=<%d> | preserved xAI state", len(serialized_messages)) + return serialized_messages + + @override + def update_config(self, **model_config: Unpack[xAIConfig]) -> None: # type: ignore[override] + """Update the Grok model configuration.""" + validate_config_keys(model_config, xAIModel.xAIConfig) + if "xai_tools" in model_config: + self._validate_xai_tools(model_config["xai_tools"]) + if not self.config.get("use_encrypted_content") and not model_config.get("use_encrypted_content"): + model_config["use_encrypted_content"] = True + logger.debug("auto-enabled use_encrypted_content for server-side tool state preservation") + self.config.update(model_config) + + @override + def get_config(self) -> xAIConfig: + """Get the Grok model configuration.""" + return self.config + + @override + async def stream( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: ToolChoice | None = None, + **kwargs: Any, + ) -> AsyncGenerator[StreamEvent, None]: + """Stream conversation with the Grok model.""" + client = self._get_client() + + try: + chat = self._build_chat(client, tool_specs) + self._append_messages_to_chat(chat, messages, system_prompt) + + yield self._format_chunk({"chunk_type": "message_start"}) + + tool_calls_pending: list[dict[str, Any]] = [] + server_tool_calls: list[dict[str, Any]] = [] + current_content_type: str | None = None + final_response: Any = None + citations: Any = None + + async for response, chunk in chat.stream(): + final_response = response + + if hasattr(response, "citations") and response.citations: + citations = response.citations + + if hasattr(chunk, "reasoning_content") and chunk.reasoning_content: + if current_content_type != "reasoning": + if current_content_type: + yield self._format_chunk({"chunk_type": "content_stop"}) + yield self._format_chunk({"chunk_type": "content_start", "data_type": "reasoning"}) + current_content_type = "reasoning" + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": chunk.reasoning_content, + } + ) + + if hasattr(chunk, "content") and chunk.content: + if current_content_type != "text": + if current_content_type: + yield self._format_chunk({"chunk_type": "content_stop"}) + yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) + current_content_type = "text" + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": chunk.content, + } + ) + + if hasattr(chunk, "tool_calls") and chunk.tool_calls: + for tool_call in chunk.tool_calls: + tool_type = get_tool_call_type(tool_call) + tool_data = { + "id": tool_call.id, + "name": tool_call.function.name, + "arguments": tool_call.function.arguments or "", + } + if tool_type == "client_side_tool": + tool_calls_pending.append(tool_data) + else: + logger.debug( + "tool_type=<%s>, tool_name=<%s> | server-side tool executed by xAI", + tool_type, + tool_call.function.name, + ) + server_tool_calls.append(tool_data) + + if current_content_type != "text": + if current_content_type: + yield self._format_chunk({"chunk_type": "content_stop"}) + yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) + current_content_type = "text" + yield self._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "server_tool", + "data": tool_data, + } + ) + + if current_content_type: + yield self._format_chunk({"chunk_type": "content_stop"}) + + # Emit encrypted reasoning content for visibility (grok-4 with use_encrypted_content=True) + # The actual state preservation happens via xAI state capture below + if final_response and hasattr(final_response, "encrypted_content") and final_response.encrypted_content: + encrypted_bytes = ( + final_response.encrypted_content.encode("utf-8") + if isinstance(final_response.encrypted_content, str) + else final_response.encrypted_content + ) + yield self._format_chunk({"chunk_type": "content_start", "data_type": "encrypted_reasoning"}) + yield {"contentBlockDelta": {"delta": {"reasoningContent": {"redactedContent": encrypted_bytes}}}} + yield self._format_chunk({"chunk_type": "content_stop"}) + + # Emit tool call events (client-side only) + for tool_call in tool_calls_pending: + yield self._format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call}) + yield self._format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call}) + yield self._format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + # Determine if we need to capture xAI state for multi-turn context preservation + # State is needed when: + # 1. Server-side tools were used (encrypted tool results must be preserved) + # 2. Encrypted reasoning content is present (grok-4 with use_encrypted_content=True) + # State is NOT needed for client-side tools only (Strands handles those) + has_encrypted_reasoning = ( + final_response and hasattr(final_response, "encrypted_content") and final_response.encrypted_content + ) + needs_xai_state = server_tool_calls or has_encrypted_reasoning + + # ================================================================= + # STATE PRESERVATION FOR MULTI-TURN CONVERSATIONS + # ================================================================= + # When server-side tools or encrypted reasoning are used, we must + # preserve the xAI SDK's internal state for the next turn. This state + # contains encrypted content that cannot be reconstructed from text. + # + # We store this state in `reasoningContent.redactedContent` because: + # 1. This field is NOT rendered in AgentResult.__str__() - keeps output clean + # 2. The Strands event loop properly accumulates redactedContent + # 3. It's semantically appropriate (encrypted/hidden content) + # + # On the next turn, _extract_xai_state() will find this content and + # restore the full conversation context including encrypted tool results. + # ================================================================= + if final_response and needs_xai_state: + xai_state = self._capture_xai_state(chat, final_response) + if xai_state: + # Encode: protobuf bytes -> base64 -> JSON -> base64 -> markers + state_json = json.dumps({"messages": xai_state}) + state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8") + state_text = f"{XAI_STATE_MARKER}{state_b64}{XAI_STATE_MARKER_END}" + + # Emit as reasoningContent.redactedContent - this is the key! + # Unlike text content, redactedContent is NOT shown to users + yield self._format_chunk({"chunk_type": "content_start", "data_type": "xai_state"}) + yield { + "contentBlockDelta": { + "delta": {"reasoningContent": {"redactedContent": state_text.encode("utf-8")}} + } + } + yield self._format_chunk({"chunk_type": "content_stop"}) + + stop_reason = "tool_use" if tool_calls_pending else "end_turn" + yield self._format_chunk({"chunk_type": "message_stop", "data": stop_reason}) + + if final_response and hasattr(final_response, "usage") and final_response.usage: + yield self._format_chunk( + { + "chunk_type": "metadata", + "data": final_response.usage, + "citations": citations, + "server_tool_calls": server_tool_calls if server_tool_calls else None, + } + ) + + except Exception as error: + self._handle_stream_error(error) + + logger.debug("finished streaming response from xAI") + + def _handle_stream_error(self, error: Exception) -> None: + """Handle errors from the xAI API and map them to Strands exceptions.""" + from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException + + error_message = str(error).lower() + error_str = str(error) + + if any(x in error_message for x in ["rate limit", "rate_limit", "too many requests"]) or "429" in error_str: + logger.warning("error=<%s> | xAI rate limit error", error) + raise ModelThrottledException(str(error)) from error + + if any(x in error_message for x in ["context length", "maximum context", "token limit"]): + logger.warning("error=<%s> | xAI context window overflow error", error) + raise ContextWindowOverflowException(str(error)) from error + + raise error + + @override + async def structured_output( + self, + output_model: type[T], + prompt: Messages, + system_prompt: str | None = None, + **kwargs: Any, + ) -> AsyncGenerator[dict[str, T | Any], None]: + """Get structured output from the Grok model using chat.parse().""" + from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException + + client = self._get_client() + + try: + chat = self._build_chat(client) + self._append_messages_to_chat(chat, prompt, system_prompt) + + response, parsed_output = await chat.parse(output_model) + yield {"output": parsed_output} + + except Exception as error: + error_message = str(error).lower() + error_str = str(error) + + if any(x in error_message for x in ["rate limit", "rate_limit", "too many requests"]) or "429" in error_str: + logger.warning("error=<%s> | xAI rate limit error", error) + raise ModelThrottledException(str(error)) from error + + if any(x in error_message for x in ["context length", "maximum context", "token limit"]): + logger.warning("error=<%s> | xAI context window overflow error", error) + raise ContextWindowOverflowException(str(error)) from error + + raise diff --git a/tests/strands/models/test_xai.py b/tests/strands/models/test_xai.py new file mode 100644 index 000000000..142d8db19 --- /dev/null +++ b/tests/strands/models/test_xai.py @@ -0,0 +1,1049 @@ +"""Unit tests for the xAI model provider.""" + +import unittest.mock +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +import pytest + +import strands +from strands.models.xai import xAIModel + + +@contextmanager +def mock_xai_sdk() -> Generator[dict[str, unittest.mock.Mock], None, None]: + """Context manager to mock the xAI SDK components.""" + with ( + unittest.mock.patch.object(strands.models.xai, "AsyncClient") as mock_client_cls, + unittest.mock.patch.object(strands.models.xai, "xai_tool") as mock_xai_tool, + unittest.mock.patch.object(strands.models.xai, "xai_system") as mock_xai_system, + unittest.mock.patch.object(strands.models.xai, "xai_user") as mock_xai_user, + unittest.mock.patch.object(strands.models.xai, "xai_tool_result") as mock_xai_tool_result, + unittest.mock.patch.object( + strands.models.xai, "get_tool_call_type", return_value="client_side_tool" + ) as mock_get_tool_call_type, + ): + mock_client = mock_client_cls.return_value + + def create_tool_mock(name: str, description: str, parameters: dict) -> dict[str, Any]: + return { + "type": "function", + "function": {"name": name, "description": description, "parameters": parameters}, + } + + mock_xai_tool.side_effect = create_tool_mock + + yield { + "client": mock_client, + "client_cls": mock_client_cls, + "xai_tool": mock_xai_tool, + "xai_system": mock_xai_system, + "xai_user": mock_xai_user, + "xai_tool_result": mock_xai_tool_result, + "get_tool_call_type": mock_get_tool_call_type, + } + + +@contextmanager +def mock_xai_client() -> Generator[unittest.mock.Mock, None, None]: + """Context manager to mock the xAI AsyncClient.""" + with mock_xai_sdk() as mocks: + yield mocks["client"] + + +@pytest.fixture +def mock_xai_client_fixture() -> Generator[unittest.mock.Mock, None, None]: + """Pytest fixture to mock the xAI AsyncClient.""" + with mock_xai_client() as client: + yield client + + +@pytest.fixture +def mock_xai_sdk_fixture() -> Generator[dict[str, unittest.mock.Mock], None, None]: + """Pytest fixture to mock the full xAI SDK.""" + with mock_xai_sdk() as mocks: + yield mocks + + +@pytest.fixture +def model_id() -> str: + """Default model ID for tests.""" + return "grok-4" + + +@pytest.fixture +def model(mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str) -> xAIModel: + """Create a xAIModel instance with mocked SDK.""" + _ = mock_xai_sdk_fixture + return xAIModel(model_id=model_id) + + +class TestxAIConfigRoundTrip: + """Tests for configuration round-trip consistency.""" + + @pytest.mark.parametrize( + "model_id", + [ + "grok-3-mini-fast-latest", + "grok-2-latest", + "test-model-123", + "model_with_underscores", + "model-with-dashes", + ], + ) + def test_config_round_trip_model_id_only(self, model_id: str) -> None: + """For any valid model_id, get_config returns equivalent config.""" + with mock_xai_client(): + model = xAIModel(model_id=model_id) + config = model.get_config() + assert config["model_id"] == model_id + + @pytest.mark.parametrize( + "model_id,params", + [ + ("grok-3-mini-fast-latest", {"temperature": 0.7}), + ("grok-2-latest", {"max_tokens": 1000}), + ("test-model", {"temperature": 1.5, "max_tokens": 2048}), + ("model-123", {}), + ], + ) + def test_config_round_trip_with_params(self, model_id: str, params: dict) -> None: + """For any valid model_id and params, config round-trip preserves values.""" + with mock_xai_client(): + model = xAIModel(model_id=model_id, params=params) + config = model.get_config() + assert config["model_id"] == model_id + if params: + assert config["params"] == params + + @pytest.mark.parametrize( + "model_id,reasoning_effort", + [ + ("grok-3-mini-fast-latest", "low"), + ("grok-2-latest", "high"), + ], + ) + def test_config_round_trip_with_reasoning_effort(self, model_id: str, reasoning_effort: str) -> None: + """For any valid model_id and reasoning_effort, config round-trip preserves values.""" + with mock_xai_client(): + model = xAIModel(model_id=model_id, reasoning_effort=reasoning_effort) + config = model.get_config() + assert config["model_id"] == model_id + assert config["reasoning_effort"] == reasoning_effort + + @pytest.mark.parametrize( + "model_id,include", + [ + ("grok-3-mini-fast-latest", ["verbose_streaming"]), + ("grok-2-latest", ["inline_citations"]), + ("test-model", ["verbose_streaming", "inline_citations"]), + ("model-123", []), + ], + ) + def test_config_round_trip_with_include(self, model_id: str, include: list) -> None: + """For any valid model_id and include list, config round-trip preserves values.""" + with mock_xai_client(): + model = xAIModel(model_id=model_id, include=include) + config = model.get_config() + assert config["model_id"] == model_id + if include: + assert config["include"] == include + + +class TestxAIModelInit: + """Unit tests for xAIModel initialization.""" + + def test_init_with_model_id(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test initialization with just model_id.""" + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + assert model.get_config()["model_id"] == model_id + + def test_init_with_params(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test initialization with params.""" + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id, params={"temperature": 0.7}) + config = model.get_config() + assert config["model_id"] == model_id + assert config["params"] == {"temperature": 0.7} + + def test_init_with_client_args(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test initialization with client_args.""" + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id, client_args={"api_key": "test-key"}) + assert model.client_args == {"api_key": "test-key"} + + def test_init_with_custom_client(self, model_id: str) -> None: + """Test initialization with a custom client.""" + mock_client = unittest.mock.Mock() + model = xAIModel(client=mock_client, model_id=model_id) + assert model._custom_client is mock_client + + def test_init_with_both_client_and_client_args_raises_error(self, model_id: str) -> None: + """Test that providing both client and client_args raises ValueError.""" + mock_client = unittest.mock.Mock() + with pytest.raises(ValueError, match="Only one of 'client' or 'client_args' should be provided"): + xAIModel(client=mock_client, client_args={"api_key": "test"}, model_id=model_id) + + +class TestxAIModelGetClient: + """Unit tests for xAIModel._get_client method.""" + + def test_get_client_returns_custom_client(self, model_id: str) -> None: + """Test that _get_client returns the injected client when provided.""" + mock_client = unittest.mock.Mock() + model = xAIModel(client=mock_client, model_id=model_id) + result = model._get_client() + assert result is mock_client + + def test_get_client_creates_new_client(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test that _get_client creates a new client when no custom client is provided.""" + model = xAIModel(model_id=model_id, client_args={"api_key": "test-key"}) + model._get_client() + strands.models.xai.AsyncClient.assert_called_with(api_key="test-key") + + +class TestGrokToolsValidation: + """Unit tests for xai_tools validation.""" + + def test_validate_xai_tools_rejects_function_tools( + self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str + ) -> None: + """Test that function-based tools (dicts with type=function) are rejected in xai_tools.""" + _ = mock_xai_client_fixture + # Client-side tools created via xai_tool() are dicts with "type": "function" + mock_function_tool = {"type": "function", "function": {"name": "test_function"}} + with pytest.raises(ValueError, match="xai_tools should not contain function-based tools"): + xAIModel(model_id=model_id, xai_tools=[mock_function_tool]) + + def test_validate_xai_tools_accepts_server_side_tools( + self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str + ) -> None: + """Test that server-side tools (protobuf objects) are accepted in xai_tools.""" + _ = mock_xai_client_fixture + # Server-side tools like web_search() are protobuf objects, not dicts + mock_server_tool = unittest.mock.Mock(spec=[]) + model = xAIModel(model_id=model_id, xai_tools=[mock_server_tool]) + assert "xai_tools" in model.get_config() + + def test_validate_xai_tools_on_update_config( + self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str + ) -> None: + """Test that xai_tools validation runs on update_config.""" + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + # Client-side tools created via xai_tool() are dicts with "type": "function" + mock_function_tool = {"type": "function", "function": {"name": "test_function"}} + with pytest.raises(ValueError, match="xai_tools should not contain function-based tools"): + model.update_config(xai_tools=[mock_function_tool]) + + +class TestFormatRequestTools: + """Unit tests for _format_request_tools method.""" + + def test_format_empty_tools(self, model: xAIModel) -> None: + """Test formatting with no tools.""" + result = model._format_request_tools(None) + assert result == [] + + def test_format_single_tool(self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str) -> None: + """Test formatting a single tool spec.""" + model = xAIModel(model_id=model_id) + tool_specs = [ + { + "name": "get_weather", + "description": "Get weather for a location", + "inputSchema": {"json": {"type": "object", "properties": {"location": {"type": "string"}}}}, + } + ] + result = model._format_request_tools(tool_specs) + assert len(result) == 1 + mock_xai_sdk_fixture["xai_tool"].assert_called_once_with( + name="get_weather", + description="Get weather for a location", + parameters={"type": "object", "properties": {"location": {"type": "string"}}}, + ) + + def test_format_multiple_tools(self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str) -> None: + """Test formatting multiple tool specs.""" + model = xAIModel(model_id=model_id) + tool_specs = [ + {"name": "tool1", "description": "First tool", "inputSchema": {"json": {"type": "object"}}}, + {"name": "tool2", "description": "Second tool", "inputSchema": {"json": {"type": "object"}}}, + ] + result = model._format_request_tools(tool_specs) + assert len(result) == 2 + assert mock_xai_sdk_fixture["xai_tool"].call_count == 2 + + def test_format_tools_with_xai_tools( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test that xai_tools are appended to formatted tools.""" + mock_server_tool = unittest.mock.Mock(spec=[]) + model = xAIModel(model_id=model_id, xai_tools=[mock_server_tool]) + tool_specs = [{"name": "tool1", "description": "Tool", "inputSchema": {"json": {"type": "object"}}}] + result = model._format_request_tools(tool_specs) + assert len(result) == 2 + assert mock_server_tool in result + + +class TestFormatChunk: + """Unit tests for _format_chunk method.""" + + def test_format_message_start(self, model: xAIModel) -> None: + """Test formatting message_start chunk.""" + result = model._format_chunk({"chunk_type": "message_start"}) + assert result == {"messageStart": {"role": "assistant"}} + + def test_format_content_start_text(self, model: xAIModel) -> None: + """Test formatting content_start chunk for text.""" + result = model._format_chunk({"chunk_type": "content_start", "data_type": "text"}) + assert result == {"contentBlockStart": {"start": {}}} + + def test_format_content_start_tool(self, model: xAIModel) -> None: + """Test formatting content_start chunk for tool.""" + result = model._format_chunk( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": {"name": "get_weather", "id": "tool-123"}, + } + ) + assert result == {"contentBlockStart": {"start": {"toolUse": {"name": "get_weather", "toolUseId": "tool-123"}}}} + + def test_format_content_delta_text(self, model: xAIModel) -> None: + """Test formatting content_delta chunk for text.""" + result = model._format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": "Hello"}) + assert result == {"contentBlockDelta": {"delta": {"text": "Hello"}}} + + def test_format_content_delta_tool(self, model: xAIModel) -> None: + """Test formatting content_delta chunk for tool.""" + result = model._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": {"arguments": '{"location": "Paris"}'}, + } + ) + assert result == {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"location": "Paris"}'}}}} + + def test_format_content_delta_reasoning(self, model: xAIModel) -> None: + """Test formatting content_delta chunk for reasoning content.""" + result = model._format_chunk( + {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "Thinking..."} + ) + assert result == {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking..."}}}} + + def test_format_content_stop(self, model: xAIModel) -> None: + """Test formatting content_stop chunk.""" + result = model._format_chunk({"chunk_type": "content_stop"}) + assert result == {"contentBlockStop": {}} + + def test_format_message_stop_end_turn(self, model: xAIModel) -> None: + """Test formatting message_stop chunk with end_turn.""" + result = model._format_chunk({"chunk_type": "message_stop", "data": "end_turn"}) + assert result == {"messageStop": {"stopReason": "end_turn"}} + + def test_format_message_stop_tool_use(self, model: xAIModel) -> None: + """Test formatting message_stop chunk with tool_use.""" + result = model._format_chunk({"chunk_type": "message_stop", "data": "tool_use"}) + assert result == {"messageStop": {"stopReason": "tool_use"}} + + def test_format_message_stop_max_tokens(self, model: xAIModel) -> None: + """Test formatting message_stop chunk with max_tokens.""" + result = model._format_chunk({"chunk_type": "message_stop", "data": "max_tokens"}) + assert result == {"messageStop": {"stopReason": "max_tokens"}} + + def test_format_metadata(self, model: xAIModel) -> None: + """Test formatting metadata chunk.""" + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + mock_usage.reasoning_tokens = None + result = model._format_chunk({"chunk_type": "metadata", "data": mock_usage}) + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + + def test_format_metadata_with_reasoning_tokens(self, model: xAIModel) -> None: + """Test formatting metadata chunk with reasoning tokens.""" + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + mock_usage.reasoning_tokens = 25 + result = model._format_chunk({"chunk_type": "metadata", "data": mock_usage}) + assert result["metadata"]["usage"]["reasoningTokens"] == 25 + + def test_format_metadata_with_citations(self, model: xAIModel) -> None: + """Test formatting metadata chunk with citations.""" + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + mock_usage.reasoning_tokens = None + citations = [{"url": "https://example.com"}] + result = model._format_chunk({"chunk_type": "metadata", "data": mock_usage, "citations": citations}) + assert result["metadata"]["citations"] == citations + + def test_format_unknown_chunk_raises_error(self, model: xAIModel) -> None: + """Test that unknown chunk types raise RuntimeError.""" + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model._format_chunk({"chunk_type": "unknown"}) + + +class TestHandleStreamError: + """Unit tests for _handle_stream_error method.""" + + def test_rate_limit_error(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test that rate limit errors raise ModelThrottledException.""" + from strands.types.exceptions import ModelThrottledException + + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + error = Exception("Rate limit exceeded") + with pytest.raises(ModelThrottledException, match="Rate limit"): + model._handle_stream_error(error) + + def test_rate_limit_error_429(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test that 429 errors raise ModelThrottledException.""" + from strands.types.exceptions import ModelThrottledException + + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + error = Exception("Error 429: Too many requests") + with pytest.raises(ModelThrottledException, match="429"): + model._handle_stream_error(error) + + def test_too_many_requests_error(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test that 'too many requests' errors raise ModelThrottledException.""" + from strands.types.exceptions import ModelThrottledException + + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + error = Exception("Too many requests, please slow down") + with pytest.raises(ModelThrottledException, match="Too many requests"): + model._handle_stream_error(error) + + def test_context_length_error(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test that context length errors raise ContextWindowOverflowException.""" + from strands.types.exceptions import ContextWindowOverflowException + + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + error = Exception("Context length exceeded") + with pytest.raises(ContextWindowOverflowException, match="Context length"): + model._handle_stream_error(error) + + def test_maximum_context_error(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test that maximum context errors raise ContextWindowOverflowException.""" + from strands.types.exceptions import ContextWindowOverflowException + + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + error = Exception("Maximum context length reached") + with pytest.raises(ContextWindowOverflowException, match="Maximum context"): + model._handle_stream_error(error) + + def test_token_limit_error(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test that token limit errors raise ContextWindowOverflowException.""" + from strands.types.exceptions import ContextWindowOverflowException + + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + error = Exception("Token limit exceeded") + with pytest.raises(ContextWindowOverflowException, match="Token limit"): + model._handle_stream_error(error) + + def test_other_error_reraises(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test that other errors are re-raised unchanged.""" + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + error = Exception("Some other error") + with pytest.raises(Exception, match="Some other error"): + model._handle_stream_error(error) + + +class TestBuildChat: + """Unit tests for _build_chat method.""" + + def test_build_chat_basic(self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str) -> None: + """Test building a basic chat.""" + model = xAIModel(model_id=model_id) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + result = model._build_chat(mock_client) + + mock_client.chat.create.assert_called_once_with(model=model_id, store_messages=False) + assert result is mock_chat + + def test_build_chat_with_tools(self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str) -> None: + """Test building a chat with tools.""" + model = xAIModel(model_id=model_id) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + tool_specs = [{"name": "tool1", "description": "Tool", "inputSchema": {"json": {"type": "object"}}}] + + model._build_chat(mock_client, tool_specs) + + call_kwargs = mock_client.chat.create.call_args[1] + assert "tools" in call_kwargs + assert len(call_kwargs["tools"]) == 1 + + def test_build_chat_with_reasoning_effort( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test building a chat with reasoning_effort.""" + model = xAIModel(model_id=model_id, reasoning_effort="high") + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + model._build_chat(mock_client) + + call_kwargs = mock_client.chat.create.call_args[1] + assert call_kwargs["reasoning_effort"] == "high" + + def test_build_chat_with_include(self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str) -> None: + """Test building a chat with include options.""" + model = xAIModel(model_id=model_id, include=["verbose_streaming"]) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + model._build_chat(mock_client) + + call_kwargs = mock_client.chat.create.call_args[1] + assert call_kwargs["include"] == ["verbose_streaming"] + + def test_build_chat_with_params(self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str) -> None: + """Test building a chat with additional params.""" + model = xAIModel(model_id=model_id, params={"temperature": 0.7, "max_tokens": 1000}) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + model._build_chat(mock_client) + + call_kwargs = mock_client.chat.create.call_args[1] + assert call_kwargs["temperature"] == 0.7 + assert call_kwargs["max_tokens"] == 1000 + + +class TestAppendMessagesToChat: + """Unit tests for _append_messages_to_chat method.""" + + def test_append_system_prompt(self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str) -> None: + """Test appending system prompt.""" + model = xAIModel(model_id=model_id) + mock_chat = unittest.mock.Mock() + mock_xai_sdk_fixture["xai_system"].return_value = "system_msg" + + model._append_messages_to_chat(mock_chat, [], system_prompt="You are helpful") + + mock_xai_sdk_fixture["xai_system"].assert_called_once_with("You are helpful") + mock_chat.append.assert_called_once_with("system_msg") + + def test_append_user_message_with_text( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test appending user message with text.""" + model = xAIModel(model_id=model_id) + mock_chat = unittest.mock.Mock() + mock_xai_sdk_fixture["xai_user"].return_value = "user_msg" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + model._append_messages_to_chat(mock_chat, messages) + + mock_xai_sdk_fixture["xai_user"].assert_called_once_with("Hello") + mock_chat.append.assert_called_once_with("user_msg") + + def test_append_user_message_with_tool_result( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test appending user message with tool result.""" + model = xAIModel(model_id=model_id) + mock_chat = unittest.mock.Mock() + mock_xai_sdk_fixture["xai_tool_result"].return_value = "tool_result_msg" + messages = [ + {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [{"text": "Result"}]}}]} + ] + + model._append_messages_to_chat(mock_chat, messages) + + mock_xai_sdk_fixture["xai_tool_result"].assert_called_once_with("Result") + mock_chat.append.assert_called_once_with("tool_result_msg") + + def test_append_user_message_with_json_tool_result( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test appending user message with JSON tool result.""" + model = xAIModel(model_id=model_id) + mock_chat = unittest.mock.Mock() + mock_xai_sdk_fixture["xai_tool_result"].return_value = "tool_result_msg" + messages = [ + {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [{"json": {"key": "value"}}]}}]} + ] + + model._append_messages_to_chat(mock_chat, messages) + + mock_xai_sdk_fixture["xai_tool_result"].assert_called_once_with('{"key": "value"}') + + def test_append_assistant_message_with_text( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test that assistant messages are reconstructed as protobuf messages.""" + model = xAIModel(model_id=model_id) + mock_chat = unittest.mock.Mock() + messages = [{"role": "assistant", "content": [{"text": "Hello"}]}] + + model._append_messages_to_chat(mock_chat, messages) + + # Assistant messages should be appended as protobuf Message objects + mock_chat.append.assert_called_once() + # Verify the appended message is a protobuf Message with correct content + appended_msg = mock_chat.append.call_args[0][0] + assert appended_msg.role == 2 # ROLE_ASSISTANT + assert len(appended_msg.content) == 1 + assert appended_msg.content[0].text == "Hello" + + +class TestUpdateConfig: + """Unit tests for update_config method.""" + + def test_update_model_id(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test updating model_id.""" + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + model.update_config(model_id="grok-4-fast") + assert model.get_config()["model_id"] == "grok-4-fast" + + def test_update_params(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test updating params.""" + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + model.update_config(params={"temperature": 0.5}) + assert model.get_config()["params"] == {"temperature": 0.5} + + def test_update_reasoning_effort(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test updating reasoning_effort.""" + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + model.update_config(reasoning_effort="low") + assert model.get_config()["reasoning_effort"] == "low" + + def test_update_include(self, mock_xai_client_fixture: unittest.mock.Mock, model_id: str) -> None: + """Test updating include.""" + _ = mock_xai_client_fixture + model = xAIModel(model_id=model_id) + model.update_config(include=["inline_citations"]) + assert model.get_config()["include"] == ["inline_citations"] + + +class TestStream: + """Unit tests for stream method.""" + + @pytest.mark.asyncio + async def test_stream_basic_response( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test streaming a basic response.""" + model = xAIModel(model_id=model_id) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + # Create mock response and chunk + mock_response = unittest.mock.Mock() + mock_response.usage = unittest.mock.Mock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + mock_response.usage.reasoning_tokens = None + mock_response.citations = None + mock_response.encrypted_content = None # Explicitly set to avoid xAI state capture + + mock_chunk = unittest.mock.Mock() + mock_chunk.content = "Hello" + mock_chunk.reasoning_content = None + mock_chunk.tool_calls = None + + async def mock_stream(): + yield mock_response, mock_chunk + + mock_chat.stream.return_value = mock_stream() + + events = [] + async for event in model.stream(messages=[], system_prompt="Test"): + events.append(event) + + # Should have: message_start, content_start, content_delta, content_stop, message_stop, metadata + assert len(events) >= 5 + assert events[0] == {"messageStart": {"role": "assistant"}} + + @pytest.mark.asyncio + async def test_stream_with_tool_calls( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test streaming a response with client-side tool calls.""" + model = xAIModel(model_id=model_id) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + mock_response = unittest.mock.Mock() + mock_response.usage = unittest.mock.Mock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + mock_response.usage.reasoning_tokens = None + mock_response.citations = None + mock_response.encrypted_content = None + + mock_tool_call = unittest.mock.Mock() + mock_tool_call.id = "tool-123" + mock_tool_call.function = unittest.mock.Mock() + mock_tool_call.function.name = "get_weather" + mock_tool_call.function.arguments = '{"location": "Paris"}' + + mock_chunk = unittest.mock.Mock() + mock_chunk.content = None + mock_chunk.reasoning_content = None + mock_chunk.tool_calls = [mock_tool_call] + + async def mock_stream(): + yield mock_response, mock_chunk + + mock_chat.stream.return_value = mock_stream() + + events = [] + async for event in model.stream(messages=[]): + events.append(event) + + # Should have tool_use stop reason (get_tool_call_type is mocked to return "client_side_tool") + stop_events = [e for e in events if "messageStop" in e] + assert len(stop_events) == 1 + assert stop_events[0]["messageStop"]["stopReason"] == "tool_use" + + @pytest.mark.asyncio + async def test_stream_with_reasoning_content( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test streaming a response with reasoning content.""" + model = xAIModel(model_id=model_id) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + mock_response = unittest.mock.Mock() + mock_response.usage = unittest.mock.Mock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + mock_response.usage.reasoning_tokens = 20 + mock_response.citations = None + mock_response.encrypted_content = None # Explicitly set to None to avoid Mock auto-attribute + + mock_chunk = unittest.mock.Mock() + mock_chunk.content = None + mock_chunk.reasoning_content = "Thinking..." + mock_chunk.tool_calls = None + + async def mock_stream(): + yield mock_response, mock_chunk + + mock_chat.stream.return_value = mock_stream() + + events = [] + async for event in model.stream(messages=[]): + events.append(event) + + # Should have reasoning content delta with text (not encrypted) + reasoning_text_events = [ + e + for e in events + if "contentBlockDelta" in e + and "reasoningContent" in e.get("contentBlockDelta", {}).get("delta", {}) + and "text" in e.get("contentBlockDelta", {}).get("delta", {}).get("reasoningContent", {}) + ] + assert len(reasoning_text_events) == 1 + assert reasoning_text_events[0]["contentBlockDelta"]["delta"]["reasoningContent"]["text"] == "Thinking..." + + +class TestStructuredOutput: + """Unit tests for structured_output method.""" + + @pytest.mark.asyncio + async def test_structured_output_basic( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test structured output with a Pydantic model.""" + import pydantic + + class Weather(pydantic.BaseModel): + temperature: int + condition: str + + model = xAIModel(model_id=model_id) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + parsed_output = Weather(temperature=25, condition="sunny") + mock_response = unittest.mock.Mock() + + async def mock_parse(output_model): + return mock_response, parsed_output + + mock_chat.parse = mock_parse + + messages = [{"role": "user", "content": [{"text": "What's the weather?"}]}] + results = [] + async for result in model.structured_output(Weather, messages): + results.append(result) + + assert len(results) == 1 + assert results[0]["output"] == parsed_output + assert results[0]["output"].temperature == 25 + assert results[0]["output"].condition == "sunny" + + @pytest.mark.asyncio + async def test_structured_output_with_system_prompt( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test structured output with system prompt.""" + import pydantic + + class Result(pydantic.BaseModel): + value: str + + model = xAIModel(model_id=model_id) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + parsed_output = Result(value="test") + mock_response = unittest.mock.Mock() + + async def mock_parse(output_model): + return mock_response, parsed_output + + mock_chat.parse = mock_parse + mock_xai_sdk_fixture["xai_system"].return_value = "system_msg" + + messages = [{"role": "user", "content": [{"text": "Test"}]}] + results = [] + async for result in model.structured_output(Result, messages, system_prompt="Be helpful"): + results.append(result) + + mock_xai_sdk_fixture["xai_system"].assert_called_once_with("Be helpful") + + +class TestServerSideToolCalls: + """Unit tests for server-side tool call handling.""" + + def test_format_metadata_with_server_tool_calls(self, model: xAIModel) -> None: + """Test formatting metadata chunk with server tool calls.""" + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + mock_usage.reasoning_tokens = None + + server_tool_calls = [ + {"id": "tool-1", "name": "x_search", "arguments": '{"query": "test"}'}, + {"id": "tool-2", "name": "web_search", "arguments": '{"query": "hello"}'}, + ] + + result = model._format_chunk( + { + "chunk_type": "metadata", + "data": mock_usage, + "server_tool_calls": server_tool_calls, + } + ) + + assert "serverToolCalls" in result["metadata"] + assert len(result["metadata"]["serverToolCalls"]) == 2 + assert result["metadata"]["serverToolCalls"][0]["name"] == "x_search" + assert result["metadata"]["serverToolCalls"][1]["name"] == "web_search" + + def test_format_metadata_without_server_tool_calls(self, model: xAIModel) -> None: + """Test formatting metadata chunk without server tool calls.""" + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + mock_usage.reasoning_tokens = None + + result = model._format_chunk( + { + "chunk_type": "metadata", + "data": mock_usage, + } + ) + + assert "serverToolCalls" not in result["metadata"] + + @pytest.mark.asyncio + async def test_stream_with_server_side_tool_calls( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test streaming a response with server-side tool calls (not executed by Strands).""" + # Override get_tool_call_type to return server_side_tool for this test + mock_xai_sdk_fixture["get_tool_call_type"].return_value = "server_side_tool" + + model = xAIModel(model_id=model_id) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + # Mock chat.messages for state capture + mock_msg = unittest.mock.Mock() + mock_msg.role = 1 # ROLE_USER + mock_msg.SerializeToString.return_value = b"mock_serialized" + mock_chat.messages = [mock_msg] + + mock_response = unittest.mock.Mock() + mock_response.usage = unittest.mock.Mock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + mock_response.usage.reasoning_tokens = None + mock_response.citations = None + mock_response.encrypted_content = None + + # Server-side tool call (e.g., x_search) + mock_tool_call = unittest.mock.Mock() + mock_tool_call.id = "server-tool-123" + mock_tool_call.function = unittest.mock.Mock() + mock_tool_call.function.name = "x_search" + mock_tool_call.function.arguments = '{"query": "test"}' + + mock_chunk = unittest.mock.Mock() + mock_chunk.content = "Here are the search results..." + mock_chunk.reasoning_content = None + mock_chunk.tool_calls = [mock_tool_call] + + async def mock_stream(): + yield mock_response, mock_chunk + + mock_chat.stream.return_value = mock_stream() + + events = [] + async for event in model.stream(messages=[]): + events.append(event) + + # Server-side tools should NOT trigger tool_use stop reason + stop_events = [e for e in events if "messageStop" in e] + assert len(stop_events) == 1 + assert stop_events[0]["messageStop"]["stopReason"] == "end_turn" + + # Server-side tools should be in metadata + metadata_events = [e for e in events if "metadata" in e] + assert len(metadata_events) == 1 + assert "serverToolCalls" in metadata_events[0]["metadata"] + assert len(metadata_events[0]["metadata"]["serverToolCalls"]) == 1 + assert metadata_events[0]["metadata"]["serverToolCalls"][0]["name"] == "x_search" + + @pytest.mark.asyncio + async def test_stream_with_mixed_tool_calls( + self, mock_xai_sdk_fixture: dict[str, unittest.mock.Mock], model_id: str + ) -> None: + """Test streaming with both client-side and server-side tool calls.""" + model = xAIModel(model_id=model_id) + mock_client = mock_xai_sdk_fixture["client"] + mock_chat = unittest.mock.Mock() + mock_client.chat.create.return_value = mock_chat + + # Mock chat.messages for state capture + mock_msg = unittest.mock.Mock() + mock_msg.role = 1 # ROLE_USER + mock_msg.SerializeToString.return_value = b"mock_serialized" + mock_chat.messages = [mock_msg] + + mock_response = unittest.mock.Mock() + mock_response.usage = unittest.mock.Mock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + mock_response.usage.reasoning_tokens = None + mock_response.citations = None + mock_response.encrypted_content = None + + # Client-side tool call + mock_client_tool = unittest.mock.Mock() + mock_client_tool.id = "client-tool-123" + mock_client_tool.function = unittest.mock.Mock() + mock_client_tool.function.name = "get_weather" + mock_client_tool.function.arguments = '{"city": "Paris"}' + + # Server-side tool call + mock_server_tool = unittest.mock.Mock() + mock_server_tool.id = "server-tool-456" + mock_server_tool.function = unittest.mock.Mock() + mock_server_tool.function.name = "x_search" + mock_server_tool.function.arguments = '{"query": "weather"}' + + mock_chunk = unittest.mock.Mock() + mock_chunk.content = None + mock_chunk.reasoning_content = None + mock_chunk.tool_calls = [mock_client_tool, mock_server_tool] + + # Mock get_tool_call_type to return different types based on tool + def mock_get_type(tool_call): + if tool_call.function.name == "get_weather": + return "client_side_tool" + return "server_side_tool" + + mock_xai_sdk_fixture["get_tool_call_type"].side_effect = mock_get_type + + async def mock_stream(): + yield mock_response, mock_chunk + + mock_chat.stream.return_value = mock_stream() + + events = [] + async for event in model.stream(messages=[]): + events.append(event) + + # Should have tool_use stop reason (client-side tool present) + stop_events = [e for e in events if "messageStop" in e] + assert len(stop_events) == 1 + assert stop_events[0]["messageStop"]["stopReason"] == "tool_use" + + # Should have client-side tool in content blocks + tool_start_events = [ + e + for e in events + if "contentBlockStart" in e and e.get("contentBlockStart", {}).get("start", {}).get("toolUse") + ] + assert len(tool_start_events) == 1 + assert tool_start_events[0]["contentBlockStart"]["start"]["toolUse"]["name"] == "get_weather" + + # Server-side tools should be in metadata + metadata_events = [e for e in events if "metadata" in e] + assert len(metadata_events) == 1 + assert "serverToolCalls" in metadata_events[0]["metadata"] + assert len(metadata_events[0]["metadata"]["serverToolCalls"]) == 1 + assert metadata_events[0]["metadata"]["serverToolCalls"][0]["name"] == "x_search" + + def test_format_content_delta_server_tool(self, model: xAIModel) -> None: + """Test formatting content_delta chunk for server-side tool (inline text).""" + tool_data = {"id": "tool-123", "name": "x_search", "arguments": '{"query": "test"}'} + result = model._format_chunk( + { + "chunk_type": "content_delta", + "data_type": "server_tool", + "data": tool_data, + } + ) + assert "contentBlockDelta" in result + assert "text" in result["contentBlockDelta"]["delta"] + assert "x_search" in result["contentBlockDelta"]["delta"]["text"] + assert '{"query": "test"}' in result["contentBlockDelta"]["delta"]["text"] diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 57614b97f..ed63e373e 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -18,6 +18,11 @@ from strands.models.openai import OpenAIModel from strands.models.writer import WriterModel +try: + from strands.models.xai import xAIModel +except ImportError: + xAIModel = None # type: ignore[misc,assignment] + class ProviderInfo: """Provider-based info for providers that require an APIKey via environment variables.""" @@ -137,6 +142,29 @@ def __init__(self): ), ) + +class xAIProviderInfo(ProviderInfo): + """Special case for xAI as it requires the xai-sdk package.""" + + def __init__(self): + super().__init__( + id="xai", + environment_variable="XAI_API_KEY", + factory=lambda: xAIModel( + client_args={"api_key": os.getenv("XAI_API_KEY")}, + model_id="grok-3-mini-fast-latest", + params={"temperature": 0.15}, + ) + if xAIModel is not None + else None, # type: ignore[return-value] + ) + # Add additional skip condition if xai-sdk is not installed + if xAIModel is None: + self.mark = mark.skipif(True, reason="xai-sdk package not installed") + + +xai = xAIProviderInfo() + ollama = OllamaProviderInfo() diff --git a/tests_integ/models/test_model_xai.py b/tests_integ/models/test_model_xai.py new file mode 100644 index 000000000..24d9ddf58 --- /dev/null +++ b/tests_integ/models/test_model_xai.py @@ -0,0 +1,381 @@ +"""Integration tests for the xAI model provider. + +These tests require a valid XAI_API_KEY environment variable. +""" + +import os + +import pydantic +import pytest + +import strands +from strands import Agent +from tests_integ.models import providers + +# Skip all tests if XAI_API_KEY is not set or xai-sdk is not installed +pytestmark = providers.xai.mark + +# Import xAIModel only if available +try: + from strands.models.xai import xAIModel +except ImportError: + xAIModel = None # type: ignore[misc,assignment] + + +@pytest.fixture +def model(): + """Create a basic xAIModel instance.""" + return xAIModel( + client_args={"api_key": os.getenv("XAI_API_KEY")}, + model_id="grok-4-1-fast-non-reasoning-latest", + params={"temperature": 0.15}, + ) + + +@pytest.fixture +def reasoning_model(): + """Create a xAIModel instance with reasoning enabled.""" + return xAIModel( + client_args={"api_key": os.getenv("XAI_API_KEY")}, + model_id="grok-3-mini-fast-latest", # reasoning_effort only supported by grok-3-mini + reasoning_effort="low", + params={"temperature": 0.15}, + ) + + +@pytest.fixture +def tools(): + """Create test tools for function calling.""" + + @strands.tool + def tool_time() -> str: + """Get the current time.""" + return "12:00" + + @strands.tool + def tool_weather(city: str) -> str: + """Get the weather for a city.""" + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def system_prompt(): + """Default system prompt for tests.""" + return "You are a helpful AI assistant." + + +@pytest.fixture +def assistant_agent(model, system_prompt): + """Create an agent without tools.""" + return Agent(model=model, system_prompt=system_prompt) + + +@pytest.fixture +def tool_agent(model, tools, system_prompt): + """Create an agent with tools.""" + return Agent(model=model, tools=tools, system_prompt=system_prompt) + + +@pytest.fixture +def weather(): + """Pydantic model for structured output tests.""" + + class Weather(pydantic.BaseModel): + """Extracts the time and weather from the user's message with the exact strings.""" + + time: str + weather: str + + return Weather(time="12:00", weather="sunny") + + +@pytest.fixture +def yellow_color(): + """Pydantic model for image analysis tests.""" + + class Color(pydantic.BaseModel): + """Describes a color.""" + + name: str + + @pydantic.field_validator("name", mode="after") + @classmethod + def lower(cls, value): + return value.lower() + + return Color(name="yellow") + + +# Basic chat completion tests + + +def test_agent_invoke(tool_agent): + """Test basic agent invocation with tools.""" + result = tool_agent("What is the current time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +@pytest.mark.asyncio +async def test_agent_invoke_async(tool_agent): + """Test async agent invocation with tools.""" + result = await tool_agent.invoke_async("What is the current time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +# Streaming tests + + +@pytest.mark.asyncio +async def test_agent_stream_async(tool_agent): + """Test async streaming with tools.""" + stream = tool_agent.stream_async("What is the current time and weather in New York?") + async for event in stream: + _ = event + + result = event["result"] + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) + + +def test_agent_invoke_multiturn(assistant_agent): + """Test multi-turn conversation.""" + assistant_agent("What color is the sky?") + assistant_agent("What color is lava?") + result = assistant_agent("What was the answer to my first question?") + text = result.message["content"][0]["text"].lower() + + assert "blue" in text + + +# Structured output tests + + +def test_agent_structured_output(assistant_agent, weather): + """Test structured output parsing.""" + result = assistant_agent( + "The time is 12:00 and the weather is sunny", + structured_output_model=type(weather), + ) + assert result.structured_output == weather + + +@pytest.mark.asyncio +async def test_agent_structured_output_async(assistant_agent, weather): + """Test async structured output parsing.""" + result = await assistant_agent.invoke_async( + "The time is 12:00 and the weather is sunny", + structured_output_model=type(weather), + ) + assert result.structured_output == weather + + +# Image understanding tests + + +def test_agent_invoke_image_input(assistant_agent, yellow_img): + """Test image input processing.""" + content = [ + {"text": "what is in this image"}, + { + "image": { + "format": "png", + "source": { + "bytes": yellow_img, + }, + }, + }, + ] + result = assistant_agent(content) + text = result.message["content"][0]["text"].lower() + + assert "yellow" in text + + +def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow_color): + """Test structured output with image input.""" + content = [ + {"text": "Is this image red, blue, or yellow?"}, + { + "image": { + "format": "png", + "source": { + "bytes": yellow_img, + }, + }, + }, + ] + result = assistant_agent(content, structured_output_model=type(yellow_color)) + assert result.structured_output == yellow_color + + +# Reasoning model tests + + +def test_reasoning_model_basic(reasoning_model): + """Test basic reasoning model invocation.""" + agent = Agent(model=reasoning_model, system_prompt="You are a helpful assistant.") + result = agent("What is 15 + 27?") + + # Reasoning models may return reasoningContent before text, so check all content blocks + text_content = "" + for content_block in result.message["content"]: + if "text" in content_block: + text_content += content_block["text"] + + assert "42" in text_content + + +# System prompt tests + + +def test_system_prompt_content_integration(model): + """Test system_prompt_content parameter.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + { + "text": "IMPORTANT: You MUST respond with ONLY the exact text " + "'SYSTEM_TEST_RESPONSE' and nothing else. No greetings, no explanations." + } + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + result = agent("Say the magic words") + + assert "SYSTEM_TEST_RESPONSE" in result.message["content"][0]["text"] + + +def test_system_prompt_backward_compatibility_integration(model): + """Test backward compatibility with system_prompt parameter.""" + system_prompt = ( + "IMPORTANT: You MUST respond with ONLY the exact text 'BACKWARD_COMPAT_TEST' " + "and nothing else. No greetings, no explanations." + ) + + agent = Agent(model=model, system_prompt=system_prompt) + result = agent("Say the magic words") + + assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"] + + +# Content blocks handling + + +def test_content_blocks_handling(model): + """Test that content blocks are handled properly without failures.""" + content = [{"text": "What is 2+2?"}, {"text": "Please be brief."}] + + agent = Agent(model=model, load_tools_from_directory=False) + result = agent(content) + + assert "4" in result.message["content"][0]["text"] + + +# Reasoning model with tools tests + + +@pytest.fixture +def reasoning_model_with_tools(): + """Create a grok-4 reasoning model for tool tests.""" + return xAIModel( + client_args={"api_key": os.getenv("XAI_API_KEY")}, + model_id="grok-4-1-fast-reasoning-latest", + params={"temperature": 0.15}, + ) + + +def test_reasoning_model_with_tools(reasoning_model_with_tools, tools): + """Test reasoning model with function calling tools.""" + agent = Agent( + model=reasoning_model_with_tools, + tools=tools, + system_prompt="You are a helpful assistant.", + ) + result = agent("What is the current time?") + + text_content = "" + for content_block in result.message["content"]: + if "text" in content_block: + text_content += content_block["text"] + + assert "12:00" in text_content + + +# Encrypted content for multi-turn reasoning tests + + +@pytest.fixture +def encrypted_reasoning_model(): + """Create a grok-4 reasoning model with encrypted content enabled.""" + return xAIModel( + client_args={"api_key": os.getenv("XAI_API_KEY")}, + model_id="grok-4-1-fast-reasoning-latest", + use_encrypted_content=True, + params={"temperature": 0.15}, + ) + + +def test_encrypted_content_multi_turn(encrypted_reasoning_model): + """Test multi-turn conversation with encrypted reasoning content. + + When use_encrypted_content=True, the model returns encrypted reasoning + that must be passed back for context continuity in reasoning models. + """ + agent = Agent( + model=encrypted_reasoning_model, + system_prompt="You are a helpful assistant with perfect memory.", + ) + + # Turn 1: Give the model something to remember + agent("Remember this secret code: ALPHA-7. Just confirm you got it.") + + # Turn 2: Ask for recall + result = agent("What was the secret code I gave you?") + + text_content = "" + for content_block in result.message["content"]: + if "text" in content_block: + text_content += content_block["text"] + + assert "ALPHA-7" in text_content + + +# Server-side tools (xai_tools) tests + + +def test_server_side_web_search(): + """Test server-side web_search tool. + + The xAI SDK provides server-side tools that run on xAI's infrastructure. + """ + from xai_sdk.tools import web_search + + model = xAIModel( + client_args={"api_key": os.getenv("XAI_API_KEY")}, + model_id="grok-4-1-fast-non-reasoning-latest", + xai_tools=[web_search()], + params={"temperature": 0.15}, + ) + + agent = Agent( + model=model, + system_prompt="You are a helpful assistant. Use web search when needed.", + ) + + # Ask something that requires current information + result = agent("What is the current year?") + + text_content = "" + for content_block in result.message["content"]: + if "text" in content_block: + text_content += content_block["text"] + + # Should return current year (2025 or 2026 depending on when test runs) + assert any(year in text_content for year in ["2025", "2026"])