diff --git a/AGENTS.md b/AGENTS.md index f08200b..dafd9da 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -93,3 +93,6 @@ please follow these guidelines: * **Code Validation**: Don't use `py_compile` for syntax checking. Use `pyright` or `make check` instead for proper type checking and validation. + +* **Deprecations**: Don't deprecate things -- just delete them and fix the usage sites. + Don't create backward compatibility APIs or exports or whatever. Fix the usage sites. diff --git a/pyproject.toml b/pyproject.toml index dff653b..379f43e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "numpy>=2.2.6", "openai>=1.81.0", "pydantic>=2.11.4", + "pydantic-ai-slim[openai]>=1.39.0", "pyreadline3>=3.5.4 ; sys_platform == 'win32'", "python-dotenv>=1.1.0", "tiktoken>=0.12.0", @@ -85,11 +86,10 @@ dev = [ "google-auth-httplib2>=0.2.0", "google-auth-oauthlib>=1.2.2", "isort>=7.0.0", - "logfire>=4.1.0", # So 'make check' passes + "logfire>=4.1.0", # So 'make check' passes "msgraph-sdk>=1.54.0", "opentelemetry-instrumentation-httpx>=0.57b0", - "pydantic-ai-slim[openai]>=1.39.0", - "pyright>=1.1.408", # 407 has a regression + "pyright>=1.1.408", # 407 has a regression "pytest>=8.3.5", "pytest-asyncio>=0.26.0", "pytest-mock>=3.14.0", diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index 819993f..8b579df 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -1,313 +1,123 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import asyncio -import os +from typing import Protocol, runtime_checkable import numpy as np from numpy.typing import NDArray -from openai import AsyncAzureOpenAI, AsyncOpenAI, DEFAULT_MAX_RETRIES, OpenAIError -from openai.types import Embedding -import tiktoken -from tiktoken import model as tiktoken_model -from tiktoken.core import Encoding - -from .auth import AzureTokenProvider, get_shared_token_provider -from .utils import timelog - type NormalizedEmbedding = NDArray[np.float32] # A single embedding type NormalizedEmbeddings = NDArray[np.float32] # An array of embeddings -DEFAULT_MODEL_NAME = "text-embedding-ada-002" -DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002) -DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI -TEST_MODEL_NAME = "test" -MAX_BATCH_SIZE = 2048 -MAX_TOKEN_SIZE = 4096 -MAX_TOKENS_PER_BATCH = 300_000 -MAX_CHAR_SIZE = MAX_TOKEN_SIZE * 3 -MAX_CHARS_PER_BATCH = MAX_TOKENS_PER_BATCH * 3 - -model_to_embedding_size_and_envvar: dict[str, tuple[int | None, str]] = { - DEFAULT_MODEL_NAME: (DEFAULT_EMBEDDING_SIZE, DEFAULT_ENVVAR), - "text-embedding-3-small": (1536, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL"), - "text-embedding-3-large": (3072, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE"), - # For testing only, not a real model (insert real embeddings above) - TEST_MODEL_NAME: (3, "SIR_NOT_APPEARING_IN_THIS_FILM"), -} +@runtime_checkable +class IEmbedder(Protocol): + """Minimal provider interface for embedding models. + + Implement this protocol to add support for a new embedding provider + (e.g. Anthropic, Gemini, local models). Only raw embedding computation + is required; caching is handled by :class:`CachingEmbeddingModel`. + + The production implementation is + :class:`~typeagent.aitools.model_adapters.PydanticAIEmbedder`. + """ + + @property + def model_name(self) -> str: ... + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + """Compute a single embedding without caching.""" + ... + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + """Compute embeddings for a batch of strings without caching. + + Raises :class:`ValueError` if *input* is empty. + """ + ... + + +@runtime_checkable +class IEmbeddingModel(Protocol): + """Consumer-facing interface for embedding models with caching. + + This extends the provider interface (:class:`IEmbedder`) with caching + methods. Use :class:`CachingEmbeddingModel` to wrap an :class:`IEmbedder` + and get a ready-to-use ``IEmbeddingModel``. + """ + + @property + def model_name(self) -> str: ... + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + """Cache an already-computed embedding under the given key.""" + ... + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + """Compute a single embedding without caching.""" + ... + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + """Compute embeddings for a batch of strings without caching.""" + ... + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + """Retrieve a single embedding, using cache if available.""" + ... + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + """Retrieve embeddings for multiple keys, using cache if available.""" + ... -class AsyncEmbeddingModel: - model_name: str - embedding_size: int - endpoint_envvar: str - use_azure: bool - azure_token_provider: AzureTokenProvider | None - async_client: AsyncOpenAI | None - azure_endpoint: str - azure_api_version: str - encoding: Encoding | None - max_chunk_size: int - max_size_per_batch: int - - _embedding_cache: dict[str, NormalizedEmbedding] - - def __init__( - self, - embedding_size: int | None = None, - model_name: str | None = None, - endpoint_envvar: str | None = None, - max_retries: int = DEFAULT_MAX_RETRIES, - ): - if model_name is None: - model_name = DEFAULT_MODEL_NAME - self.model_name = model_name - - suggested_embedding_size, suggested_endpoint_envvar = ( - model_to_embedding_size_and_envvar.get(model_name, (None, None)) - ) - - if embedding_size is None: - if suggested_embedding_size is not None: - embedding_size = suggested_embedding_size - else: - embedding_size = DEFAULT_EMBEDDING_SIZE - self.embedding_size = embedding_size - - if ( - model_name == DEFAULT_MODEL_NAME - and embedding_size != DEFAULT_EMBEDDING_SIZE - ): - raise ValueError( - f"Cannot customize embedding_size for default model {DEFAULT_MODEL_NAME}" - ) - - # Read API keys once - openai_api_key = os.getenv("OPENAI_API_KEY") - azure_api_key = os.getenv("AZURE_OPENAI_API_KEY") - - # Prefer OpenAI if both are set, use Azure if only Azure is set - self.use_azure = bool(azure_api_key) and not bool(openai_api_key) - - if endpoint_envvar is None: - # Check if OpenAI credentials are available, prefer OpenAI over Azure - if openai_api_key: - endpoint_envvar = "OPENAI_BASE_URL" # Use OpenAI - elif suggested_endpoint_envvar is not None: - endpoint_envvar = suggested_endpoint_envvar - else: - endpoint_envvar = DEFAULT_ENVVAR - - self.endpoint_envvar = endpoint_envvar - self.azure_token_provider = None - - if self.model_name == TEST_MODEL_NAME: - self.async_client = None - elif self.use_azure: - if not azure_api_key: - raise ValueError("AZURE_OPENAI_API_KEY not found in environment.") - with timelog("Using Azure OpenAI"): - self._setup_azure(azure_api_key) - else: - if not openai_api_key: - raise ValueError("OPENAI_API_KEY not found in environment.") - endpoint = os.getenv(self.endpoint_envvar) - with timelog("Using OpenAI"): - self.async_client = AsyncOpenAI( - base_url=endpoint, api_key=openai_api_key, max_retries=max_retries - ) - - if self.model_name in tiktoken_model.MODEL_TO_ENCODING: - encoding_name = tiktoken.encoding_name_for_model(self.model_name) - self.encoding = tiktoken.get_encoding(encoding_name) - self.max_chunk_size = MAX_TOKEN_SIZE - self.max_size_per_batch = MAX_TOKENS_PER_BATCH - else: - self.encoding = None - self.max_chunk_size = MAX_CHAR_SIZE - self.max_size_per_batch = MAX_CHARS_PER_BATCH - - self._embedding_cache = {} - - def _setup_azure(self, azure_api_key: str) -> None: - from .utils import get_azure_api_key, parse_azure_endpoint - - azure_api_key = get_azure_api_key(azure_api_key) - self.azure_endpoint, self.azure_api_version = parse_azure_endpoint( - self.endpoint_envvar - ) - - if azure_api_key != os.getenv("AZURE_OPENAI_API_KEY"): - # If we got a token from identity, store the provider for refresh - self.azure_token_provider = get_shared_token_provider() - - self.async_client = AsyncAzureOpenAI( - api_version=self.azure_api_version, - azure_endpoint=self.azure_endpoint, - api_key=azure_api_key, - ) - - async def refresh_auth(self): - """Update client when using a token provider and it's nearly expired.""" - # refresh_token is synchronous and slow -- run it in a separate thread - assert self.azure_token_provider - refresh_token = self.azure_token_provider.refresh_token - loop = asyncio.get_running_loop() - azure_api_key = await loop.run_in_executor(None, refresh_token) - assert self.azure_api_version - assert self.azure_endpoint - self.async_client = AsyncAzureOpenAI( - api_version=self.azure_api_version, - azure_endpoint=self.azure_endpoint, - api_key=azure_api_key, - ) +class CachingEmbeddingModel: + """Wraps an :class:`IEmbedder` with an in-memory embedding cache. + + This shared base class implements the caching logic once, so individual + embedding providers only need to implement the minimal :class:`IEmbedder` + protocol (``get_embedding_nocache`` / ``get_embeddings_nocache``). + """ + + def __init__(self, embedder: IEmbedder) -> None: + self._embedder = embedder + self._cache: dict[str, NormalizedEmbedding] = {} + + @property + def model_name(self) -> str: + return self._embedder.model_name def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: - existing = self._embedding_cache.get(key) - if existing is not None: - assert np.array_equal(existing, embedding) - else: - self._embedding_cache[key] = embedding + self._cache[key] = embedding async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: - embeddings = await self.get_embeddings_nocache([input]) - return embeddings[0] + return await self._embedder.get_embedding_nocache(input) async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: - if not input: - empty = np.array([], dtype=np.float32) - empty.shape = (0, self.embedding_size) - return empty - if self.azure_token_provider and self.azure_token_provider.needs_refresh(): - await self.refresh_auth() - extra_args = {} - if self.model_name != DEFAULT_MODEL_NAME: - extra_args["dimensions"] = self.embedding_size - if self.async_client is None: - # Compute a random embedding for testing purposes. - - def hashish(s: str) -> int: - # Primitive deterministic hash function (hash() varies per run) - h = 0 - for ch in s: - h = (h * 31 + ord(ch)) & 0xFFFFFFFF - return h - - prime = 1961 - fake_data: list[NormalizedEmbedding] = [] - for item in input: - if not item: - raise OpenAIError - length = len(item) - floats = [] - for i in range(self.embedding_size): - cut = i % length - scrambled = item[cut:] + item[:cut] - hashed = hashish(scrambled) - reduced = (hashed % prime) / prime - floats.append(reduced) - array = np.array(floats, dtype=np.float64) - normalized = array / np.sqrt(np.dot(array, array)) - dot = np.dot(normalized, normalized) - assert ( - abs(dot - 1.0) < 1e-15 - ), f"Embedding {normalized} is not normalized: {dot}" - fake_data.append(normalized) - assert len(fake_data) == len(input), (len(fake_data), "!=", len(input)) - result = np.array(fake_data, dtype=np.float32) - return result - else: - batches: list[list[str]] = [] - batch: list[str] = [] - batch_sum: int = 0 - for sentence in input: - truncated_input, truncated_input_size = await self.truncate_input( - sentence - ) - if ( - len(batch) >= MAX_BATCH_SIZE - or batch_sum + truncated_input_size > self.max_size_per_batch - ): - batches.append(batch) - batch = [] - batch_sum = 0 - batch.append(truncated_input) - batch_sum += truncated_input_size - if batch: - batches.append(batch) - - data: list[Embedding] = [] - for batch in batches: - embeddings_data = ( - await self.async_client.embeddings.create( - input=batch, - model=self.model_name, - encoding_format="float", - **extra_args, - ) - ).data - data.extend(embeddings_data) - - assert len(data) == len(input), (len(data), "!=", len(input)) - return np.array([d.embedding for d in data], dtype=np.float32) + return await self._embedder.get_embeddings_nocache(input) async def get_embedding(self, key: str) -> NormalizedEmbedding: - """Retrieve an embedding, using the cache.""" - if key in self._embedding_cache: - return self._embedding_cache[key] - embedding = await self.get_embedding_nocache(key) - self._embedding_cache[key] = embedding + cached = self._cache.get(key) + if cached is not None: + return cached + embedding = await self._embedder.get_embedding_nocache(key) + self._cache[key] = embedding return embedding async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: - """Retrieve embeddings for multiple keys, using the cache.""" - embeddings: list[NormalizedEmbedding | None] = [] - missing_keys: list[str] = [] - - # Collect cached embeddings and identify missing keys - for key in keys: - if key in self._embedding_cache: - embeddings.append(self._embedding_cache[key]) - else: - embeddings.append(None) # Placeholder for missing keys - missing_keys.append(key) - - # Retrieve embeddings for missing keys + if not keys: + raise ValueError("Cannot embed an empty list") + missing_keys = [k for k in keys if k not in self._cache] if missing_keys: - new_embeddings = await self.get_embeddings_nocache(missing_keys) - for key, embedding in zip(missing_keys, new_embeddings): - self._embedding_cache[key] = embedding - - # Replace placeholders with retrieved embeddings - for i, key in enumerate(keys): - if embeddings[i] is None: - embeddings[i] = self._embedding_cache[key] - return np.array(embeddings, dtype=np.float32).reshape( - (len(keys), self.embedding_size) - ) - - async def truncate_input(self, input: str) -> tuple[str, int]: - """Truncate input strings to fit within model limits. - - args: - input: The input string to truncate. - - returns: - A tuple of (truncated string, size after truncation). - """ - if self.encoding is None: - # Non-token-aware truncation - if len(input) > self.max_chunk_size: - return input[: self.max_chunk_size], self.max_chunk_size - else: - return input, len(input) - else: - # Token-aware truncation - tokens = self.encoding.encode(input) - if len(tokens) > self.max_chunk_size: - truncated_tokens = tokens[: self.max_chunk_size] - return self.encoding.decode(truncated_tokens), self.max_chunk_size - else: - return input, len(tokens) + fresh = await self._embedder.get_embeddings_nocache(missing_keys) + for i, k in enumerate(missing_keys): + self._cache[k] = fresh[i] + return np.array([self._cache[k] for k in keys], dtype=np.float32) + + +TEST_MODEL_NAME = "test" + +model_to_envvar: dict[str, str] = { + "text-embedding-ada-002": "AZURE_OPENAI_ENDPOINT_EMBEDDING", + "text-embedding-3-small": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL", + "text-embedding-3-large": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE", +} diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py new file mode 100644 index 0000000..34d5ac8 --- /dev/null +++ b/src/typeagent/aitools/model_adapters.py @@ -0,0 +1,404 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Provider-agnostic model configuration backed by pydantic_ai. + +Create chat and embedding models from ``provider:model`` spec strings:: + + from typeagent.aitools.model_adapters import configure_models + + chat, embedder = configure_models( + "openai:gpt-4o", + "openai:text-embedding-3-small", + ) + +The spec format is ``provider:model``, matching pydantic_ai conventions. +Provider wiring (API keys, endpoints, etc.) is handled by pydantic_ai's +model registry, which supports 25+ providers including ``openai``, +``azure``, ``anthropic``, ``google``, ``bedrock``, ``groq``, ``mistral``, +``ollama``, ``cohere``, and many more. + +When a spec uses ``openai:`` as the provider and ``OPENAI_API_KEY`` is not +set, but ``AZURE_OPENAI_API_KEY`` is available, the provider is +automatically switched to Azure OpenAI. + +See https://ai.pydantic.dev/models/ for all supported providers and their +required environment variables. +""" + +from collections.abc import Sequence +import os + +import numpy as np +from numpy.typing import NDArray + +from pydantic_ai import Embedder as _PydanticAIEmbedder +from pydantic_ai.embeddings.base import EmbeddingModel as _PydanticAIEmbeddingModelBase +from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType +from pydantic_ai.embeddings.settings import EmbeddingSettings +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + SystemPromptPart, + TextPart, + UserPromptPart, +) +from pydantic_ai.models import infer_model, Model, ModelRequestParameters +import typechat + +from .embeddings import ( + CachingEmbeddingModel, + NormalizedEmbedding, + NormalizedEmbeddings, +) + +# --------------------------------------------------------------------------- +# Chat model adapter +# --------------------------------------------------------------------------- + + +class PydanticAIChatModel(typechat.TypeChatLanguageModel): + """Adapter from :class:`pydantic_ai.models.Model` to TypeChat's + :class:`~typechat.TypeChatLanguageModel`. + + This lets any pydantic_ai chat model (OpenAI, Anthropic, Google, …) be + used wherever TypeChat expects a ``TypeChatLanguageModel``. + """ + + def __init__(self, model: Model) -> None: + self._model = model + + async def complete( + self, prompt: str | list[typechat.PromptSection] + ) -> typechat.Result[str]: + parts: list[SystemPromptPart | UserPromptPart] = [] + if isinstance(prompt, str): + parts.append(UserPromptPart(content=prompt)) + else: + for section in prompt: + if section["role"] == "system": + parts.append(SystemPromptPart(content=section["content"])) + else: + parts.append(UserPromptPart(content=section["content"])) + + messages: list[ModelMessage] = [ModelRequest(parts=parts)] + params = ModelRequestParameters() + + response = await self._model.request(messages, None, params) + text_parts = [p.content for p in response.parts if isinstance(p, TextPart)] + if text_parts: + return typechat.Success("".join(text_parts)) + return typechat.Failure("No text content in model response") + + +# --------------------------------------------------------------------------- +# Embedding model adapter +# --------------------------------------------------------------------------- + + +class PydanticAIEmbedder: + """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbedder`. + + This lets any pydantic_ai embedding provider (OpenAI, Cohere, Google, …) + be used wherever the codebase expects an ``IEmbedder``. Wrap in + :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` to get a + ready-to-use ``IEmbeddingModel`` with caching. + """ + + model_name: str + + def __init__( + self, + embedder: _PydanticAIEmbedder, + model_name: str, + ) -> None: + self._embedder = embedder + self.model_name = model_name + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + result = await self._embedder.embed_documents([input]) + embedding: NDArray[np.float32] = np.array( + result.embeddings[0], dtype=np.float32 + ) + norm = float(np.linalg.norm(embedding)) + if norm > 0: + embedding = (embedding / norm).astype(np.float32) + return embedding + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + if not input: + raise ValueError("Cannot embed an empty list") + result = await self._embedder.embed_documents(input) + embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) + norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) + norms = np.where(norms > 0, norms, np.float32(1.0)) + embeddings = (embeddings / norms).astype(np.float32) + return embeddings + + +# --------------------------------------------------------------------------- +# Provider auto-detection +# --------------------------------------------------------------------------- + + +def _needs_azure_fallback(provider: str) -> bool: + """Return True if *provider* is ``openai`` but only Azure credentials exist.""" + return ( + provider == "openai" + and not os.getenv("OPENAI_API_KEY") + and bool(os.getenv("AZURE_OPENAI_API_KEY")) + ) + + +def _make_azure_provider( + endpoint_envvar: str = "AZURE_OPENAI_ENDPOINT", + api_key_envvar: str = "AZURE_OPENAI_API_KEY", +): + """Create a :class:`pydantic_ai.providers.azure.AzureProvider`. + + Constructs an ``AsyncAzureOpenAI`` client from the given environment + variables and wraps it in an ``AzureProvider``. The endpoint env-var + may contain a full Azure deployment URL (including path and + ``api-version`` query parameter) — the same format used throughout + this codebase. + + When ``AZURE_OPENAI_API_KEY`` is set to ``"identity"``, the client + uses Azure Managed Identity via a token provider callback, which + refreshes tokens automatically before each request. + """ + from openai import AsyncAzureOpenAI + from pydantic_ai.providers.azure import AzureProvider + + from .utils import parse_azure_endpoint + + raw_key = os.environ[api_key_envvar] + azure_endpoint, api_version = parse_azure_endpoint(endpoint_envvar) + + if raw_key.lower() == "identity": + from .auth import get_shared_token_provider + + token_provider = get_shared_token_provider() + client = AsyncAzureOpenAI( + azure_endpoint=azure_endpoint, + api_version=api_version, + azure_ad_token_provider=token_provider.get_token, + ) + else: + client = AsyncAzureOpenAI( + azure_endpoint=azure_endpoint, + api_version=api_version, + api_key=raw_key, + ) + return AzureProvider(openai_client=client) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +DEFAULT_CHAT_SPEC = "openai:gpt-4o" + + +def create_chat_model( + model_spec: str | None = None, +) -> PydanticAIChatModel: + """Create a chat model from a ``provider:model`` spec. + + Delegates to :func:`pydantic_ai.models.infer_model` for provider wiring. + If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but + ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. + + If *model_spec* is ``None``, it is constructed from the ``OPENAI_MODEL`` + environment variable (falling back to :data:`DEFAULT_CHAT_SPEC`). + + Examples:: + + model = create_chat_model() # uses OPENAI_MODEL or gpt-4o + model = create_chat_model("openai:gpt-4o") + model = create_chat_model("anthropic:claude-sonnet-4-20250514") + model = create_chat_model("google:gemini-2.0-flash") + """ + if model_spec is None: + openai_model = os.getenv("OPENAI_MODEL") + if openai_model: + model_spec = f"openai:{openai_model}" + else: + model_spec = DEFAULT_CHAT_SPEC + provider, _, model_name = model_spec.partition(":") + if _needs_azure_fallback(provider): + from pydantic_ai.models.openai import OpenAIChatModel + + if os.getenv("OPENAI_MODEL"): + print( + f"OPENAI_MODEL={os.getenv('OPENAI_MODEL')!r} ignored; " + f"Azure deployment is determined by AZURE_OPENAI_ENDPOINT" + ) + model = OpenAIChatModel(model_name, provider=_make_azure_provider()) + else: + model = infer_model(model_spec) + return PydanticAIChatModel(model) + + +DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-ada-002" + + +def create_embedding_model( + model_spec: str | None = None, +) -> CachingEmbeddingModel: + """Create an embedding model from a ``provider:model`` spec. + + Delegates to :class:`pydantic_ai.Embedder` for provider wiring. + If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but + ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. + + If *model_spec* is ``None``, it is constructed from the + ``OPENAI_EMBEDDING_MODEL`` environment variable (falling back to + :data:`DEFAULT_EMBEDDING_SPEC`). + + Returns a :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` + wrapping a :class:`PydanticAIEmbedder`. + + Examples:: + + model = create_embedding_model() # uses OPENAI_EMBEDDING_MODEL or ada-002 + model = create_embedding_model("openai:text-embedding-3-small") + model = create_embedding_model("cohere:embed-english-v3.0") + model = create_embedding_model("google:text-embedding-004") + """ + if model_spec is None: + openai_embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL") + if openai_embedding_model: + model_spec = f"openai:{openai_embedding_model}" + else: + model_spec = DEFAULT_EMBEDDING_SPEC + provider, _, model_name = model_spec.partition(":") + if not model_name: + model_name = provider # No colon in spec + if _needs_azure_fallback(provider): + from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel + + from .embeddings import model_to_envvar + + # Look up model-specific Azure endpoint, falling back to the generic one. + suggested_envvar = model_to_envvar.get(model_name) + if suggested_envvar and os.getenv(suggested_envvar): + endpoint_envvar = suggested_envvar + else: + endpoint_envvar = "AZURE_OPENAI_ENDPOINT_EMBEDDING" + # Allow a model-specific API key, falling back to the generic one. + api_key_envvar = "AZURE_OPENAI_API_KEY_EMBEDDING" + if not os.getenv(api_key_envvar): + api_key_envvar = "AZURE_OPENAI_API_KEY" + + azure_provider = _make_azure_provider(endpoint_envvar, api_key_envvar) + embedding_model = OpenAIEmbeddingModel(model_name, provider=azure_provider) + embedder = _PydanticAIEmbedder(embedding_model) + else: + embedder = _PydanticAIEmbedder(model_spec) + return CachingEmbeddingModel(PydanticAIEmbedder(embedder, model_name)) + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +def _hashish(s: str) -> int: + """Deterministic hash function for fake embeddings (hash() varies per run).""" + h = 0 + for ch in s: + h = (h * 31 + ord(ch)) & 0xFFFFFFFF + return h + + +def _compute_fake_embeddings( + input_texts: list[str], embedding_size: int +) -> list[list[float]]: + """Generate deterministic fake embeddings for testing (unnormalized). + + Raises :class:`ValueError` on empty input strings. + """ + prime = 1961 + result: list[list[float]] = [] + for item in input_texts: + if not item: + raise ValueError("Empty input text") + length = len(item) + floats: list[float] = [] + for i in range(embedding_size): + cut = i % length + scrambled = item[cut:] + item[:cut] + hashed = _hashish(scrambled) + reduced = (hashed % prime) / prime + floats.append(reduced) + result.append(floats) + return result + + +class _FakePydanticAIEmbeddingModel(_PydanticAIEmbeddingModelBase): + """A pydantic_ai :class:`EmbeddingModel` that returns deterministic fake + embeddings. Used only for testing — no network calls are made.""" + + def __init__(self, embedding_size: int = 3) -> None: + super().__init__() + self._embedding_size = embedding_size + + @property + def model_name(self) -> str: + return "test" + + @property + def system(self) -> str: + return "test" + + async def embed( + self, + inputs: str | Sequence[str], + *, + input_type: EmbedInputType, + settings: EmbeddingSettings | None = None, + ) -> EmbeddingResult: + inputs_list, settings = self.prepare_embed(inputs, settings) + embeddings = _compute_fake_embeddings(inputs_list, self._embedding_size) + return EmbeddingResult( + embeddings=embeddings, + inputs=inputs_list, + input_type=input_type, + model_name="test", + provider_name="test", + ) + + +def create_test_embedding_model( + embedding_size: int = 3, +) -> CachingEmbeddingModel: + """Create a :class:`CachingEmbeddingModel` with deterministic fake + embeddings for testing. No API keys or network access required.""" + fake_model = _FakePydanticAIEmbeddingModel(embedding_size) + pydantic_embedder = _PydanticAIEmbedder(fake_model) + return CachingEmbeddingModel(PydanticAIEmbedder(pydantic_embedder, "test")) + + +def configure_models( + chat_model_spec: str, + embedding_model_spec: str, +) -> tuple[PydanticAIChatModel, CachingEmbeddingModel]: + """Configure both a chat model and an embedding model at once. + + Delegates to pydantic_ai's model registry for provider wiring. + + Example:: + + chat, embedder = configure_models( + "openai:gpt-4o", + "openai:text-embedding-3-small", + ) + + settings = ConversationSettings(model=embedder) + extractor = KnowledgeExtractor(model=chat) + """ + return ( + create_chat_model(chat_model_spec), + create_embedding_model(embedding_model_spec), + ) diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 3bbc572..46c4dba 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -6,9 +6,12 @@ import numpy as np -from openai import DEFAULT_MAX_RETRIES - -from .embeddings import AsyncEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings +from .embeddings import ( + IEmbeddingModel, + NormalizedEmbedding, + NormalizedEmbeddings, +) +from .model_adapters import create_embedding_model @dataclass @@ -19,47 +22,34 @@ class ScoredInt: @dataclass class TextEmbeddingIndexSettings: - embedding_model: AsyncEmbeddingModel - embedding_size: int # Set to embedding_model.embedding_size + embedding_model: IEmbeddingModel min_score: float # Between 0.0 and 1.0 max_matches: int | None # >= 1; None means no limit batch_size: int # >= 1 - max_retries: int def __init__( self, - embedding_model: AsyncEmbeddingModel | None = None, - embedding_size: int | None = None, + embedding_model: IEmbeddingModel | None = None, min_score: float | None = None, max_matches: int | None = None, batch_size: int | None = None, - max_retries: int | None = None, ): self.min_score = min_score if min_score is not None else 0.85 self.max_matches = max_matches if max_matches and max_matches >= 1 else None self.batch_size = batch_size if batch_size and batch_size >= 1 else 8 - self.max_retries = ( - max_retries if max_retries is not None else DEFAULT_MAX_RETRIES - ) - self.embedding_model = embedding_model or AsyncEmbeddingModel( - embedding_size, max_retries=self.max_retries - ) - self.embedding_size = self.embedding_model.embedding_size - assert ( - embedding_size is None or self.embedding_size == embedding_size - ), f"Given embedding size {embedding_size} doesn't match model's embedding size {self.embedding_size}" + self.embedding_model = embedding_model or create_embedding_model() class VectorBase: settings: TextEmbeddingIndexSettings _vectors: NormalizedEmbeddings - _model: AsyncEmbeddingModel + _model: IEmbeddingModel _embedding_size: int def __init__(self, settings: TextEmbeddingIndexSettings): self.settings = settings self._model = settings.embedding_model - self._embedding_size = self._model.embedding_size + self._embedding_size = 0 self.clear() async def get_embedding(self, key: str, cache: bool = True) -> NormalizedEmbedding: @@ -88,6 +78,14 @@ def add_embedding( ) -> None: if isinstance(embedding, list): embedding = np.array(embedding, dtype=np.float32) + if self._embedding_size == 0: + self._set_embedding_size(len(embedding)) + self._vectors.shape = (0, self._embedding_size) + if len(embedding) != self._embedding_size: + raise ValueError( + f"Embedding size mismatch: expected {self._embedding_size}, " + f"got {len(embedding)}" + ) embeddings = embedding.reshape(1, -1) # Make it 2D: 1xN self._vectors = np.append(self._vectors, embeddings, axis=0) if key is not None: @@ -96,20 +94,30 @@ def add_embedding( def add_embeddings( self, keys: None | list[str], embeddings: NormalizedEmbeddings ) -> None: - assert embeddings.ndim == 2 - assert embeddings.shape[1] == self._embedding_size + if embeddings.ndim != 2: + raise ValueError(f"Expected 2D embeddings array, got {embeddings.ndim}D") + if self._embedding_size == 0: + self._set_embedding_size(embeddings.shape[1]) + self._vectors.shape = (0, self._embedding_size) + if embeddings.shape[1] != self._embedding_size: + raise ValueError( + f"Embedding size mismatch: expected {self._embedding_size}, " + f"got {embeddings.shape[1]}" + ) self._vectors = np.concatenate((self._vectors, embeddings), axis=0) if keys is not None: for key, embedding in zip(keys, embeddings): self._model.add_embedding(key, embedding) async def add_key(self, key: str, cache: bool = True) -> None: - embeddings = (await self.get_embedding(key, cache=cache)).reshape(1, -1) - self._vectors = np.append(self._vectors, embeddings, axis=0) + embedding = await self.get_embedding(key, cache=cache) + self.add_embedding(key if cache else None, embedding) async def add_keys(self, keys: list[str], cache: bool = True) -> None: + if not keys: + return embeddings = await self.get_embeddings(keys, cache=cache) - self._vectors = np.concatenate((self._vectors, embeddings), axis=0) + self.add_embeddings(keys if cache else None, embeddings) def fuzzy_lookup_embedding( self, @@ -122,6 +130,8 @@ def fuzzy_lookup_embedding( max_hits = 10 if min_score is None: min_score = 0.0 + if len(self._vectors) == 0: + return [] # This line does most of the work: scores: Iterable[float] = np.dot(self._vectors, embedding) scored_ordinals = [ @@ -160,9 +170,15 @@ async def fuzzy_lookup( embedding, max_hits=max_hits, min_score=min_score, predicate=predicate ) + def _set_embedding_size(self, size: int) -> None: + """Adopt *size* when it was not known at construction time.""" + assert size > 0 + self._embedding_size = size + def clear(self) -> None: self._vectors = np.array([], dtype=np.float32) - self._vectors.shape = (0, self._embedding_size) + if self._embedding_size > 0: + self._vectors.shape = (0, self._embedding_size) def get_embedding_at(self, pos: int) -> NormalizedEmbedding: if 0 <= pos < len(self._vectors): @@ -175,13 +191,20 @@ def serialize_embedding_at(self, pos: int) -> NormalizedEmbedding | None: return self._vectors[pos] if 0 <= pos < len(self._vectors) else None def serialize(self) -> NormalizedEmbeddings: - assert self._vectors.shape == (len(self._vectors), self._embedding_size) + if self._embedding_size > 0: + assert self._vectors.shape == (len(self._vectors), self._embedding_size) return self._vectors # TODO: Should we make a copy? def deserialize(self, data: NormalizedEmbeddings | None) -> None: if data is None: self.clear() return + if self._embedding_size == 0: + if data.ndim < 2 or data.shape[0] == 0: + # Empty data — can't determine size; just clear. + self.clear() + return + self._set_embedding_size(data.shape[1]) assert data.shape == (len(data), self._embedding_size), [ data.shape, self._embedding_size, diff --git a/src/typeagent/emails/email_memory.py b/src/typeagent/emails/email_memory.py index d6cf06c..6dd50cc 100644 --- a/src/typeagent/emails/email_memory.py +++ b/src/typeagent/emails/email_memory.py @@ -8,11 +8,10 @@ import typechat -from ..aitools import utils +from ..aitools import model_adapters, utils from ..knowpro import ( answer_response_schema, answers, - convknowledge, search_query_schema, searchlang, ) @@ -24,7 +23,7 @@ class EmailMemorySettings: def __init__(self, conversation_settings: ConversationSettings) -> None: - self.language_model = convknowledge.create_typechat_model() + self.language_model = model_adapters.create_chat_model() self.query_translator = utils.create_translator( self.language_model, search_query_schema.SearchQuery ) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 732e239..07ea155 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -18,7 +18,7 @@ searchlang, secindex, ) -from ..aitools import utils +from ..aitools import model_adapters, utils from ..storage.memory import semrefindex from .convsettings import ConversationSettings from .interfaces import ( @@ -352,12 +352,12 @@ async def query( """ # Create translators lazily (once per conversation instance) if self._query_translator is None: - model = convknowledge.create_typechat_model() + model = model_adapters.create_chat_model() self._query_translator = utils.create_translator( model, search_query_schema.SearchQuery ) if self._answer_translator is None: - model = convknowledge.create_typechat_model() + model = model_adapters.create_chat_model() self._answer_translator = utils.create_translator( model, answer_response_schema.AnswerResponse ) diff --git a/src/typeagent/knowpro/convknowledge.py b/src/typeagent/knowpro/convknowledge.py index 4bea97d..fe1d5f5 100644 --- a/src/typeagent/knowpro/convknowledge.py +++ b/src/typeagent/knowpro/convknowledge.py @@ -1,67 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import asyncio from dataclasses import dataclass, field -import os import typechat from . import kplib -from ..aitools import auth - -# TODO: Move ModelWrapper and create_typechat_model() to aitools package. - - -# TODO: Make these parameters that can be configured (e.g. from command line). -DEFAULT_MAX_RETRY_ATTEMPTS = 0 -DEFAULT_TIMEOUT_SECONDS = 25 - - -class ModelWrapper(typechat.TypeChatLanguageModel): - def __init__( - self, - base_model: typechat.TypeChatLanguageModel, - token_provider: auth.AzureTokenProvider, - ): - self.base_model = base_model - self.token_provider = token_provider - - async def complete( - self, prompt: str | list[typechat.PromptSection] - ) -> typechat.Result[str]: - if self.token_provider.needs_refresh(): - loop = asyncio.get_running_loop() - api_key = await loop.run_in_executor( - None, self.token_provider.refresh_token - ) - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - env[key_name] = api_key - self.base_model = typechat.create_language_model(env) - self.base_model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - return await self.base_model.complete(prompt) - - -def create_typechat_model() -> typechat.TypeChatLanguageModel: - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - key = env.get(key_name) - shared_token_provider: auth.AzureTokenProvider | None = None - if key is not None and key.lower() == "identity": - shared_token_provider = auth.get_shared_token_provider() - env[key_name] = shared_token_provider.get_token() - model = typechat.create_language_model(env) - model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS - if shared_token_provider is not None: - model = ModelWrapper(model, shared_token_provider) - return model +from ..aitools.model_adapters import create_chat_model @dataclass class KnowledgeExtractor: - model: typechat.TypeChatLanguageModel = field(default_factory=create_typechat_model) + model: typechat.TypeChatLanguageModel = field(default_factory=create_chat_model) max_chars_per_chunk: int = 2048 merge_action_knowledge: bool = ( False # TODO: Implement merge_action_knowledge_into_response diff --git a/src/typeagent/knowpro/convsettings.py b/src/typeagent/knowpro/convsettings.py index 627546e..9dbf121 100644 --- a/src/typeagent/knowpro/convsettings.py +++ b/src/typeagent/knowpro/convsettings.py @@ -5,7 +5,8 @@ from dataclasses import dataclass -from ..aitools.embeddings import AsyncEmbeddingModel +from ..aitools.embeddings import IEmbeddingModel +from ..aitools.model_adapters import create_embedding_model from ..aitools.vectorbase import TextEmbeddingIndexSettings from .interfaces import IKnowledgeExtractor, IStorageProvider @@ -38,11 +39,11 @@ class ConversationSettings: def __init__( self, - model: AsyncEmbeddingModel | None = None, + model: IEmbeddingModel | None = None, storage_provider: IStorageProvider | None = None, ): # All settings share the same model, so they share the embedding cache. - model = model or AsyncEmbeddingModel() + model = model or create_embedding_model() self.embedding_model = model min_score = 0.85 self.related_term_index_settings = RelatedTermIndexSettings( diff --git a/src/typeagent/knowpro/fuzzyindex.py b/src/typeagent/knowpro/fuzzyindex.py index 6ace1b3..97138e6 100644 --- a/src/typeagent/knowpro/fuzzyindex.py +++ b/src/typeagent/knowpro/fuzzyindex.py @@ -137,7 +137,7 @@ def deserialize(self, embeddings: NormalizedEmbedding) -> None: assert embeddings.dtype == np.float32, embeddings.dtype assert embeddings.ndim == 2, embeddings.shape assert ( - embeddings.shape[1] == self._vector_base._embedding_size + self._vector_base._embedding_size == 0 + or embeddings.shape[1] == self._vector_base._embedding_size ), embeddings.shape - self.clear() - self.push(embeddings) + self._vector_base.deserialize(embeddings) diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index 877e3ba..a82fe7a 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -52,7 +52,6 @@ class ConversationMetadata: schema_version: int | None = None created_at: Datetime | None = None updated_at: Datetime | None = None - embedding_size: int | None = None embedding_model: str | None = None tags: list[str] | None = None extra: dict[str, str] | None = None diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index bda7397..60ce302 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -8,6 +8,7 @@ from typechat import Result, TypeChatLanguageModel from . import convknowledge, kplib +from ..aitools import model_adapters from .interfaces import IKnowledgeExtractor @@ -15,7 +16,7 @@ def create_knowledge_extractor( chat_model: TypeChatLanguageModel | None = None, ) -> convknowledge.KnowledgeExtractor: """Create a knowledge extractor using the given Chat Model.""" - chat_model = chat_model or convknowledge.create_typechat_model() + chat_model = chat_model or model_adapters.create_chat_model() extractor = convknowledge.KnowledgeExtractor( chat_model, max_chars_per_chunk=4096, merge_action_knowledge=False ) @@ -25,7 +26,6 @@ def create_knowledge_extractor( async def extract_knowledge_from_text( knowledge_extractor: IKnowledgeExtractor, text: str, - max_retries: int, ) -> Result[kplib.KnowledgeResponse]: """Extract knowledge from a single text input with retries.""" # TODO: Add a retry mechanism to handle transient errors. @@ -36,13 +36,10 @@ async def batch_worker( q: asyncio.Queue[tuple[int, str] | None], knowledge_extractor: IKnowledgeExtractor, results: dict[int, Result[kplib.KnowledgeResponse]], - max_retries: int, ) -> None: while item := await q.get(): index, text = item - result = await extract_knowledge_from_text( - knowledge_extractor, text, max_retries - ) + result = await extract_knowledge_from_text(knowledge_extractor, text) results[index] = result @@ -50,7 +47,6 @@ async def extract_knowledge_from_text_batch( knowledge_extractor: IKnowledgeExtractor, text_batch: list[str], concurrency: int = 2, - max_retries: int = 3, ) -> list[Result[kplib.KnowledgeResponse]]: """Extract knowledge from a batch of text inputs concurrently.""" if not text_batch: @@ -63,7 +59,7 @@ async def extract_knowledge_from_text_batch( async with asyncio.TaskGroup() as tg: for _ in range(concurrency): - tg.create_task(batch_worker(q, knowledge_extractor, results, max_retries)) + tg.create_task(batch_worker(q, knowledge_extractor, results)) for index, text in enumerate(text_batch): await q.put((index, text)) @@ -202,7 +198,6 @@ async def extract_knowledge_for_text_batch_q( knowledge_extractor: convknowledge.KnowledgeExtractor, text_batch: list[str], concurrency: int = 2, - max_retries: int = 3, ) -> list[Result[kplib.KnowledgeResponse]]: """Extract knowledge for a batch of text inputs using a task queue.""" raise NotImplementedError("TODO") @@ -211,7 +206,7 @@ async def extract_knowledge_for_text_batch_q( # await run_in_batches( # task_batch, - # lambda text: extract_knowledge_from_text(knowledge_extractor, text, max_retries), + # lambda text: extract_knowledge_from_text(knowledge_extractor, text), # concurrency, # ) diff --git a/src/typeagent/knowpro/serialization.py b/src/typeagent/knowpro/serialization.py index 1e48b68..cbbe7b7 100644 --- a/src/typeagent/knowpro/serialization.py +++ b/src/typeagent/knowpro/serialization.py @@ -46,9 +46,14 @@ def create_file_header() -> FileHeader: return FileHeader(version="0.1") +class ModelMetadata(TypedDict): + embeddingSize: int + + class EmbeddingFileHeader(TypedDict): relatedCount: NotRequired[int | None] messageCount: NotRequired[int | None] + modelMetadata: NotRequired[ModelMetadata | None] class EmbeddingData(TypedDict): @@ -104,6 +109,7 @@ def to_conversation_file_data[TMessageData]( embedding_file_header = EmbeddingFileHeader() embeddings_list: list[NormalizedEmbeddings] = [] + embedding_size = 0 related_terms_index_data = conversation_data.get("relatedTermsIndexData") if related_terms_index_data is not None: @@ -114,6 +120,8 @@ def to_conversation_file_data[TMessageData]( embeddings_list.append(embeddings) text_embedding_data["embeddings"] = None embedding_file_header["relatedCount"] = len(embeddings) + if embedding_size == 0 and embeddings.ndim == 2: + embedding_size = embeddings.shape[1] message_index_data = conversation_data.get("messageIndexData") if message_index_data is not None: @@ -124,6 +132,13 @@ def to_conversation_file_data[TMessageData]( embeddings_list.append(embeddings) text_embedding_data["embeddings"] = None embedding_file_header["messageCount"] = len(embeddings) + if embedding_size == 0 and embeddings.ndim == 2: + embedding_size = embeddings.shape[1] + + if embedding_size > 0: + embedding_file_header["modelMetadata"] = ModelMetadata( + embeddingSize=embedding_size + ) binary_data = ConversationBinaryData(embeddingsList=embeddings_list) json_data = ConversationJsonData( diff --git a/src/typeagent/mcp/server.py b/src/typeagent/mcp/server.py index 19919c9..dcd4a3c 100644 --- a/src/typeagent/mcp/server.py +++ b/src/typeagent/mcp/server.py @@ -102,7 +102,7 @@ class ProcessingContext: query_context: query.QueryEvalContext[ podcast.PodcastMessage, TermToSemanticRefIndex ] - embedding_model: embeddings.AsyncEmbeddingModel + embedding_model: embeddings.IEmbeddingModel query_translator: typechat.TypeChatJsonTranslator[SearchQuery] answer_translator: typechat.TypeChatJsonTranslator[AnswerResponse] diff --git a/src/typeagent/podcasts/podcast.py b/src/typeagent/podcasts/podcast.py index 3ed2639..5376d20 100644 --- a/src/typeagent/podcasts/podcast.py +++ b/src/typeagent/podcasts/podcast.py @@ -143,13 +143,19 @@ async def deserialize( @staticmethod def _read_conversation_data_from_file( - filename_prefix: str, embedding_size: int + filename_prefix: str, ) -> ConversationDataWithIndexes[Any]: """Read podcast conversation data from files. No exceptions are caught; they just bubble out.""" with open(filename_prefix + "_data.json", "r", encoding="utf-8") as f: json_data: serialization.ConversationJsonData[PodcastMessageData] = ( json.load(f) ) + embedding_file_header = json_data.get("embeddingFileHeader") + embedding_size = 0 + if embedding_file_header: + model_metadata = embedding_file_header.get("modelMetadata") + if model_metadata: + embedding_size = model_metadata.get("embeddingSize", 0) embeddings_list: list[NormalizedEmbeddings] | None = None if embedding_size: with open(filename_prefix + "_embeddings.bin", "rb") as f: @@ -159,7 +165,7 @@ def _read_conversation_data_from_file( embeddings_list = [embeddings] else: print( - "Warning: not reading embeddings file because size is {embedding_size}" + f"Warning: not reading embeddings file because size is {embedding_size}" ) embeddings_list = None file_data = serialization.ConversationFileData( @@ -178,10 +184,7 @@ async def read_from_file( settings: ConversationSettings, dbname: str | None = None, ) -> "Podcast": - embedding_size = settings.embedding_model.embedding_size - data = Podcast._read_conversation_data_from_file( - filename_prefix, embedding_size - ) + data = Podcast._read_conversation_data_from_file(filename_prefix) provider = await settings.get_storage_provider() msgs = await provider.get_message_collection() diff --git a/src/typeagent/storage/sqlite/messageindex.py b/src/typeagent/storage/sqlite/messageindex.py index 877cbd6..d48a976 100644 --- a/src/typeagent/storage/sqlite/messageindex.py +++ b/src/typeagent/storage/sqlite/messageindex.py @@ -63,6 +63,9 @@ async def add_messages_starting_at( for chunk_ord, chunk in enumerate(message.text_chunks): chunks_to_embed.append((msg_ord, chunk_ord, chunk)) + if not chunks_to_embed: + return + embeddings = await self._vectorbase.get_embeddings( [chunk for _, _, chunk in chunks_to_embed], cache=False ) diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 8fae1b2..3d5a318 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone import sqlite3 -from ...aitools.embeddings import AsyncEmbeddingModel +from ...aitools.model_adapters import create_embedding_model from ...aitools.vectorbase import TextEmbeddingIndexSettings from ...knowpro import interfaces from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings @@ -31,7 +31,7 @@ class SqliteStorageProvider[TMessage: interfaces.IMessage]( """SQLite-backed storage provider implementation. This provider performs consistency checks on database initialization to ensure - that existing embeddings match the configured embedding_size. If a mismatch is + that existing embeddings match the configured embedding model. If a mismatch is detected, a ValueError is raised with a descriptive error message. """ @@ -119,19 +119,16 @@ def _resolve_embedding_settings( provided_related_settings: RelatedTermIndexSettings | None, ) -> tuple[MessageTextIndexSettings, RelatedTermIndexSettings]: metadata_exists = self._conversation_metadata_exists() - stored_size_str = self._get_single_metadata_value("embedding_size") stored_name = self._get_single_metadata_value("embedding_name") - stored_size = int(stored_size_str) if stored_size_str else None if provided_message_settings is None: - if stored_size is not None or stored_name is not None: - embedding_model = AsyncEmbeddingModel( - embedding_size=stored_size, - model_name=stored_name, - ) + if stored_name is not None: + spec = stored_name + if spec and ":" not in spec: + spec = f"openai:{spec}" + embedding_model = create_embedding_model(spec) base_embedding_settings = TextEmbeddingIndexSettings( embedding_model=embedding_model, - embedding_size=stored_size, ) else: base_embedding_settings = TextEmbeddingIndexSettings() @@ -139,13 +136,7 @@ def _resolve_embedding_settings( else: message_settings = provided_message_settings base_embedding_settings = message_settings.embedding_index_settings - provided_size = base_embedding_settings.embedding_size provided_name = base_embedding_settings.embedding_model.model_name - if stored_size is not None and stored_size != provided_size: - raise ValueError( - f"Conversation metadata embedding_size " - f"({stored_size}) does not match provided embedding size ({provided_size})." - ) if stored_name is not None and stored_name != provided_name: raise ValueError( f"Conversation metadata embedding_model " @@ -157,12 +148,7 @@ def _resolve_embedding_settings( else: related_settings = provided_related_settings related_embedding_settings = related_settings.embedding_index_settings - related_size = related_embedding_settings.embedding_size related_name = related_embedding_settings.embedding_model.model_name - if related_size != base_embedding_settings.embedding_size: - raise ValueError( - "Related term index embedding_size does not match message text index embedding_size" - ) if related_name != base_embedding_settings.embedding_model.model_name: raise ValueError( "Related term index embedding_model does not match message text index embedding_model" @@ -170,17 +156,9 @@ def _resolve_embedding_settings( if related_settings.embedding_index_settings is not base_embedding_settings: related_settings.embedding_index_settings = base_embedding_settings - actual_size = base_embedding_settings.embedding_size actual_name = base_embedding_settings.embedding_model.model_name if self._metadata is not None: - if self._metadata.embedding_size is None: - self._metadata.embedding_size = actual_size - elif self._metadata.embedding_size != actual_size: - raise ValueError( - "Conversation metadata embedding_size does not match provider settings" - ) - if self._metadata.embedding_model is None: self._metadata.embedding_model = actual_name elif self._metadata.embedding_model != actual_name: @@ -190,8 +168,6 @@ def _resolve_embedding_settings( if metadata_exists: metadata_updates: dict[str, str] = {} - if stored_size is None: - metadata_updates["embedding_size"] = str(actual_size) if stored_name is None: metadata_updates["embedding_name"] = actual_name if metadata_updates: @@ -200,51 +176,47 @@ def _resolve_embedding_settings( return message_settings, related_settings def _check_embedding_consistency(self) -> None: - """Check that existing embeddings in the database match the expected embedding size. + """Check that existing embeddings in the database are consistent. - This method is called during initialization to ensure that embeddings stored in the - database match the embedding_size specified in ConversationSettings. This prevents - runtime errors when trying to use embeddings of incompatible sizes. + This method is called during initialization to ensure that embeddings + stored in the message text index and related terms index have the same + size. This prevents runtime errors when trying to use embeddings of + incompatible sizes. Raises: - ValueError: If embeddings in the database don't match the expected size. + ValueError: If embeddings in the database have inconsistent sizes. """ from .schema import deserialize_embedding cursor = self.db.cursor() - expected_size = ( - self.message_text_index_settings.embedding_index_settings.embedding_size - ) - # Check message text index embeddings + # Get size from message text index embeddings + message_size: int | None = None cursor.execute("SELECT embedding FROM MessageTextIndex LIMIT 1") row = cursor.fetchone() if row and row[0]: embedding = deserialize_embedding(row[0]) - actual_size = len(embedding) - if actual_size != expected_size: - raise ValueError( - f"Message text index embedding size mismatch: " - f"database contains embeddings of size {actual_size}, " - f"but ConversationSettings specifies embedding_size={expected_size}. " - f"The database was likely created with a different embedding model. " - f"Please use the same embedding model or create a new database." - ) + message_size = len(embedding) - # Check related terms fuzzy index embeddings + # Get size from related terms fuzzy index embeddings + related_size: int | None = None cursor.execute("SELECT term_embedding FROM RelatedTermsFuzzy LIMIT 1") row = cursor.fetchone() if row and row[0]: embedding = deserialize_embedding(row[0]) - actual_size = len(embedding) - if actual_size != expected_size: - raise ValueError( - f"Related terms index embedding size mismatch: " - f"database contains embeddings of size {actual_size}, " - f"but ConversationSettings specifies embedding_size={expected_size}. " - f"The database was likely created with a different embedding model. " - f"Please use the same embedding model or create a new database." - ) + related_size = len(embedding) + + if ( + message_size is not None + and related_size is not None + and message_size != related_size + ): + raise ValueError( + f"Embedding size mismatch: " + f"message text index has size {message_size}, " + f"but related terms index has size {related_size}. " + f"The database may be corrupted." + ) def _init_conversation_metadata_if_needed(self) -> None: """Initialize conversation metadata if the database is new (empty metadata table). @@ -273,18 +245,10 @@ def _init_conversation_metadata_if_needed(self) -> None: tags = None extras = {} - actual_embedding_size = ( - self.message_text_index_settings.embedding_index_settings.embedding_size - ) actual_embedding_name = ( self.message_text_index_settings.embedding_index_settings.embedding_model.model_name ) - metadata_embedding_size = ( - self._metadata.embedding_size - if self._metadata and self._metadata.embedding_size is not None - else actual_embedding_size - ) metadata_embedding_name = ( self._metadata.embedding_model if self._metadata and self._metadata.embedding_model is not None @@ -306,7 +270,6 @@ def _init_conversation_metadata_if_needed(self) -> None: created_at=format_timestamp_utc(current_time), updated_at=format_timestamp_utc(current_time), tag=tags, # None or list of tags - embedding_size=str(metadata_embedding_size), embedding_name=metadata_embedding_name, **extras, ) @@ -513,9 +476,6 @@ def parse_datetime(value_str: str) -> datetime: updated_at_str = get_single("updated_at") updated_at = parse_datetime(updated_at_str) if updated_at_str else None - embedding_size_str = get_single("embedding_size") - embedding_size = int(embedding_size_str) if embedding_size_str else None - embedding_model = get_single("embedding_name") # Handle tags (multiple values allowed, None if key doesn't exist) @@ -542,7 +502,6 @@ def parse_datetime(value_str: str) -> datetime: schema_version=schema_version, created_at=created_at, updated_at=updated_at, - embedding_size=embedding_size, embedding_model=embedding_model, tags=tags, extra=extra if extra else None, @@ -589,9 +548,6 @@ async def update_conversation_timestamps( # Insert default values if no metadata exists name_tag = self._metadata.name_tag if self._metadata else "conversation" schema_version = str(CONVERSATION_SCHEMA_VERSION) - actual_embedding_size = ( - self.message_text_index_settings.embedding_index_settings.embedding_size - ) actual_embedding_name = ( self.message_text_index_settings.embedding_index_settings.embedding_model.model_name ) @@ -599,7 +555,6 @@ async def update_conversation_timestamps( metadata_kwds: dict[str, str | None] = { "name_tag": name_tag or "conversation", "schema_version": schema_version, - "embedding_size": str(actual_embedding_size), "embedding_name": actual_embedding_name, } if created_at is not None: diff --git a/src/typeagent/transcripts/transcript.py b/src/typeagent/transcripts/transcript.py index 494166b..5033e29 100644 --- a/src/typeagent/transcripts/transcript.py +++ b/src/typeagent/transcripts/transcript.py @@ -143,13 +143,19 @@ async def deserialize( @staticmethod def _read_conversation_data_from_file( - filename_prefix: str, embedding_size: int + filename_prefix: str, ) -> ConversationDataWithIndexes[Any]: """Read transcript conversation data from files. No exceptions are caught; they just bubble out.""" with open(filename_prefix + "_data.json", "r", encoding="utf-8") as f: json_data: serialization.ConversationJsonData[TranscriptMessageData] = ( json.load(f) ) + embedding_file_header = json_data.get("embeddingFileHeader") + embedding_size = 0 + if embedding_file_header: + model_metadata = embedding_file_header.get("modelMetadata") + if model_metadata: + embedding_size = model_metadata.get("embeddingSize", 0) embeddings_list: list[NormalizedEmbeddings] | None = None if embedding_size: with open(filename_prefix + "_embeddings.bin", "rb") as f: @@ -159,7 +165,7 @@ def _read_conversation_data_from_file( embeddings_list = [embeddings] else: print( - "Warning: not reading embeddings file because size is {embedding_size}" + f"Warning: not reading embeddings file because size is {embedding_size}" ) embeddings_list = None file_data = serialization.ConversationFileData( @@ -178,10 +184,7 @@ async def read_from_file( settings: ConversationSettings, dbname: str | None = None, ) -> "Transcript": - embedding_size = settings.embedding_model.embedding_size - data = Transcript._read_conversation_data_from_file( - filename_prefix, embedding_size - ) + data = Transcript._read_conversation_data_from_file(filename_prefix) provider = await settings.get_storage_provider() msgs = await provider.get_message_collection() diff --git a/tests/conftest.py b/tests/conftest.py index 40aee88..c4de6d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,11 +11,8 @@ import pytest import pytest_asyncio -from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage -from openai.types.embedding import Embedding -import tiktoken - -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -90,9 +87,9 @@ def really_needs_auth() -> None: @pytest.fixture(scope="session") -def embedding_model() -> AsyncEmbeddingModel: +def embedding_model() -> IEmbeddingModel: """Fixture to create a test embedding model with small embedding size for faster tests.""" - return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + return create_test_embedding_model() @pytest.fixture(scope="session") @@ -130,7 +127,7 @@ def temp_db_path() -> Iterator[str]: @pytest.fixture def memory_storage( - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, ) -> MemoryStorageProvider: """Create a memory storage provider with settings.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model=embedding_model) @@ -188,7 +185,7 @@ def get_text_location(self) -> TextLocation: @pytest_asyncio.fixture async def sqlite_storage( - temp_db_path: str, embedding_model: AsyncEmbeddingModel + temp_db_path: str, embedding_model: IEmbeddingModel ) -> AsyncGenerator[SqliteStorageProvider[FakeMessage], None]: """Create a SqliteStorageProvider for testing.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -299,7 +296,7 @@ def __init__( self._has_secondary_indexes = has_secondary_indexes else: # Create test model for settings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() self.settings = ConversationSettings(test_model, storage_provider) self._needs_async_init = False self._storage_provider = storage_provider @@ -319,7 +316,7 @@ def __init__( async def ensure_initialized(self): """Ensure async initialization is complete.""" if self._needs_async_init: - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() self.settings = ConversationSettings(test_model) storage_provider = await self.settings.get_storage_provider() self._storage_provider = storage_provider @@ -351,79 +348,3 @@ async def fake_conversation_with_storage( ) -> FakeConversation: """Fixture to create a FakeConversation instance with storage provider.""" return FakeConversation(storage_provider=memory_storage) - - -class FakeEmbeddings: - - def __init__( - self, - max_batch_size: int = 2048, - max_chunk_size: int = 4096, - max_elements_per_batch: int = 300_000, - use_tiktoken: bool = False, - ): - self.model_name = "text-embedding-ada-002" - self.call_count = 0 - self.max_batch_size = max_batch_size - self.max_chunk_size = max_chunk_size - self.max_elements_per_batch = max_elements_per_batch - self.use_tiktoken = use_tiktoken - - def reset_counter(self): - self.call_count = 0 - - async def create(self, **kwargs): - self.call_count += 1 - input = kwargs["input"] - len_input = len(input) - if len_input > self.max_batch_size: - raise ValueError("Embedding model received batch larger 2048") - dimensions = 1536 - if "dimensions" in kwargs: - dimensions = kwargs["dimensions"] - - embedding_result = [] - total_elements = 0 - for index in range(len_input): - entity = input[index] - if self.use_tiktoken: - enc_name = tiktoken.encoding_name_for_model(self.model_name) - enc = tiktoken.get_encoding(enc_name) - entity = enc.encode(entity) - total_elements += len(entity) - if len(entity) > self.max_chunk_size: - raise ValueError( - f"Chunk size {len(entity)} larger than max size {self.max_chunk_size}" - ) - value = index % 2 - embedding_result.append( - Embedding( - embedding=[value] * dimensions, index=index, object="embedding" - ) - ) - - if total_elements > self.max_elements_per_batch: - raise ValueError( - f"Batch size {total_elements} larger than max tokens/chars per batch {self.max_elements_per_batch}" - ) - - response = CreateEmbeddingResponse( - data=embedding_result, - model="test_model", - object="list", - usage=Usage(prompt_tokens=0, total_tokens=0), - ) - - return response - - -@pytest.fixture -def fake_embeddings() -> FakeEmbeddings: - """Fixture to create a FaceEmbedding instance""" - return FakeEmbeddings(max_batch_size=2048, max_chunk_size=4096 * 3) - - -@pytest.fixture -def fake_embeddings_tiktoken() -> FakeEmbeddings: - """Fixture to create a FaceEmbedding instance""" - return FakeEmbeddings(max_batch_size=2048, max_chunk_size=4096, use_tiktoken=True) diff --git a/tests/test_add_messages_with_indexing.py b/tests/test_add_messages_with_indexing.py index 4f00cfb..d3df2c4 100644 --- a/tests/test_add_messages_with_indexing.py +++ b/tests/test_add_messages_with_indexing.py @@ -8,7 +8,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.transcripts.transcript import ( @@ -24,7 +24,7 @@ async def test_add_messages_with_indexing_basic(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -65,7 +65,7 @@ async def test_add_messages_with_indexing_batched(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -122,7 +122,7 @@ async def test_transaction_rollback_on_error(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False diff --git a/tests/test_conversation_metadata.py b/tests/test_conversation_metadata.py index eadc125..37a194a 100644 --- a/tests/test_conversation_metadata.py +++ b/tests/test_conversation_metadata.py @@ -16,7 +16,8 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -54,7 +55,7 @@ def get_knowledge(self) -> KnowledgeResponse: @pytest_asyncio.fixture async def storage_provider( - temp_db_path: str, embedding_model: AsyncEmbeddingModel + temp_db_path: str, embedding_model: IEmbeddingModel ) -> AsyncGenerator[SqliteStorageProvider[DummyMessage], None]: """Create a SqliteStorageProvider for testing conversation metadata.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -76,7 +77,7 @@ async def storage_provider_memory() -> ( AsyncGenerator[SqliteStorageProvider[DummyMessage], None] ): """Create an in-memory SqliteStorageProvider for testing conversation metadata.""" - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) @@ -105,7 +106,6 @@ async def test_get_conversation_metadata_nonexistent( assert metadata.schema_version is None assert metadata.created_at is None assert metadata.updated_at is None - assert metadata.embedding_size is None assert metadata.embedding_model is None assert metadata.tags is None assert metadata.extra is None @@ -130,9 +130,7 @@ async def test_update_conversation_timestamps_new( assert metadata.created_at == created_at assert metadata.updated_at == updated_at settings = storage_provider.message_text_index_settings.embedding_index_settings - expected_size = settings.embedding_size expected_model = settings.embedding_model.model_name - assert metadata.embedding_size == expected_size assert metadata.embedding_model == expected_model assert metadata.tags is None assert metadata.extra is None @@ -270,7 +268,7 @@ def test_get_db_version_with_metadata( @pytest.mark.asyncio async def test_multiple_conversations_different_dbs( - self, embedding_model: AsyncEmbeddingModel + self, embedding_model: IEmbeddingModel ): """Test multiple conversations in different database files.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -343,7 +341,7 @@ async def test_multiple_conversations_different_dbs( @pytest.mark.asyncio async def test_conversation_metadata_single_per_db( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test that only one conversation metadata can exist per database.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -418,7 +416,7 @@ async def test_conversation_metadata_with_special_characters( @pytest.mark.asyncio async def test_conversation_metadata_persistence( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test that conversation metadata persists across provider instances.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -457,9 +455,7 @@ async def test_conversation_metadata_persistence( assert metadata.name_tag == "conversation_persistent_test" assert metadata.created_at == created_at assert metadata.updated_at == updated_at - expected_size = embedding_settings.embedding_size expected_model = embedding_settings.embedding_model.model_name - assert metadata.embedding_size == expected_size assert metadata.embedding_model == expected_model finally: await provider2.close() @@ -486,7 +482,7 @@ async def test_empty_string_timestamps( @pytest.mark.asyncio async def test_very_long_name_tag( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test conversation metadata with very long name_tag.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -517,7 +513,7 @@ async def test_very_long_name_tag( @pytest.mark.asyncio async def test_unicode_name_tag( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test conversation metadata with Unicode name_tag.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -548,7 +544,7 @@ async def test_unicode_name_tag( @pytest.mark.asyncio async def test_conversation_metadata_shared_access( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test shared access to metadata using the same database file.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -597,53 +593,9 @@ async def test_conversation_metadata_shared_access( await provider1.close() await provider2.close() - @pytest.mark.asyncio - async def test_embedding_metadata_mismatch_raises( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel - ): - """Ensure a mismatch between stored metadata and provided settings raises.""" - embedding_settings = TextEmbeddingIndexSettings(embedding_model) - message_text_settings = MessageTextIndexSettings(embedding_settings) - related_terms_settings = RelatedTermIndexSettings(embedding_settings) - - provider = SqliteStorageProvider( - db_path=temp_db_path, - message_type=DummyMessage, - message_text_index_settings=message_text_settings, - related_term_index_settings=related_terms_settings, - ) - - await provider.update_conversation_timestamps( - created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), - updated_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), - ) - provider.db.commit() - await provider.close() - - mismatched_model = AsyncEmbeddingModel( - embedding_size=embedding_settings.embedding_size + 1, - model_name=embedding_model.model_name, - ) - mismatched_settings = TextEmbeddingIndexSettings( - embedding_model=mismatched_model, - embedding_size=mismatched_model.embedding_size, - ) - - with pytest.raises(ValueError, match="embedding_size"): - SqliteStorageProvider( - db_path=temp_db_path, - message_type=DummyMessage, - message_text_index_settings=MessageTextIndexSettings( - mismatched_settings - ), - related_term_index_settings=RelatedTermIndexSettings( - mismatched_settings - ), - ) - @pytest.mark.asyncio async def test_embedding_model_mismatch_raises( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Ensure providing a different embedding model name raises.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -683,7 +635,7 @@ async def test_embedding_model_mismatch_raises( @pytest.mark.asyncio async def test_updated_at_changes_on_add_messages( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test that updated_at timestamp is updated when messages are added.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_demo.py b/tests/test_demo.py index 599f006..39f2f06 100644 --- a/tests/test_demo.py +++ b/tests/test_demo.py @@ -6,7 +6,7 @@ import textwrap import time -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import ScoredSemanticRefOrdinal from typeagent.podcasts import podcast @@ -33,7 +33,7 @@ async def main(filename_prefix: str): settings = ConversationSettings() model = settings.embedding_model assert model is not None - assert isinstance(model, AsyncEmbeddingModel), f"model is {model!r}" + assert isinstance(model, IEmbeddingModel), f"model is {model!r}" assert settings.thread_settings.embedding_model is model assert ( settings.message_text_index_settings.embedding_index_settings.embedding_model diff --git a/tests/test_embedding_consistency.py b/tests/test_embedding_consistency.py index 906c2b5..619c921 100644 --- a/tests/test_embedding_consistency.py +++ b/tests/test_embedding_consistency.py @@ -1,39 +1,38 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Test embedding consistency checks between database and settings.""" +"""Test embedding consistency checks between database indexes.""" import os +import sqlite3 import tempfile +import numpy as np import pytest from typeagent import create_conversation -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite import SqliteStorageProvider +from typeagent.storage.sqlite.schema import serialize_embedding from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @pytest.mark.asyncio -async def test_embedding_size_mismatch_in_message_index(): - """Test that opening a DB with mismatched embedding size raises an error.""" - # Create a temporary database file +async def test_same_embedding_size_no_error(): + """Test that opening a DB with the same model works fine.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create a conversation with test model (embedding size 3) settings1 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=3, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) - # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False conv1 = await create_conversation( db_path, TranscriptMessage, settings=settings1 ) - # Add some messages to populate the index messages = [ TranscriptMessage( text_chunks=["Hello world"], @@ -43,108 +42,151 @@ async def test_embedding_size_mismatch_in_message_index(): await conv1.add_messages_with_indexing(messages) await conv1.storage_provider.close() - # Now try to open the same database with a different embedding size - # This should raise an error + # Reopen with same settings — should work settings2 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=5, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) - - with pytest.raises(ValueError, match="embedding_size"): - provider = SqliteStorageProvider( - db_path=db_path, - message_type=TranscriptMessage, - message_text_index_settings=settings2.message_text_index_settings, - related_term_index_settings=settings2.related_term_index_settings, - ) - await provider.close() + provider = SqliteStorageProvider( + db_path=db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings2.message_text_index_settings, + related_term_index_settings=settings2.related_term_index_settings, + ) + await provider.close() finally: - # Clean up the temporary database if os.path.exists(db_path): os.unlink(db_path) @pytest.mark.asyncio -async def test_embedding_size_mismatch_in_related_terms(): - """Test that opening a DB with mismatched embedding size in related terms raises an error.""" - # Create a temporary database file +async def test_empty_db_no_error(): + """Test that opening an empty DB doesn't raise an error regardless of embedding size.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create a conversation with default embedding size settings1 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=3, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) - # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False conv1 = await create_conversation( db_path, TranscriptMessage, settings=settings1 ) - - # Add some messages to populate the related terms index - messages = [ - TranscriptMessage( - text_chunks=["Apple is a fruit"], - metadata=TranscriptMessageMeta(speaker="Alice"), - ) - ] - await conv1.add_messages_with_indexing(messages) await conv1.storage_provider.close() - # Now try to open the same database with a different embedding size - # This should raise an error + # Open with different embedding size should work since DB is empty settings2 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=5, model_name="test") + model=create_test_embedding_model(embedding_size=5) + ) + provider = SqliteStorageProvider( + db_path=db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings2.message_text_index_settings, + related_term_index_settings=settings2.related_term_index_settings, ) + await provider.close() + + finally: + if os.path.exists(db_path): + os.unlink(db_path) + - with pytest.raises(ValueError, match="embedding_size"): - provider = SqliteStorageProvider( +@pytest.mark.asyncio +async def test_embedding_size_mismatch_raises(): + """Test that mismatched embedding sizes between indexes raises ValueError.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name + + try: + # Create a conversation so that the schema is set up + settings = ConversationSettings( + model=create_test_embedding_model(embedding_size=3) + ) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + conv = await create_conversation(db_path, TranscriptMessage, settings=settings) + await conv.storage_provider.close() + + # Manually insert embeddings of different sizes into the two tables + conn = sqlite3.connect(db_path) + msg_emb = serialize_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float32)) + term_emb = serialize_embedding( + np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32) + ) + conn.execute( + "INSERT INTO MessageTextIndex " + "(msg_id, chunk_ordinal, embedding, index_position) " + "VALUES (0, 0, ?, 0)", + (msg_emb,), + ) + conn.execute( + "INSERT INTO RelatedTermsFuzzy (term, term_embedding) VALUES (?, ?)", + ("hello", term_emb), + ) + conn.commit() + conn.close() + + # Reopening should detect the mismatch + settings2 = ConversationSettings( + model=create_test_embedding_model(embedding_size=3) + ) + with pytest.raises(ValueError, match="Embedding size mismatch"): + SqliteStorageProvider( db_path=db_path, message_type=TranscriptMessage, message_text_index_settings=settings2.message_text_index_settings, related_term_index_settings=settings2.related_term_index_settings, ) - await provider.close() finally: - # Clean up the temporary database if os.path.exists(db_path): os.unlink(db_path) @pytest.mark.asyncio -async def test_empty_db_no_error(): - """Test that opening an empty DB doesn't raise an error regardless of embedding size.""" - # Create a temporary database file +async def test_adding_mismatched_embeddings_raises(): + """Test that adding messages with a different embedding size raises ValueError.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create an empty database + # Create and populate with size-3 embeddings settings1 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=3, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) - # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False conv1 = await create_conversation( db_path, TranscriptMessage, settings=settings1 ) + await conv1.add_messages_with_indexing( + [ + TranscriptMessage( + text_chunks=["Hello world"], + metadata=TranscriptMessageMeta(speaker="Alice"), + ) + ] + ) await conv1.storage_provider.close() - # Open with different embedding size should work since DB is empty + # Reopen with size-5 embeddings and try to add more messages settings2 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=5, model_name="test") + model=create_test_embedding_model(embedding_size=5) ) - provider = SqliteStorageProvider( - db_path=db_path, - message_type=TranscriptMessage, - message_text_index_settings=settings2.message_text_index_settings, - related_term_index_settings=settings2.related_term_index_settings, + settings2.semantic_ref_index_settings.auto_extract_knowledge = False + conv2 = await create_conversation( + db_path, TranscriptMessage, settings=settings2 ) - await provider.close() + with pytest.raises(ValueError, match="Embedding size mismatch"): + await conv2.add_messages_with_indexing( + [ + TranscriptMessage( + text_chunks=["Goodbye world"], + metadata=TranscriptMessageMeta(speaker="Bob"), + ) + ] + ) + await conv2.storage_provider.close() finally: - # Clean up the temporary database if os.path.exists(db_path): os.unlink(db_path) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index ac766f1..24a4ff6 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -3,57 +3,50 @@ import numpy as np import pytest -from pytest import MonkeyPatch from pytest_mock import MockerFixture -import openai - -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import CachingEmbeddingModel, IEmbeddingModel from conftest import ( embedding_model, # type: ignore # Magic, prevents side effects of mocking ) -from conftest import ( - FakeEmbeddings, -) @pytest.mark.asyncio -async def test_get_embedding_nocache(embedding_model: AsyncEmbeddingModel): +async def test_get_embedding_nocache(embedding_model: CachingEmbeddingModel): """Test retrieving an embedding without using the cache.""" input_text = "Hello, world" embedding = await embedding_model.get_embedding_nocache(input_text) assert isinstance(embedding, np.ndarray) - assert embedding.shape == (embedding_model.embedding_size,) assert embedding.dtype == np.float32 @pytest.mark.asyncio -async def test_get_embeddings_nocache(embedding_model: AsyncEmbeddingModel): +async def test_get_embeddings_nocache(embedding_model: CachingEmbeddingModel): """Test retrieving multiple embeddings without using the cache.""" inputs = ["Hello, world", "Foo bar baz"] embeddings = await embedding_model.get_embeddings_nocache(inputs) assert isinstance(embeddings, np.ndarray) - assert embeddings.shape == (len(inputs), embedding_model.embedding_size) + assert embeddings.shape[0] == len(inputs) assert embeddings.dtype == np.float32 @pytest.mark.asyncio async def test_get_embedding_with_cache( - embedding_model: AsyncEmbeddingModel, mocker: MockerFixture + embedding_model: CachingEmbeddingModel, mocker: MockerFixture ): """Test retrieving an embedding with caching.""" input_text = "Hello, world" # First call should populate the cache embedding1 = await embedding_model.get_embedding(input_text) - assert input_text in embedding_model._embedding_cache + assert input_text in embedding_model._cache - # Mock the nocache method to ensure it's not called + # Mock the nocache method on the underlying embedder to ensure it's not called mock_get_embedding_nocache = mocker.patch.object( - embedding_model, "get_embedding_nocache", autospec=True + embedding_model._embedder, "get_embedding_nocache", autospec=True ) # Second call should retrieve from the cache @@ -66,7 +59,7 @@ async def test_get_embedding_with_cache( @pytest.mark.asyncio async def test_get_embeddings_with_cache( - embedding_model: AsyncEmbeddingModel, mocker: MockerFixture + embedding_model: CachingEmbeddingModel, mocker: MockerFixture ): """Test retrieving multiple embeddings with caching.""" inputs = ["Hello, world", "Foo bar baz"] @@ -74,11 +67,11 @@ async def test_get_embeddings_with_cache( # First call should populate the cache embeddings1 = await embedding_model.get_embeddings(inputs) for input_text in inputs: - assert input_text in embedding_model._embedding_cache + assert input_text in embedding_model._cache - # Mock the nocache method to ensure it's not called + # Mock the nocache method on the underlying embedder to ensure it's not called mock_get_embeddings_nocache = mocker.patch.object( - embedding_model, "get_embeddings_nocache", autospec=True + embedding_model._embedder, "get_embeddings_nocache", autospec=True ) # Second call should retrieve from the cache @@ -90,233 +83,67 @@ async def test_get_embeddings_with_cache( @pytest.mark.asyncio -async def test_get_embeddings_empty_input(embedding_model: AsyncEmbeddingModel): - """Test retrieving embeddings for an empty input list.""" - inputs = [] - embeddings = await embedding_model.get_embeddings(inputs) - - assert isinstance(embeddings, np.ndarray) - assert embeddings.shape == (0, embedding_model.embedding_size) - assert embeddings.dtype == np.float32 +async def test_get_embeddings_empty_input(embedding_model: CachingEmbeddingModel): + """Test retrieving embeddings for an empty input list raises ValueError.""" + with pytest.raises(ValueError, match="Cannot embed an empty list"): + await embedding_model.get_embeddings([]) @pytest.mark.asyncio -async def test_add_embedding_to_cache(embedding_model: AsyncEmbeddingModel): +async def test_add_embedding_to_cache(embedding_model: CachingEmbeddingModel): """Test adding an embedding to the cache.""" key = "test_key" embedding = np.array([0.1, 0.2, 0.3], dtype=np.float32) embedding_model.add_embedding(key, embedding) - assert key in embedding_model._embedding_cache - assert np.array_equal(embedding_model._embedding_cache[key], embedding) + assert key in embedding_model._cache + assert np.array_equal(embedding_model._cache[key], embedding) @pytest.mark.asyncio -async def test_get_embedding_nocache_empty_input(embedding_model: AsyncEmbeddingModel): +async def test_get_embedding_nocache_empty_input( + embedding_model: CachingEmbeddingModel, +): """Test retrieving an embedding with no cache for an empty input.""" - with pytest.raises(openai.OpenAIError): + with pytest.raises(ValueError, match="Empty input text"): await embedding_model.get_embedding_nocache("") @pytest.mark.asyncio -async def test_refresh_auth( - embedding_model: AsyncEmbeddingModel, mocker: MockerFixture -): - """Test refreshing authentication when using Azure.""" - # Note that pyright doesn't understand mocking, hence the `# type: ignore` below - mocker.patch.object(embedding_model, "azure_token_provider", autospec=True) - mocker.patch.object(embedding_model, "_setup_azure", autospec=True) - - embedding_model.azure_token_provider.needs_refresh.return_value = True # type: ignore - embedding_model.azure_token_provider.refresh_token.return_value = "new_token" # type: ignore - embedding_model.azure_api_version = "2023-05-15" - embedding_model.azure_endpoint = "https://example.azure.com" - - await embedding_model.refresh_auth() +async def test_embeddings_are_normalized(embedding_model: CachingEmbeddingModel): + """Test that returned embeddings are unit-normalized.""" + inputs = ["Hello, world", "Foo bar baz", "Testing normalization"] + embeddings = await embedding_model.get_embeddings_nocache(inputs) - embedding_model.azure_token_provider.refresh_token.assert_called_once() # type: ignore - assert embedding_model.async_client is not None + for i in range(len(inputs)): + norm = float(np.linalg.norm(embeddings[i])) + assert abs(norm - 1.0) < 1e-6, f"Embedding {i} not normalized: norm={norm}" @pytest.mark.asyncio -async def test_set_endpoint(monkeypatch: MonkeyPatch): - """Test creating of model with custom endpoint.""" - - monkeypatch.setenv("AZURE_OPENAI_API_KEY", "does-not-matter") - monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Ensure Azure path is used - - # Default - monkeypatch.setenv( - "AZURE_OPENAI_ENDPOINT_EMBEDDING", - "http://localhost:7997?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel() - assert embedding_model.embedding_size == 1536 - assert embedding_model.model_name == "text-embedding-ada-002" - assert embedding_model.endpoint_envvar == "AZURE_OPENAI_ENDPOINT_EMBEDDING" - - # 3-large - monkeypatch.setenv( - "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE", - "http://localhost:7997?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel(model_name="text-embedding-3-large") - assert embedding_model.embedding_size == 3072 - assert embedding_model.model_name == "text-embedding-3-large" - assert embedding_model.endpoint_envvar == "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE" - - # 3-small - monkeypatch.setenv( - "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL", - "http://localhost:7998?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel(model_name="text-embedding-3-small") - assert embedding_model.embedding_size == 1536 - assert embedding_model.model_name == "text-embedding-3-small" - assert embedding_model.endpoint_envvar == "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL" - - # Fully custom with OpenAI - monkeypatch.setenv("OPENAI_API_KEY", "does-not-matter") - monkeypatch.setenv("INFINITY_EMBEDDING_URL", "http://localhost:7997") - embedding_model = AsyncEmbeddingModel( - 1024, "custom_model", endpoint_envvar="INFINITY_EMBEDDING_URL" - ) - assert embedding_model.embedding_size == 1024 - assert embedding_model.model_name == "custom_model" - # NOTE: checking openai.AsyncOpenAI internals - assert embedding_model.async_client is not None - assert embedding_model.async_client.base_url == "http://localhost:7997" - assert embedding_model.async_client.api_key == "does-not-matter" - assert embedding_model.endpoint_envvar == "INFINITY_EMBEDDING_URL" - - # Customized 3-small with Azure (endpoint_envvar must contain "AZURE") - monkeypatch.delenv("OPENAI_API_KEY") # Force Azure path - monkeypatch.setenv( - "AZURE_ALTERNATE_ENDPOINT", - "http://localhost:7999?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel( - 2000, "text-embedding-3-small", endpoint_envvar="AZURE_ALTERNATE_ENDPOINT" - ) - assert embedding_model.embedding_size == 2000 - assert embedding_model.model_name == "text-embedding-3-small" - assert embedding_model.endpoint_envvar == "AZURE_ALTERNATE_ENDPOINT" - - # Allow explicitly setting default embedding size - AsyncEmbeddingModel(1536) - - # Can't customize embedding_size for default model - with pytest.raises(ValueError): - AsyncEmbeddingModel(1024) - - # Not even when default model name specified explicitly - with pytest.raises(ValueError): - AsyncEmbeddingModel(1024, "text-embedding-ada-002") +async def test_embeddings_are_deterministic( + embedding_model: CachingEmbeddingModel, +): + """Test that the same input always produces the same embedding.""" + input_text = "Deterministic test" + e1 = await embedding_model.get_embedding_nocache(input_text) + e2 = await embedding_model.get_embedding_nocache(input_text) + assert np.array_equal(e1, e2) @pytest.mark.asyncio -async def test_embeddings_batching_tiktoken( - fake_embeddings_tiktoken: FakeEmbeddings, monkeypatch: MonkeyPatch +async def test_different_inputs_produce_different_embeddings( + embedding_model: CachingEmbeddingModel, ): - monkeypatch.setenv("OPENAI_API_KEY", "test_key") - - embedding_model = AsyncEmbeddingModel() - assert embedding_model.max_chunk_size == 4096 - - embedding_model.async_client.embeddings = fake_embeddings_tiktoken # type: ignore - - # Check max batch size - inputs = ["a"] * 2049 - embeddings = await embedding_model.get_embeddings(inputs) - assert len(embeddings) == 2049 - assert fake_embeddings_tiktoken.call_count == 2 - - # Check max token size - inputs = ["Very long input longer than 4096 tokens will be truncated" * 500] - embeddings = await embedding_model.get_embeddings(inputs) - assert len(embeddings) == 1 - - fake_embeddings_tiktoken.reset_counter() - - TEST_MAX_TOKEN_SIZE = 10 - TEST_MAX_TOKENS_PER_BATCH = 20 - embedding_model.max_chunk_size = TEST_MAX_TOKEN_SIZE - embedding_model.max_size_per_batch = TEST_MAX_TOKENS_PER_BATCH - fake_embeddings_tiktoken.max_elements_per_batch = TEST_MAX_TOKENS_PER_BATCH - - assert embedding_model.encoding is not None - - token = [500] * 20 # --> 20 tokens - input = [embedding_model.encoding.decode(token)] * 4 - embeddings = await embedding_model.get_embeddings_nocache(input) # type: ignore - - # each input gets truncated to 10 tokens, so 4 inputs fit in 2 batches of 20 tokens - assert fake_embeddings_tiktoken.call_count == 2 - assert len(embeddings) == 4 - - fake_embeddings_tiktoken.reset_counter() - - TEST_MAX_TOKEN_SIZE = 7 - embedding_model.max_chunk_size = TEST_MAX_TOKEN_SIZE - - token = [500] * 20 # --> 20 tokens - input = [embedding_model.encoding.decode(token)] * 5 - embeddings = await embedding_model.get_embeddings_nocache(input) # type: ignore - - # each input gets truncated to 7 tokens, so each batch can hold 2 inputs (14 tokens) - # 5 inputs require 3 batches - assert fake_embeddings_tiktoken.call_count == 3 - assert len(embeddings) == 5 + """Test that different inputs produce different embeddings.""" + e1 = await embedding_model.get_embedding_nocache("Hello") + e2 = await embedding_model.get_embedding_nocache("World") + assert not np.array_equal(e1, e2) @pytest.mark.asyncio -async def test_embeddings_batching( - fake_embeddings: FakeEmbeddings, monkeypatch: MonkeyPatch +async def test_implements_iembedding_model( + embedding_model: CachingEmbeddingModel, ): - monkeypatch.setenv("OPENAI_API_KEY", "test_key") - - embedding_model = AsyncEmbeddingModel(1024, "custom_model") - embedding_model.async_client.embeddings = fake_embeddings # type: ignore - - # Check max batch size - inputs = ["a"] * 2049 - embeddings = await embedding_model.get_embeddings(inputs) - assert len(embeddings) == 2049 - assert fake_embeddings.call_count == 2 - - TEST_MAX_CHAR_SIZE = 10 - TEST_MAX_CHARS_PER_BATCH = 20 - embedding_model.max_chunk_size = TEST_MAX_CHAR_SIZE - embedding_model.max_size_per_batch = TEST_MAX_CHARS_PER_BATCH - fake_embeddings.max_elements_per_batch = TEST_MAX_CHARS_PER_BATCH - - # Check max token size - inputs = ["a" * TEST_MAX_CHAR_SIZE] - embeddings = await embedding_model.get_embeddings_nocache(inputs) - assert len(embeddings) == 1 - assert np.all(embeddings[0] == 0) - - fake_embeddings.reset_counter() - - # Check one over max token size - inputs = ["a" * (TEST_MAX_CHAR_SIZE + 1)] - embeddings = await embedding_model.get_embeddings_nocache(inputs) - assert len(embeddings) == 1 - assert fake_embeddings.call_count == 1 - - fake_embeddings.reset_counter() - - # Check input as large as max_size_per_batch - inputs = ["a" * 10, "a" * 5, "a" * 5] - embeddings = await embedding_model.get_embeddings_nocache(inputs) # type: ignore - assert fake_embeddings.call_count == 1 - assert len(embeddings) == 3 - - fake_embeddings.reset_counter() - - # Check input larger than max_size_per_batch - # max chars per batch is 20, so 10*10 chars requires 5 batches - inputs = ["a" * 10] * 10 - embeddings = await embedding_model.get_embeddings_nocache(inputs) # type: ignore - assert fake_embeddings.call_count == 5 - assert len(embeddings) == 10 + """Test that CachingEmbeddingModel satisfies the IEmbeddingModel protocol.""" + assert isinstance(embedding_model, IEmbeddingModel) diff --git a/tests/test_factory.py b/tests/test_factory.py index 0f62220..44c45e5 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -6,7 +6,7 @@ import pytest from typeagent import create_conversation -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @@ -15,7 +15,7 @@ async def test_create_conversation_minimal(): """Test creating a conversation with minimal parameters.""" # Create empty conversation with test model - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, @@ -36,7 +36,7 @@ async def test_create_conversation_minimal(): @pytest.mark.asyncio async def test_create_conversation_with_tags(): """Test creating a conversation with tags.""" - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, @@ -54,7 +54,7 @@ async def test_create_conversation_with_tags(): async def test_create_conversation_and_add_messages(really_needs_auth): """Test the complete workflow: create conversation and add messages.""" # 1. Create empty conversation - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, diff --git a/tests/test_incremental_index.py b/tests/test_incremental_index.py index 12f706a..ced12f1 100644 --- a/tests/test_incremental_index.py +++ b/tests/test_incremental_index.py @@ -8,7 +8,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.transcripts.transcript import ( @@ -30,7 +30,7 @@ async def test_incremental_index_building(): db_path = os.path.join(tmpdir, "test.db") # Create settings with test model (no API keys needed) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -74,7 +74,7 @@ async def test_incremental_index_building(): # Second ingestion - add more messages and rebuild index print("\n=== Second ingestion ===") - test_model2 = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model2 = create_test_embedding_model() settings2 = ConversationSettings(model=test_model2) settings2.semantic_ref_index_settings.auto_extract_knowledge = False storage2 = SqliteStorageProvider( @@ -136,7 +136,7 @@ async def test_incremental_index_with_vtt_files(): db_path = os.path.join(tmpdir, "test.db") # Create settings with test model (no API keys needed) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -161,9 +161,7 @@ async def test_incremental_index_with_vtt_files(): # Second VTT file ingestion into same database print("\n=== Import second VTT file ===") - settings2 = ConversationSettings( - model=AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - ) + settings2 = ConversationSettings(model=create_test_embedding_model()) settings2.semantic_ref_index_settings.auto_extract_knowledge = False # Ingest the second transcript diff --git a/tests/test_knowledge.py b/tests/test_knowledge.py index e20ff1f..d4f46fd 100644 --- a/tests/test_knowledge.py +++ b/tests/test_knowledge.py @@ -44,12 +44,12 @@ async def test_extract_knowledge_from_text( mock_knowledge_extractor: convknowledge.KnowledgeExtractor, ): """Test extracting knowledge from a single text input.""" - result = await extract_knowledge_from_text(mock_knowledge_extractor, "test text", 3) + result = await extract_knowledge_from_text(mock_knowledge_extractor, "test text") assert isinstance(result, Success) assert result.value.topics[0] == "test text" failure_result = await extract_knowledge_from_text( - mock_knowledge_extractor, "error", 3 + mock_knowledge_extractor, "error" ) assert isinstance(failure_result, Failure) assert failure_result.message == "Extraction failed" @@ -62,7 +62,7 @@ async def test_extract_knowledge_from_text_batch( """Test extracting knowledge from a batch of text inputs.""" text_batch = ["text 1", "text 2", "error"] results = await extract_knowledge_from_text_batch( - mock_knowledge_extractor, text_batch, 2, 3 + mock_knowledge_extractor, text_batch, 2 ) assert len(results) == 3 diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 03fd0e6..24933ca 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -3,16 +3,21 @@ """End-to-end tests for the MCP server.""" +import json import os import sys from typing import Any import pytest -from mcp import StdioServerParameters +from mcp import ClientSession, StdioServerParameters from mcp.client.session import ClientSession as ClientSessionType +from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent +from openai.types.chat import ChatCompletionMessageParam + +from typeagent.aitools.utils import create_async_openai_client from conftest import EPISODE_53_INDEX @@ -38,11 +43,6 @@ async def sampling_callback( params: CreateMessageRequestParams, ) -> CreateMessageResult: """Sampling callback that uses OpenAI to generate responses.""" - # Use OpenAI to generate a response - from openai.types.chat import ChatCompletionMessageParam - - from typeagent.aitools.utils import create_async_openai_client - client = create_async_openai_client() # Convert MCP SamplingMessage to OpenAI format @@ -91,9 +91,6 @@ async def test_mcp_server_query_conversation_slow( really_needs_auth, server_params: StdioServerParameters ): """Test the query_conversation tool end-to-end using MCP client.""" - from mcp import ClientSession - from mcp.client.stdio import stdio_client - # Pass through environment variables needed for authentication # otherwise this test will fail in the CI on Windows only if not (server_params.env) is None: @@ -135,8 +132,6 @@ async def test_mcp_server_query_conversation_slow( response_text = content_item.text # Parse response (it should be JSON with success, answer, time_used) - import json - try: response_data = json.loads(response_text) except json.JSONDecodeError as e: @@ -158,9 +153,6 @@ async def test_mcp_server_query_conversation_slow( @pytest.mark.asyncio async def test_mcp_server_empty_question(server_params: StdioServerParameters): """Test the query_conversation tool with an empty question.""" - from mcp import ClientSession - from mcp.client.stdio import stdio_client - # Create client session and connect to server async with stdio_client(server_params) as (read, write): async with ClientSession( @@ -183,8 +175,6 @@ async def test_mcp_server_empty_question(server_params: StdioServerParameters): assert isinstance(content_item, TextContent) response_text = content_item.text - import json - response_data = json.loads(response_text) assert response_data["success"] is False assert "No question provided" in response_data["answer"] diff --git a/tests/test_message_text_index_population.py b/tests/test_message_text_index_population.py index 384aaa9..13d53c0 100644 --- a/tests/test_message_text_index_population.py +++ b/tests/test_message_text_index_population.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -30,7 +30,7 @@ async def test_message_text_index_population_from_database(): try: # Use the test model that's already configured in the system - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_message_text_index_serialization.py b/tests/test_message_text_index_serialization.py index 9b14fbe..4504bf4 100644 --- a/tests/test_message_text_index_serialization.py +++ b/tests/test_message_text_index_serialization.py @@ -8,7 +8,7 @@ import numpy as np import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, TextEmbeddingIndexSettings, @@ -44,7 +44,7 @@ def sqlite_db(self) -> sqlite3.Connection: async def test_message_text_index_serialize_not_empty( self, sqlite_db: sqlite3.Connection, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, needs_auth: None, ): """Test that MessageTextIndex serialization produces non-empty data when populated.""" @@ -111,7 +111,7 @@ async def test_message_text_index_serialize_not_empty( async def test_message_text_index_deserialize_restores_data( self, sqlite_db: sqlite3.Connection, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, needs_auth: None, ): """Test that MessageTextIndex deserialization actually restores data.""" diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index b4e40dd..7e00cc4 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -4,6 +4,7 @@ from typing import cast from unittest.mock import AsyncMock, MagicMock +import numpy as np import pytest from typeagent.knowpro.convsettings import MessageTextIndexSettings @@ -42,10 +43,10 @@ def message_text_index( mock_text_location_index: MagicMock, ) -> IMessageTextEmbeddingIndex: """Fixture to create a MessageTextIndex instance with a mocked TextToTextLocationIndex.""" - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = MessageTextIndexSettings(embedding_settings) index = MessageTextIndex(settings) @@ -55,10 +56,10 @@ def message_text_index( def test_message_text_index_init(needs_auth: None): """Test initialization of MessageTextIndex.""" - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = MessageTextIndexSettings(embedding_settings) index = MessageTextIndex(settings) @@ -145,13 +146,11 @@ async def test_lookup_messages_in_subset( @pytest.mark.asyncio async def test_generate_embedding(needs_auth: None): """Test generating an embedding for a message without mocking.""" - import numpy as np - - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings # Create real MessageTextIndex with test model - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = MessageTextIndexSettings(embedding_settings) index = MessageTextIndex(settings) @@ -159,7 +158,7 @@ async def test_generate_embedding(needs_auth: None): embedding = await index.generate_embedding("test message") assert embedding is not None - assert len(embedding) == test_model.embedding_size # 3 for test model + assert len(embedding) == 3 # test model uses embedding size 3 dot = float(np.dot(embedding, embedding)) assert abs(dot - 1.0) < 1e-6, f"Embedding not normalized: {dot}" @@ -205,14 +204,14 @@ async def test_build_message_index(needs_auth: None): ] # Create storage provider asynchronously - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, RelatedTermIndexSettings, ) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py new file mode 100644 index 0000000..11907bd --- /dev/null +++ b/tests/test_model_adapters.py @@ -0,0 +1,191 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from unittest.mock import AsyncMock + +import numpy as np +import pytest + +from pydantic_ai import Embedder +from pydantic_ai.embeddings import EmbeddingResult +from pydantic_ai.messages import ( + ModelResponse, + SystemPromptPart, + TextPart, + UserPromptPart, +) +from pydantic_ai.models import Model +import typechat + +from typeagent.aitools.embeddings import CachingEmbeddingModel, NormalizedEmbedding +from typeagent.aitools.model_adapters import ( + configure_models, + create_chat_model, + PydanticAIChatModel, + PydanticAIEmbedder, +) + +# --------------------------------------------------------------------------- +# Spec format +# --------------------------------------------------------------------------- + + +def test_spec_uses_colon_separator() -> None: + """Specs use ``provider:model`` format matching pydantic_ai conventions.""" + with pytest.raises(Exception): + # A nonsense provider should fail + create_chat_model("nonexistent_provider_xyz:fake-model") + + +# --------------------------------------------------------------------------- +# PydanticAIChatModel adapter +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_chat_adapter_complete() -> None: + """PydanticAIChatModel wraps a pydantic_ai Model.""" + mock_model = AsyncMock(spec=Model) + mock_model.request.return_value = ModelResponse(parts=[TextPart(content="hello")]) + + adapter = PydanticAIChatModel(mock_model) + result = await adapter.complete("test prompt") + assert isinstance(result, typechat.Success) + assert result.value == "hello" + + +@pytest.mark.asyncio +async def test_chat_adapter_prompt_sections() -> None: + """PydanticAIChatModel handles list[PromptSection] prompts.""" + mock_model = AsyncMock(spec=Model) + mock_model.request.return_value = ModelResponse( + parts=[TextPart(content="response")] + ) + + adapter = PydanticAIChatModel(mock_model) + sections: list[typechat.PromptSection] = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + result = await adapter.complete(sections) + assert isinstance(result, typechat.Success) + assert result.value == "response" + + # Verify the request was called with proper message structure + call_args = mock_model.request.call_args + messages = call_args[0][0] + assert len(messages) == 1 + request = messages[0] + assert isinstance(request.parts[0], SystemPromptPart) + assert isinstance(request.parts[1], UserPromptPart) + + +# --------------------------------------------------------------------------- +# PydanticAIEmbedder adapter +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_embedding_adapter_single() -> None: + """PydanticAIEmbedder computes a single normalized embedding.""" + mock_embedder = AsyncMock(spec=Embedder) + raw_vec = [3.0, 4.0, 0.0] + mock_embedder.embed_documents.return_value = EmbeddingResult( + embeddings=[raw_vec], + inputs=["test"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbedder(mock_embedder, "test-model") + result = await adapter.get_embedding_nocache("test") + assert result.shape == (3,) + norm = float(np.linalg.norm(result)) + assert abs(norm - 1.0) < 1e-6 + + +@pytest.mark.asyncio +async def test_embedding_adapter_empty_batch_raises() -> None: + """Empty batch raises ValueError.""" + mock_embedder = AsyncMock(spec=Embedder) + adapter = PydanticAIEmbedder(mock_embedder, "test-model") + with pytest.raises(ValueError, match="Cannot embed an empty list"): + await adapter.get_embeddings_nocache([]) + + +@pytest.mark.asyncio +async def test_embedding_adapter_batch() -> None: + """PydanticAIEmbedder computes batch embeddings.""" + mock_embedder = AsyncMock(spec=Embedder) + mock_embedder.embed_documents.return_value = EmbeddingResult( + embeddings=[[1.0, 0.0], [0.0, 1.0]], + inputs=["a", "b"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbedder(mock_embedder, "test-model") + result = await adapter.get_embeddings_nocache(["a", "b"]) + assert result.shape == (2, 2) + + +@pytest.mark.asyncio +async def test_embedding_adapter_caching() -> None: + """CachingEmbeddingModel avoids re-computing embeddings.""" + mock_embedder = AsyncMock(spec=Embedder) + mock_embedder.embed_documents.return_value = EmbeddingResult( + embeddings=[[1.0, 0.0, 0.0]], + inputs=["cached"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + embedder = PydanticAIEmbedder(mock_embedder, "test-model") + adapter = CachingEmbeddingModel(embedder) + first = await adapter.get_embedding("cached") + second = await adapter.get_embedding("cached") + np.testing.assert_array_equal(first, second) + # embed_documents() should only be called once + assert mock_embedder.embed_documents.call_count == 1 + + +@pytest.mark.asyncio +async def test_embedding_adapter_add_embedding() -> None: + """add_embedding() populates the cache.""" + mock_embedder = AsyncMock(spec=Embedder) + embedder = PydanticAIEmbedder(mock_embedder, "test-model") + adapter = CachingEmbeddingModel(embedder) + vec: NormalizedEmbedding = np.array([1.0, 0.0, 0.0], dtype=np.float32) + adapter.add_embedding("key", vec) + result = await adapter.get_embedding("key") + np.testing.assert_array_equal(result, vec) + # No embed_documents() call needed + mock_embedder.embed_documents.assert_not_called() + + +@pytest.mark.asyncio +async def test_embedding_adapter_empty_batch_returns_empty() -> None: + """Empty batch via CachingEmbeddingModel raises ValueError.""" + mock_embedder = AsyncMock(spec=Embedder) + embedder = PydanticAIEmbedder(mock_embedder, "test-model") + adapter = CachingEmbeddingModel(embedder) + with pytest.raises(ValueError, match="Cannot embed an empty list"): + await adapter.get_embeddings([]) + + +# --------------------------------------------------------------------------- +# configure_models +# --------------------------------------------------------------------------- + + +def test_configure_models_returns_correct_types( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """configure_models creates both adapters.""" + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") + assert isinstance(chat, PydanticAIChatModel) + assert isinstance(embedder, CachingEmbeddingModel) diff --git a/tests/test_podcast_incremental.py b/tests/test_podcast_incremental.py index 92d5ad3..4b1732d 100644 --- a/tests/test_podcast_incremental.py +++ b/tests/test_podcast_incremental.py @@ -8,7 +8,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.podcasts.podcast import Podcast, PodcastMessage, PodcastMessageMeta from typeagent.storage.sqlite.provider import SqliteStorageProvider @@ -20,7 +20,7 @@ async def test_podcast_add_messages_with_indexing(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -57,7 +57,7 @@ async def test_podcast_add_messages_batched(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False diff --git a/tests/test_podcasts.py b/tests/test_podcasts.py index 6f901a7..97be7c3 100644 --- a/tests/test_podcasts.py +++ b/tests/test_podcasts.py @@ -6,7 +6,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import Datetime from typeagent.knowpro.serialization import DATA_FILE_SUFFIX, EMBEDDING_FILE_SUFFIX @@ -18,7 +18,7 @@ @pytest.mark.asyncio async def test_ingest_podcast( - really_needs_auth: None, temp_dir: str, embedding_model: AsyncEmbeddingModel + really_needs_auth: None, temp_dir: str, embedding_model: IEmbeddingModel ): # Import the podcast settings = ConversationSettings(embedding_model) diff --git a/tests/test_property_index_population.py b/tests/test_property_index_population.py index 8b751bb..f6cc3ed 100644 --- a/tests/test_property_index_population.py +++ b/tests/test_property_index_population.py @@ -8,10 +8,9 @@ import tempfile from dotenv import load_dotenv -import numpy as np import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import kplib from typeagent.knowpro.convsettings import ( @@ -23,16 +22,6 @@ from typeagent.storage import SqliteStorageProvider -class MockEmbeddingModel(AsyncEmbeddingModel): - def __init__(self): - super().__init__(embedding_size=3, model_name="test") - - async def get_embeddings(self, keys: list[str]) -> np.ndarray: - result = np.random.rand(len(keys), 3).astype(np.float32) - norms = np.linalg.norm(result, axis=1, keepdims=True) - return result / norms - - @pytest.mark.asyncio async def test_property_index_population_from_database(really_needs_auth): """Test that property index is correctly populated when reopening a database.""" @@ -42,7 +31,7 @@ async def test_property_index_population_from_database(really_needs_auth): temp_db_file.close() try: - embedding_model = MockEmbeddingModel() + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) @@ -98,7 +87,7 @@ async def test_property_index_population_from_database(really_needs_auth): # Reopen database and verify property index # Use the same embedding settings to avoid dimension mismatch - embedding_model2 = MockEmbeddingModel() + embedding_model2 = create_test_embedding_model() embedding_settings2 = TextEmbeddingIndexSettings(embedding_model2) message_text_settings2 = MessageTextIndexSettings(embedding_settings2) related_terms_settings2 = RelatedTermIndexSettings(embedding_settings2) diff --git a/tests/test_query_method.py b/tests/test_query_method.py index ac60582..bcbf2e0 100644 --- a/tests/test_query_method.py +++ b/tests/test_query_method.py @@ -6,7 +6,7 @@ import pytest from typeagent import create_conversation -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @@ -15,7 +15,7 @@ async def test_query_method_basic(really_needs_auth: None): """Test the basic query method workflow.""" # Create a conversation with some test data - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, @@ -60,7 +60,7 @@ async def test_query_method_basic(really_needs_auth: None): @pytest.mark.asyncio async def test_query_method_empty_conversation(really_needs_auth: None): """Test query method on an empty conversation.""" - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, diff --git a/tests/test_related_terms_fast.py b/tests/test_related_terms_fast.py index 3919666..fbcf60c 100644 --- a/tests/test_related_terms_fast.py +++ b/tests/test_related_terms_fast.py @@ -9,7 +9,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import SemanticRef, TextLocation, TextRange from typeagent.knowpro.kplib import ConcreteEntity @@ -26,7 +26,7 @@ async def test_related_terms_index_minimal(): try: # Create minimal test data with test embedding model (no API keys needed) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) # Use a simple storage provider without AI embeddings diff --git a/tests/test_related_terms_index_population.py b/tests/test_related_terms_index_population.py index 9d16936..9de6f01 100644 --- a/tests/test_related_terms_index_population.py +++ b/tests/test_related_terms_index_population.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import kplib from typeagent.knowpro.convsettings import ( @@ -32,7 +32,7 @@ async def test_related_terms_index_population_from_database(really_needs_auth): try: # Use the test model that's already configured in the system - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_reltermsindex.py b/tests/test_reltermsindex.py index 57afa91..47d21d5 100644 --- a/tests/test_reltermsindex.py +++ b/tests/test_reltermsindex.py @@ -8,7 +8,7 @@ import pytest_asyncio # TypeAgent imports -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -30,7 +30,7 @@ @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def related_terms_index( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[ITermToRelatedTermsIndex, None]: class DummyTestMessage(IMessage): diff --git a/tests/test_secindex.py b/tests/test_secindex.py index a9008aa..39665b0 100644 --- a/tests/test_secindex.py +++ b/tests/test_secindex.py @@ -3,7 +3,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -29,9 +29,9 @@ def simple_conversation() -> FakeConversation: @pytest.fixture def conversation_settings(needs_auth: None) -> ConversationSettings: - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model - model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + model = create_test_embedding_model() return ConversationSettings(model) @@ -41,7 +41,7 @@ def test_conversation_secondary_indexes_initialization( """Test initialization of ConversationSecondaryIndexes.""" storage_provider = memory_storage # Create proper settings for testing - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = RelatedTermIndexSettings(embedding_settings) indexes = ConversationSecondaryIndexes(storage_provider, settings) diff --git a/tests/test_secindex_storage_integration.py b/tests/test_secindex_storage_integration.py index a050771..15738bb 100644 --- a/tests/test_secindex_storage_integration.py +++ b/tests/test_secindex_storage_integration.py @@ -4,7 +4,7 @@ # Test that ConversationSecondaryIndexes now uses storage provider properly import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import RelatedTermIndexSettings from typeagent.knowpro.secindex import ConversationSecondaryIndexes @@ -19,7 +19,7 @@ async def test_secondary_indexes_use_storage_provider( storage_provider = memory_storage # Create test settings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_semrefindex.py b/tests/test_semrefindex.py index 5fbff6c..20dad6b 100644 --- a/tests/test_semrefindex.py +++ b/tests/test_semrefindex.py @@ -8,7 +8,7 @@ import pytest_asyncio # TypeAgent imports -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -37,7 +37,7 @@ @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def semantic_ref_index( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[ITermToSemanticRefIndex, None]: """Unified fixture to create a semantic ref index for both memory and SQLite providers.""" @@ -97,7 +97,7 @@ def get_knowledge(self): @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def semantic_ref_setup( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[Dict[str, ITermToSemanticRefIndex | ISemanticRefCollection], None]: """Unified fixture that provides both semantic ref index and collection for testing helper functions.""" diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 92aa71d..adb46dd 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -113,7 +113,7 @@ def test_write_and_read_conversation_data( # Read back the data read_data = Podcast._read_conversation_data_from_file( - str(filename), embedding_size=2 + str(filename), ) assert read_data is not None assert read_data.get("relatedTermsIndexData") is not None diff --git a/tests/test_sqlite_indexes.py b/tests/test_sqlite_indexes.py index 8639dab..825f57d 100644 --- a/tests/test_sqlite_indexes.py +++ b/tests/test_sqlite_indexes.py @@ -10,7 +10,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import interfaces from typeagent.knowpro.convsettings import MessageTextIndexSettings @@ -35,7 +35,7 @@ @pytest.fixture def embedding_settings( - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, ) -> TextEmbeddingIndexSettings: """Create TextEmbeddingIndexSettings for testing.""" return TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_sqlitestore.py b/tests/test_sqlitestore.py index 7bd1f98..704ab0b 100644 --- a/tests/test_sqlitestore.py +++ b/tests/test_sqlitestore.py @@ -3,13 +3,14 @@ from collections.abc import AsyncGenerator from dataclasses import field +from datetime import datetime import pytest import pytest_asyncio from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -39,7 +40,7 @@ def get_knowledge(self) -> KnowledgeResponse: @pytest_asyncio.fixture async def dummy_sqlite_storage_provider( - temp_db_path: str, embedding_model: AsyncEmbeddingModel + temp_db_path: str, embedding_model: IEmbeddingModel ) -> AsyncGenerator[SqliteStorageProvider[DummyMessage], None]: """Create a SqliteStorageProvider for testing.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -128,8 +129,6 @@ async def test_sqlite_timestamp_index( dummy_sqlite_storage_provider: SqliteStorageProvider[DummyMessage], ): """Test SqliteTimestampToTextRangeIndex functionality.""" - from datetime import datetime - from typeagent.knowpro.interfaces import DateRange # Set up database with some messages diff --git a/tests/test_storage_providers_unified.py b/tests/test_storage_providers_unified.py index 67ae9d7..d0ecb9c 100644 --- a/tests/test_storage_providers_unified.py +++ b/tests/test_storage_providers_unified.py @@ -9,6 +9,8 @@ """ from dataclasses import field +import os +import tempfile from typing import assert_never, AsyncGenerator import pytest @@ -16,7 +18,7 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import kplib from typeagent.knowpro.convsettings import ( @@ -52,7 +54,7 @@ def get_knowledge(self) -> KnowledgeResponse: @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def storage_provider_type( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[tuple[IStorageProvider, str], None]: """Parameterized fixture that provides both memory and sqlite storage providers.""" @@ -328,7 +330,7 @@ async def test_conversation_threads_interface_parity( # Cross-provider validation tests @pytest.mark.asyncio async def test_cross_provider_message_collection_equivalence( - embedding_model: AsyncEmbeddingModel, temp_db_path: str, needs_auth: None + embedding_model: IEmbeddingModel, temp_db_path: str, needs_auth: None ): """Test that both providers handle message collections equivalently.""" # Create both providers with identical settings @@ -586,7 +588,7 @@ async def test_timestamp_index_with_data( @pytest.mark.asyncio async def test_storage_provider_independence( - embedding_model: AsyncEmbeddingModel, temp_db_path: str, needs_auth: None + embedding_model: IEmbeddingModel, temp_db_path: str, needs_auth: None ): """Test that different storage provider instances work independently.""" # Create settings shared between providers @@ -605,9 +607,6 @@ async def test_storage_provider_independence( ) # Create two sqlite providers (with different temp files) - import os - import tempfile - temp_file1 = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) temp_path1 = temp_file1.name temp_file1.close() diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index 9d93003..9d98ae8 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -5,8 +5,10 @@ import os import pytest +import webvtt -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH from typeagent.transcripts.transcript import ( @@ -88,7 +90,7 @@ def test_get_transcript_info(): @pytest.fixture def conversation_settings( - needs_auth: None, embedding_model: AsyncEmbeddingModel + needs_auth: None, embedding_model: IEmbeddingModel ) -> ConversationSettings: """Create conversation settings for testing.""" return ConversationSettings(embedding_model) @@ -101,8 +103,6 @@ def conversation_settings( @pytest.mark.asyncio async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings): """Test importing a VTT file into a Transcript object.""" - import webvtt - from typeagent.storage.memory.collections import ( MemoryMessageCollection, MemorySemanticRefCollection, @@ -224,10 +224,8 @@ def test_transcript_message_creation(): @pytest.mark.asyncio async def test_transcript_creation(): """Test creating an empty transcript.""" - from typeagent.aitools.embeddings import TEST_MODEL_NAME - # Create a minimal transcript for testing structure - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() settings = ConversationSettings(embedding_model) transcript = await Transcript.create( @@ -242,7 +240,7 @@ async def test_transcript_creation(): @pytest.mark.asyncio async def test_transcript_knowledge_extraction_slow( - really_needs_auth: None, embedding_model: AsyncEmbeddingModel + really_needs_auth: None, embedding_model: IEmbeddingModel ): """ Test that knowledge extraction works during transcript ingestion. @@ -254,8 +252,6 @@ async def test_transcript_knowledge_extraction_slow( 4. Verifies both mechanical extraction (entities/actions from metadata) and LLM extraction (topics from content) work correctly """ - import webvtt - from typeagent.storage.memory.collections import ( MemoryMessageCollection, MemorySemanticRefCollection, diff --git a/tests/test_utils.py b/tests/test_utils.py index cb95c93..ceea367 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,6 +7,9 @@ from dotenv import load_dotenv +import pydantic.dataclasses +import typechat + import typeagent.aitools.utils as utils @@ -37,14 +40,10 @@ def test_load_dotenv(really_needs_auth): def test_create_translator(): - import typechat - class DummyModel(typechat.TypeChatLanguageModel): async def complete(self, *args, **kwargs) -> typechat.Result: return typechat.Failure("dummy response") - import pydantic.dataclasses - @pydantic.dataclasses.dataclass class DummySchema: pass diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 62abd39..81ccecc 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -5,9 +5,11 @@ import pytest from typeagent.aitools.embeddings import ( - AsyncEmbeddingModel, + CachingEmbeddingModel, NormalizedEmbedding, - TEST_MODEL_NAME, +) +from typeagent.aitools.model_adapters import ( + create_test_embedding_model, ) from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase @@ -19,9 +21,7 @@ def vector_base() -> VectorBase: def make_vector_base() -> VectorBase: - settings = TextEmbeddingIndexSettings( - AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - ) + settings = TextEmbeddingIndexSettings(create_test_embedding_model()) return VectorBase(settings) @@ -61,8 +61,10 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): assert len(bulk_vector_base) == len(vector_base) np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) - sequential_cache = vector_base._model._embedding_cache - bulk_cache = bulk_vector_base._model._embedding_cache + assert isinstance(vector_base._model, CachingEmbeddingModel) + assert isinstance(bulk_vector_base._model, CachingEmbeddingModel) + sequential_cache = vector_base._model._cache + bulk_cache = bulk_vector_base._model._cache assert set(sequential_cache.keys()) == set(bulk_cache.keys()) for key in keys: np.testing.assert_array_equal(bulk_cache[key], sequential_cache[key]) @@ -84,9 +86,8 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp await vector_base.add_key(key, cache=False) assert len(vector_base) == len(sample_embeddings) - assert ( - vector_base._model._embedding_cache == {} - ), "Cache should remain empty when cache=False" + assert isinstance(vector_base._model, CachingEmbeddingModel) + assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" @pytest.mark.asyncio @@ -105,9 +106,8 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam await vector_base.add_keys(keys, cache=False) assert len(vector_base) == len(sample_embeddings) - assert ( - vector_base._model._embedding_cache == {} - ), "Cache should remain empty when cache=False" + assert isinstance(vector_base._model, CachingEmbeddingModel) + assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" @pytest.mark.asyncio @@ -195,3 +195,28 @@ def test_fuzzy_lookup_embedding_in_subset( # Empty subset returns empty list result = vector_base.fuzzy_lookup_embedding_in_subset(query, []) assert result == [] + + +def test_add_embedding_size_mismatch(vector_base: VectorBase) -> None: + """Adding an embedding of wrong size raises ValueError.""" + emb3 = np.array([0.1, 0.2, 0.3], dtype=np.float32) + emb5 = np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32) + vector_base.add_embedding(None, emb3) + with pytest.raises(ValueError, match="Embedding size mismatch"): + vector_base.add_embedding(None, emb5) + + +def test_add_embeddings_size_mismatch(vector_base: VectorBase) -> None: + """Adding a batch of embeddings of wrong size raises ValueError.""" + batch3 = np.array([[0.1, 0.2, 0.3]], dtype=np.float32) + batch5 = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]], dtype=np.float32) + vector_base.add_embeddings(None, batch3) + with pytest.raises(ValueError, match="Embedding size mismatch"): + vector_base.add_embeddings(None, batch5) + + +def test_add_embeddings_wrong_ndim(vector_base: VectorBase) -> None: + """Adding a 1D array via add_embeddings raises ValueError.""" + emb1d = np.array([0.1, 0.2, 0.3], dtype=np.float32) + with pytest.raises(ValueError, match="Expected 2D"): + vector_base.add_embeddings(None, emb1d) diff --git a/tools/ingest_vtt.py b/tools/ingest_vtt.py index ad56775..7fbf38f 100644 --- a/tools/ingest_vtt.py +++ b/tools/ingest_vtt.py @@ -24,7 +24,7 @@ from dotenv import load_dotenv import webvtt -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.model_adapters import create_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import ConversationMetadata from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH @@ -203,7 +203,10 @@ async def ingest_vtt_files( if verbose: print("Setting up conversation settings...") try: - embedding_model = AsyncEmbeddingModel(model_name=embedding_name) + spec = embedding_name + if spec and ":" not in spec: + spec = f"openai:{spec}" + embedding_model = create_embedding_model(spec) settings = ConversationSettings(embedding_model) # Create metadata with the conversation name diff --git a/tools/query.py b/tools/query.py index 3d28a89..b3a3509 100644 --- a/tools/query.py +++ b/tools/query.py @@ -32,11 +32,10 @@ import typechat -from typeagent.aitools import embeddings, utils +from typeagent.aitools import embeddings, model_adapters, utils from typeagent.knowpro import ( answer_response_schema, answers, - convknowledge, kplib, query, search, @@ -150,7 +149,7 @@ class ProcessingContext: debug2: typing.Literal["none", "diff", "full", "skip"] debug3: typing.Literal["none", "diff", "full", "nice"] debug4: typing.Literal["none", "diff", "full", "nice"] - embedding_model: embeddings.AsyncEmbeddingModel + embedding_model: embeddings.IEmbeddingModel query_translator: typechat.TypeChatJsonTranslator[search_query_schema.SearchQuery] answer_translator: typechat.TypeChatJsonTranslator[ answer_response_schema.AnswerResponse @@ -576,7 +575,7 @@ async def main(): "Error: non-empty --search-results required for batch mode." ) - model = convknowledge.create_typechat_model() + model = model_adapters.create_chat_model() query_translator = utils.create_translator(model, search_query_schema.SearchQuery) if args.alt_schema: if args.verbose: diff --git a/uv.lock b/uv.lock index 5a2dab3..e26630a 100644 --- a/uv.lock +++ b/uv.lock @@ -2370,6 +2370,7 @@ dependencies = [ { name = "numpy" }, { name = "openai" }, { name = "pydantic" }, + { name = "pydantic-ai-slim", extra = ["openai"] }, { name = "pyreadline3", marker = "sys_platform == 'win32'" }, { name = "python-dotenv" }, { name = "tiktoken" }, @@ -2395,7 +2396,6 @@ dev = [ { name = "logfire" }, { name = "msgraph-sdk" }, { name = "opentelemetry-instrumentation-httpx" }, - { name = "pydantic-ai-slim", extra = ["openai"] }, { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -2413,6 +2413,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.81.0" }, { name = "opentelemetry-instrumentation-httpx", marker = "extra == 'logfire'", specifier = ">=0.57b0" }, { name = "pydantic", specifier = ">=2.11.4" }, + { name = "pydantic-ai-slim", extras = ["openai"], specifier = ">=1.39.0" }, { name = "pyreadline3", marker = "sys_platform == 'win32'", specifier = ">=3.5.4" }, { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "tiktoken", specifier = ">=0.12.0" }, @@ -2433,7 +2434,6 @@ dev = [ { name = "logfire", specifier = ">=4.1.0" }, { name = "msgraph-sdk", specifier = ">=1.54.0" }, { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.57b0" }, - { name = "pydantic-ai-slim", extras = ["openai"], specifier = ">=1.39.0" }, { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.3.5" }, { name = "pytest-asyncio", specifier = ">=0.26.0" },