From 74097cfc7831f28af7c3272d7389cb62d3290756 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 17 Feb 2026 09:02:51 -0800 Subject: [PATCH 01/23] Unreviewed agent output: make chat and embed interfaces provider-agnostic using pydantic_ai --- src/typeagent/aitools/embeddings.py | 60 +++++++++++++++++++ src/typeagent/aitools/utils.py | 60 +++++++++++++++++++ src/typeagent/aitools/vectorbase.py | 17 ++++-- src/typeagent/knowpro/convknowledge.py | 53 +--------------- src/typeagent/knowpro/convsettings.py | 6 +- src/typeagent/mcp/server.py | 2 +- src/typeagent/storage/sqlite/provider.py | 4 +- tests/conftest.py | 12 ++-- tests/test_conversation_metadata.py | 26 ++++---- tests/test_demo.py | 4 +- .../test_message_text_index_serialization.py | 6 +- tests/test_podcasts.py | 4 +- tests/test_reltermsindex.py | 4 +- tests/test_semrefindex.py | 6 +- tests/test_sqlite_indexes.py | 4 +- tests/test_sqlitestore.py | 4 +- tests/test_storage_providers_unified.py | 8 +-- tests/test_transcripts.py | 6 +- tests/test_vectorbase.py | 8 +-- tools/ingest_vtt.py | 4 +- tools/query.py | 2 +- 21 files changed, 193 insertions(+), 107 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index 819993f4..c3403af2 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -3,6 +3,7 @@ import asyncio import os +from typing import Protocol, runtime_checkable import numpy as np from numpy.typing import NDArray @@ -20,6 +21,39 @@ type NormalizedEmbeddings = NDArray[np.float32] # An array of embeddings +@runtime_checkable +class IEmbeddingModel(Protocol): + """Provider-agnostic interface for embedding models. + + Implement this protocol to add support for a new embedding provider + (e.g. Anthropic, Gemini, local models). The existing AsyncEmbeddingModel + implements it for OpenAI and Azure OpenAI. + """ + + model_name: str + embedding_size: int + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + """Cache an already-computed embedding under the given key.""" + ... + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + """Compute a single embedding without caching.""" + ... + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + """Compute embeddings for a batch of strings without caching.""" + ... + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + """Retrieve a single embedding, using cache if available.""" + ... + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + """Retrieve embeddings for multiple keys, using cache if available.""" + ... + + DEFAULT_MODEL_NAME = "text-embedding-ada-002" DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002) DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI @@ -311,3 +345,29 @@ async def truncate_input(self, input: str) -> tuple[str, int]: return self.encoding.decode(truncated_tokens), self.max_chunk_size else: return input, len(tokens) + + +def create_embedding_model( + embedding_size: int | None = None, + model_name: str | None = None, + **kwargs, +) -> IEmbeddingModel: + """Create an embedding model using OpenAI/Azure OpenAI. + + This is the default factory. To use a different provider, create an + instance of a class that implements ``IEmbeddingModel`` and pass it + directly to ``TextEmbeddingIndexSettings`` or ``ConversationSettings``. + + Args: + embedding_size: Requested embedding dimensionality (provider-specific). + model_name: Model identifier (e.g. "text-embedding-ada-002"). + **kwargs: Extra keyword arguments forwarded to ``AsyncEmbeddingModel``. + + Returns: + An ``IEmbeddingModel`` instance backed by OpenAI / Azure OpenAI. + """ + return AsyncEmbeddingModel( + embedding_size=embedding_size, + model_name=model_name, + **kwargs, + ) diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index b9150334..2459473c 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -3,6 +3,7 @@ """Utilities that are hard to fit in any specific module.""" +import asyncio from contextlib import contextmanager import difflib import os @@ -16,6 +17,8 @@ import typechat +from .auth import AzureTokenProvider, get_shared_token_provider + @contextmanager def timelog(label: str, verbose: bool = True): @@ -87,6 +90,63 @@ def create_translator[T]( return translator +# TODO: Make these parameters that can be configured (e.g. from command line). +DEFAULT_MAX_RETRY_ATTEMPTS = 0 +DEFAULT_TIMEOUT_SECONDS = 25 + + +class ModelWrapper(typechat.TypeChatLanguageModel): + """Wraps a TypeChat model to handle Azure token refresh.""" + + def __init__( + self, + base_model: typechat.TypeChatLanguageModel, + token_provider: AzureTokenProvider, + ): + self.base_model = base_model + self.token_provider = token_provider + + async def complete( + self, prompt: str | list[typechat.PromptSection] + ) -> typechat.Result[str]: + if self.token_provider.needs_refresh(): + loop = asyncio.get_running_loop() + api_key = await loop.run_in_executor( + None, self.token_provider.refresh_token + ) + env: dict[str, str | None] = dict(os.environ) + key_name = "AZURE_OPENAI_API_KEY" + env[key_name] = api_key + self.base_model = typechat.create_language_model(env) + self.base_model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS + return await self.base_model.complete(prompt) + + +def create_typechat_model() -> typechat.TypeChatLanguageModel: + """Create a TypeChat language model using OpenAI or Azure OpenAI. + + Reads ``OPENAI_API_KEY``, ``AZURE_OPENAI_API_KEY`` and related env vars. + Handles Azure ``identity`` token provider for Microsoft internal usage. + + To use a different provider (e.g. Anthropic, Gemini), implement + ``typechat.TypeChatLanguageModel`` directly and pass it to + ``KnowledgeExtractor`` or ``create_translator()``. + """ + env: dict[str, str | None] = dict(os.environ) + key_name = "AZURE_OPENAI_API_KEY" + key = env.get(key_name) + shared_token_provider: AzureTokenProvider | None = None + if key is not None and key.lower() == "identity": + shared_token_provider = get_shared_token_provider() + env[key_name] = shared_token_provider.get_token() + model = typechat.create_language_model(env) + model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS + model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS + if shared_token_provider is not None: + model = ModelWrapper(model, shared_token_provider) + return model + + # Vibe-coded by o4-mini-high def list_diff(label_a, a, label_b, b, max_items): """Print colorized diff between two sorted list of numbers.""" diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 3bbc5729..0a52f838 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -6,9 +6,14 @@ import numpy as np -from openai import DEFAULT_MAX_RETRIES +from .embeddings import ( + create_embedding_model, + IEmbeddingModel, + NormalizedEmbedding, + NormalizedEmbeddings, +) -from .embeddings import AsyncEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings +DEFAULT_MAX_RETRIES = 2 @dataclass @@ -19,7 +24,7 @@ class ScoredInt: @dataclass class TextEmbeddingIndexSettings: - embedding_model: AsyncEmbeddingModel + embedding_model: IEmbeddingModel embedding_size: int # Set to embedding_model.embedding_size min_score: float # Between 0.0 and 1.0 max_matches: int | None # >= 1; None means no limit @@ -28,7 +33,7 @@ class TextEmbeddingIndexSettings: def __init__( self, - embedding_model: AsyncEmbeddingModel | None = None, + embedding_model: IEmbeddingModel | None = None, embedding_size: int | None = None, min_score: float | None = None, max_matches: int | None = None, @@ -41,7 +46,7 @@ def __init__( self.max_retries = ( max_retries if max_retries is not None else DEFAULT_MAX_RETRIES ) - self.embedding_model = embedding_model or AsyncEmbeddingModel( + self.embedding_model = embedding_model or create_embedding_model( embedding_size, max_retries=self.max_retries ) self.embedding_size = self.embedding_model.embedding_size @@ -53,7 +58,7 @@ def __init__( class VectorBase: settings: TextEmbeddingIndexSettings _vectors: NormalizedEmbeddings - _model: AsyncEmbeddingModel + _model: IEmbeddingModel _embedding_size: int def __init__(self, settings: TextEmbeddingIndexSettings): diff --git a/src/typeagent/knowpro/convknowledge.py b/src/typeagent/knowpro/convknowledge.py index 4bea97d4..53a10b25 100644 --- a/src/typeagent/knowpro/convknowledge.py +++ b/src/typeagent/knowpro/convknowledge.py @@ -1,62 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import asyncio from dataclasses import dataclass, field -import os import typechat from . import kplib -from ..aitools import auth +from ..aitools.utils import create_typechat_model # Re-export for backward compat -# TODO: Move ModelWrapper and create_typechat_model() to aitools package. - - -# TODO: Make these parameters that can be configured (e.g. from command line). -DEFAULT_MAX_RETRY_ATTEMPTS = 0 -DEFAULT_TIMEOUT_SECONDS = 25 - - -class ModelWrapper(typechat.TypeChatLanguageModel): - def __init__( - self, - base_model: typechat.TypeChatLanguageModel, - token_provider: auth.AzureTokenProvider, - ): - self.base_model = base_model - self.token_provider = token_provider - - async def complete( - self, prompt: str | list[typechat.PromptSection] - ) -> typechat.Result[str]: - if self.token_provider.needs_refresh(): - loop = asyncio.get_running_loop() - api_key = await loop.run_in_executor( - None, self.token_provider.refresh_token - ) - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - env[key_name] = api_key - self.base_model = typechat.create_language_model(env) - self.base_model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - return await self.base_model.complete(prompt) - - -def create_typechat_model() -> typechat.TypeChatLanguageModel: - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - key = env.get(key_name) - shared_token_provider: auth.AzureTokenProvider | None = None - if key is not None and key.lower() == "identity": - shared_token_provider = auth.get_shared_token_provider() - env[key_name] = shared_token_provider.get_token() - model = typechat.create_language_model(env) - model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS - if shared_token_provider is not None: - model = ModelWrapper(model, shared_token_provider) - return model +# Re-export: callers may still do ``convknowledge.create_typechat_model()``. +__all__ = ["create_typechat_model", "KnowledgeExtractor"] @dataclass diff --git a/src/typeagent/knowpro/convsettings.py b/src/typeagent/knowpro/convsettings.py index 627546ed..7e25cd93 100644 --- a/src/typeagent/knowpro/convsettings.py +++ b/src/typeagent/knowpro/convsettings.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from ..aitools.embeddings import AsyncEmbeddingModel +from ..aitools.embeddings import create_embedding_model, IEmbeddingModel from ..aitools.vectorbase import TextEmbeddingIndexSettings from .interfaces import IKnowledgeExtractor, IStorageProvider @@ -38,11 +38,11 @@ class ConversationSettings: def __init__( self, - model: AsyncEmbeddingModel | None = None, + model: IEmbeddingModel | None = None, storage_provider: IStorageProvider | None = None, ): # All settings share the same model, so they share the embedding cache. - model = model or AsyncEmbeddingModel() + model = model or create_embedding_model() self.embedding_model = model min_score = 0.85 self.related_term_index_settings = RelatedTermIndexSettings( diff --git a/src/typeagent/mcp/server.py b/src/typeagent/mcp/server.py index 19919c92..dcd4a3cf 100644 --- a/src/typeagent/mcp/server.py +++ b/src/typeagent/mcp/server.py @@ -102,7 +102,7 @@ class ProcessingContext: query_context: query.QueryEvalContext[ podcast.PodcastMessage, TermToSemanticRefIndex ] - embedding_model: embeddings.AsyncEmbeddingModel + embedding_model: embeddings.IEmbeddingModel query_translator: typechat.TypeChatJsonTranslator[SearchQuery] answer_translator: typechat.TypeChatJsonTranslator[AnswerResponse] diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 8fae1b2b..975d6a70 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone import sqlite3 -from ...aitools.embeddings import AsyncEmbeddingModel +from ...aitools.embeddings import create_embedding_model from ...aitools.vectorbase import TextEmbeddingIndexSettings from ...knowpro import interfaces from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings @@ -125,7 +125,7 @@ def _resolve_embedding_settings( if provided_message_settings is None: if stored_size is not None or stored_name is not None: - embedding_model = AsyncEmbeddingModel( + embedding_model = create_embedding_model( embedding_size=stored_size, model_name=stored_name, ) diff --git a/tests/conftest.py b/tests/conftest.py index 40aee885..7f7ce210 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,11 @@ from openai.types.embedding import Embedding import tiktoken -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.embeddings import ( + AsyncEmbeddingModel, + IEmbeddingModel, + TEST_MODEL_NAME, +) from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -90,7 +94,7 @@ def really_needs_auth() -> None: @pytest.fixture(scope="session") -def embedding_model() -> AsyncEmbeddingModel: +def embedding_model() -> IEmbeddingModel: """Fixture to create a test embedding model with small embedding size for faster tests.""" return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) @@ -130,7 +134,7 @@ def temp_db_path() -> Iterator[str]: @pytest.fixture def memory_storage( - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, ) -> MemoryStorageProvider: """Create a memory storage provider with settings.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model=embedding_model) @@ -188,7 +192,7 @@ def get_text_location(self) -> TextLocation: @pytest_asyncio.fixture async def sqlite_storage( - temp_db_path: str, embedding_model: AsyncEmbeddingModel + temp_db_path: str, embedding_model: IEmbeddingModel ) -> AsyncGenerator[SqliteStorageProvider[FakeMessage], None]: """Create a SqliteStorageProvider for testing.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_conversation_metadata.py b/tests/test_conversation_metadata.py index eadc125a..69d80dab 100644 --- a/tests/test_conversation_metadata.py +++ b/tests/test_conversation_metadata.py @@ -16,7 +16,11 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.embeddings import ( + AsyncEmbeddingModel, + IEmbeddingModel, + TEST_MODEL_NAME, +) from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -54,7 +58,7 @@ def get_knowledge(self) -> KnowledgeResponse: @pytest_asyncio.fixture async def storage_provider( - temp_db_path: str, embedding_model: AsyncEmbeddingModel + temp_db_path: str, embedding_model: IEmbeddingModel ) -> AsyncGenerator[SqliteStorageProvider[DummyMessage], None]: """Create a SqliteStorageProvider for testing conversation metadata.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -270,7 +274,7 @@ def test_get_db_version_with_metadata( @pytest.mark.asyncio async def test_multiple_conversations_different_dbs( - self, embedding_model: AsyncEmbeddingModel + self, embedding_model: IEmbeddingModel ): """Test multiple conversations in different database files.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -343,7 +347,7 @@ async def test_multiple_conversations_different_dbs( @pytest.mark.asyncio async def test_conversation_metadata_single_per_db( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test that only one conversation metadata can exist per database.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -418,7 +422,7 @@ async def test_conversation_metadata_with_special_characters( @pytest.mark.asyncio async def test_conversation_metadata_persistence( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test that conversation metadata persists across provider instances.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -486,7 +490,7 @@ async def test_empty_string_timestamps( @pytest.mark.asyncio async def test_very_long_name_tag( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test conversation metadata with very long name_tag.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -517,7 +521,7 @@ async def test_very_long_name_tag( @pytest.mark.asyncio async def test_unicode_name_tag( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test conversation metadata with Unicode name_tag.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -548,7 +552,7 @@ async def test_unicode_name_tag( @pytest.mark.asyncio async def test_conversation_metadata_shared_access( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test shared access to metadata using the same database file.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -599,7 +603,7 @@ async def test_conversation_metadata_shared_access( @pytest.mark.asyncio async def test_embedding_metadata_mismatch_raises( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Ensure a mismatch between stored metadata and provided settings raises.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -643,7 +647,7 @@ async def test_embedding_metadata_mismatch_raises( @pytest.mark.asyncio async def test_embedding_model_mismatch_raises( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Ensure providing a different embedding model name raises.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -683,7 +687,7 @@ async def test_embedding_model_mismatch_raises( @pytest.mark.asyncio async def test_updated_at_changes_on_add_messages( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test that updated_at timestamp is updated when messages are added.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_demo.py b/tests/test_demo.py index 599f006e..39f2f061 100644 --- a/tests/test_demo.py +++ b/tests/test_demo.py @@ -6,7 +6,7 @@ import textwrap import time -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import ScoredSemanticRefOrdinal from typeagent.podcasts import podcast @@ -33,7 +33,7 @@ async def main(filename_prefix: str): settings = ConversationSettings() model = settings.embedding_model assert model is not None - assert isinstance(model, AsyncEmbeddingModel), f"model is {model!r}" + assert isinstance(model, IEmbeddingModel), f"model is {model!r}" assert settings.thread_settings.embedding_model is model assert ( settings.message_text_index_settings.embedding_index_settings.embedding_model diff --git a/tests/test_message_text_index_serialization.py b/tests/test_message_text_index_serialization.py index 9b14fbe2..4504bf42 100644 --- a/tests/test_message_text_index_serialization.py +++ b/tests/test_message_text_index_serialization.py @@ -8,7 +8,7 @@ import numpy as np import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, TextEmbeddingIndexSettings, @@ -44,7 +44,7 @@ def sqlite_db(self) -> sqlite3.Connection: async def test_message_text_index_serialize_not_empty( self, sqlite_db: sqlite3.Connection, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, needs_auth: None, ): """Test that MessageTextIndex serialization produces non-empty data when populated.""" @@ -111,7 +111,7 @@ async def test_message_text_index_serialize_not_empty( async def test_message_text_index_deserialize_restores_data( self, sqlite_db: sqlite3.Connection, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, needs_auth: None, ): """Test that MessageTextIndex deserialization actually restores data.""" diff --git a/tests/test_podcasts.py b/tests/test_podcasts.py index 6f901a75..97be7c3e 100644 --- a/tests/test_podcasts.py +++ b/tests/test_podcasts.py @@ -6,7 +6,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import Datetime from typeagent.knowpro.serialization import DATA_FILE_SUFFIX, EMBEDDING_FILE_SUFFIX @@ -18,7 +18,7 @@ @pytest.mark.asyncio async def test_ingest_podcast( - really_needs_auth: None, temp_dir: str, embedding_model: AsyncEmbeddingModel + really_needs_auth: None, temp_dir: str, embedding_model: IEmbeddingModel ): # Import the podcast settings = ConversationSettings(embedding_model) diff --git a/tests/test_reltermsindex.py b/tests/test_reltermsindex.py index 57afa913..47d21d59 100644 --- a/tests/test_reltermsindex.py +++ b/tests/test_reltermsindex.py @@ -8,7 +8,7 @@ import pytest_asyncio # TypeAgent imports -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -30,7 +30,7 @@ @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def related_terms_index( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[ITermToRelatedTermsIndex, None]: class DummyTestMessage(IMessage): diff --git a/tests/test_semrefindex.py b/tests/test_semrefindex.py index 5fbff6cc..20dad6be 100644 --- a/tests/test_semrefindex.py +++ b/tests/test_semrefindex.py @@ -8,7 +8,7 @@ import pytest_asyncio # TypeAgent imports -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -37,7 +37,7 @@ @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def semantic_ref_index( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[ITermToSemanticRefIndex, None]: """Unified fixture to create a semantic ref index for both memory and SQLite providers.""" @@ -97,7 +97,7 @@ def get_knowledge(self): @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def semantic_ref_setup( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[Dict[str, ITermToSemanticRefIndex | ISemanticRefCollection], None]: """Unified fixture that provides both semantic ref index and collection for testing helper functions.""" diff --git a/tests/test_sqlite_indexes.py b/tests/test_sqlite_indexes.py index 8639dab3..825f57d8 100644 --- a/tests/test_sqlite_indexes.py +++ b/tests/test_sqlite_indexes.py @@ -10,7 +10,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import interfaces from typeagent.knowpro.convsettings import MessageTextIndexSettings @@ -35,7 +35,7 @@ @pytest.fixture def embedding_settings( - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, ) -> TextEmbeddingIndexSettings: """Create TextEmbeddingIndexSettings for testing.""" return TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_sqlitestore.py b/tests/test_sqlitestore.py index 7bd1f98d..ad784221 100644 --- a/tests/test_sqlitestore.py +++ b/tests/test_sqlitestore.py @@ -9,7 +9,7 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -39,7 +39,7 @@ def get_knowledge(self) -> KnowledgeResponse: @pytest_asyncio.fixture async def dummy_sqlite_storage_provider( - temp_db_path: str, embedding_model: AsyncEmbeddingModel + temp_db_path: str, embedding_model: IEmbeddingModel ) -> AsyncGenerator[SqliteStorageProvider[DummyMessage], None]: """Create a SqliteStorageProvider for testing.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_storage_providers_unified.py b/tests/test_storage_providers_unified.py index 67ae9d7c..14829f03 100644 --- a/tests/test_storage_providers_unified.py +++ b/tests/test_storage_providers_unified.py @@ -16,7 +16,7 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import kplib from typeagent.knowpro.convsettings import ( @@ -52,7 +52,7 @@ def get_knowledge(self) -> KnowledgeResponse: @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def storage_provider_type( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[tuple[IStorageProvider, str], None]: """Parameterized fixture that provides both memory and sqlite storage providers.""" @@ -328,7 +328,7 @@ async def test_conversation_threads_interface_parity( # Cross-provider validation tests @pytest.mark.asyncio async def test_cross_provider_message_collection_equivalence( - embedding_model: AsyncEmbeddingModel, temp_db_path: str, needs_auth: None + embedding_model: IEmbeddingModel, temp_db_path: str, needs_auth: None ): """Test that both providers handle message collections equivalently.""" # Create both providers with identical settings @@ -586,7 +586,7 @@ async def test_timestamp_index_with_data( @pytest.mark.asyncio async def test_storage_provider_independence( - embedding_model: AsyncEmbeddingModel, temp_db_path: str, needs_auth: None + embedding_model: IEmbeddingModel, temp_db_path: str, needs_auth: None ): """Test that different storage provider instances work independently.""" # Create settings shared between providers diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index 9d930034..0d0bfb57 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -6,7 +6,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import AsyncEmbeddingModel, IEmbeddingModel from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH from typeagent.transcripts.transcript import ( @@ -88,7 +88,7 @@ def test_get_transcript_info(): @pytest.fixture def conversation_settings( - needs_auth: None, embedding_model: AsyncEmbeddingModel + needs_auth: None, embedding_model: IEmbeddingModel ) -> ConversationSettings: """Create conversation settings for testing.""" return ConversationSettings(embedding_model) @@ -242,7 +242,7 @@ async def test_transcript_creation(): @pytest.mark.asyncio async def test_transcript_knowledge_extraction_slow( - really_needs_auth: None, embedding_model: AsyncEmbeddingModel + really_needs_auth: None, embedding_model: IEmbeddingModel ): """ Test that knowledge extraction works during transcript ingestion. diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 62abd392..be8d2c07 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -61,8 +61,8 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): assert len(bulk_vector_base) == len(vector_base) np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) - sequential_cache = vector_base._model._embedding_cache - bulk_cache = bulk_vector_base._model._embedding_cache + sequential_cache = vector_base._model._embedding_cache # type: ignore[attr-defined] + bulk_cache = bulk_vector_base._model._embedding_cache # type: ignore[attr-defined] assert set(sequential_cache.keys()) == set(bulk_cache.keys()) for key in keys: np.testing.assert_array_equal(bulk_cache[key], sequential_cache[key]) @@ -85,7 +85,7 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp assert len(vector_base) == len(sample_embeddings) assert ( - vector_base._model._embedding_cache == {} + vector_base._model._embedding_cache == {} # type: ignore[attr-defined] ), "Cache should remain empty when cache=False" @@ -106,7 +106,7 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam assert len(vector_base) == len(sample_embeddings) assert ( - vector_base._model._embedding_cache == {} + vector_base._model._embedding_cache == {} # type: ignore[attr-defined] ), "Cache should remain empty when cache=False" diff --git a/tools/ingest_vtt.py b/tools/ingest_vtt.py index ad567756..8bcef6d7 100644 --- a/tools/ingest_vtt.py +++ b/tools/ingest_vtt.py @@ -24,7 +24,7 @@ from dotenv import load_dotenv import webvtt -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import create_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import ConversationMetadata from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH @@ -203,7 +203,7 @@ async def ingest_vtt_files( if verbose: print("Setting up conversation settings...") try: - embedding_model = AsyncEmbeddingModel(model_name=embedding_name) + embedding_model = create_embedding_model(model_name=embedding_name) settings = ConversationSettings(embedding_model) # Create metadata with the conversation name diff --git a/tools/query.py b/tools/query.py index 3d28a89e..24b1f4c9 100644 --- a/tools/query.py +++ b/tools/query.py @@ -150,7 +150,7 @@ class ProcessingContext: debug2: typing.Literal["none", "diff", "full", "skip"] debug3: typing.Literal["none", "diff", "full", "nice"] debug4: typing.Literal["none", "diff", "full", "nice"] - embedding_model: embeddings.AsyncEmbeddingModel + embedding_model: embeddings.IEmbeddingModel query_translator: typechat.TypeChatJsonTranslator[search_query_schema.SearchQuery] answer_translator: typechat.TypeChatJsonTranslator[ answer_response_schema.AnswerResponse From d59d7b6909484b9d8ac3b99cf455c6cf9bdc394c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 17 Feb 2026 10:05:17 -0800 Subject: [PATCH 02/23] Agent step 2 -- unreviewed --- src/typeagent/aitools/embeddings.py | 9 +- src/typeagent/aitools/model_registry.py | 214 ++++++++++++++++++++++++ src/typeagent/aitools/utils.py | 9 +- tests/test_model_registry.py | 152 +++++++++++++++++ 4 files changed, 377 insertions(+), 7 deletions(-) create mode 100644 src/typeagent/aitools/model_registry.py create mode 100644 tests/test_model_registry.py diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index c3403af2..ce8ea062 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -94,6 +94,7 @@ def __init__( model_name: str | None = None, endpoint_envvar: str | None = None, max_retries: int = DEFAULT_MAX_RETRIES, + use_azure: bool | None = None, ): if model_name is None: model_name = DEFAULT_MODEL_NAME @@ -122,8 +123,12 @@ def __init__( openai_api_key = os.getenv("OPENAI_API_KEY") azure_api_key = os.getenv("AZURE_OPENAI_API_KEY") - # Prefer OpenAI if both are set, use Azure if only Azure is set - self.use_azure = bool(azure_api_key) and not bool(openai_api_key) + # Determine provider: explicit use_azure overrides auto-detection. + if use_azure is not None: + self.use_azure = use_azure + else: + # Prefer OpenAI if both are set, use Azure if only Azure is set + self.use_azure = bool(azure_api_key) and not bool(openai_api_key) if endpoint_envvar is None: # Check if OpenAI credentials are available, prefer OpenAI over Azure diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_registry.py new file mode 100644 index 00000000..7cc03c4b --- /dev/null +++ b/src/typeagent/aitools/model_registry.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Provider-agnostic model configuration. + +Create chat and embedding models from ``provider/model`` spec strings:: + + from typeagent.aitools.model_registry import configure_models + + chat, embedder = configure_models( + "openai/gpt-4o", + "openai/text-embedding-3-small", + ) + +Supported built-in providers +----------------------------- + +* ``openai/`` — requires ``OPENAI_API_KEY`` env var. +* ``azure/`` — requires ``AZURE_OPENAI_API_KEY`` (and + ``AZURE_OPENAI_ENDPOINT``) env vars. For Azure, the *model* part of the + spec is the **deployment name**. + +Extending with new providers +---------------------------- + +Implement ``typechat.TypeChatLanguageModel`` for chat, or +``IEmbeddingModel`` for embeddings, then register a factory:: + + from typeagent.aitools.model_registry import ( + register_chat_provider, + register_embedding_provider, + ) + + register_chat_provider("anthropic", my_anthropic_chat_factory) + register_embedding_provider("gemini", my_gemini_embedding_factory) + +Each factory is a callable ``(model_name: str) -> Model``. +""" + +from collections.abc import Callable +import os + +import typechat + +from .embeddings import AsyncEmbeddingModel, IEmbeddingModel + +# --------------------------------------------------------------------------- +# Spec parsing +# --------------------------------------------------------------------------- + +type ChatModelFactory = Callable[[str], typechat.TypeChatLanguageModel] +type EmbeddingModelFactory = Callable[[str], IEmbeddingModel] + + +def _parse_model_spec(spec: str) -> tuple[str, str]: + """Parse ``'provider/model'`` into ``(provider, model_name)``. + + Raises ``ValueError`` on malformed specs. + """ + parts = spec.split("/", 1) + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError( + f"Invalid model spec {spec!r}. " + f"Expected 'provider/model', e.g. 'openai/gpt-4o'." + ) + return parts[0], parts[1] + + +# --------------------------------------------------------------------------- +# Chat model registry +# --------------------------------------------------------------------------- + +_chat_providers: dict[str, ChatModelFactory] = {} + + +def register_chat_provider(provider: str, factory: ChatModelFactory) -> None: + """Register a factory that creates chat models for *provider*.""" + _chat_providers[provider] = factory + + +def _openai_chat(model_name: str) -> typechat.TypeChatLanguageModel: + env: dict[str, str | None] = dict(os.environ) + if not env.get("OPENAI_API_KEY"): + raise RuntimeError("OPENAI_API_KEY required for openai/ chat models.") + env["OPENAI_MODEL"] = model_name + # Force the OpenAI path even when Azure env vars are also present. + env.pop("AZURE_OPENAI_API_KEY", None) + return typechat.create_language_model(env) + + +def _azure_chat(model_name: str) -> typechat.TypeChatLanguageModel: + from .auth import AzureTokenProvider, get_shared_token_provider + from .utils import DEFAULT_MAX_RETRY_ATTEMPTS, DEFAULT_TIMEOUT_SECONDS, ModelWrapper + + env: dict[str, str | None] = dict(os.environ) + key = env.get("AZURE_OPENAI_API_KEY") + if not key: + raise RuntimeError("AZURE_OPENAI_API_KEY required for azure/ chat models.") + env["OPENAI_MODEL"] = model_name + # Force the Azure path even when OPENAI_API_KEY is also present. + env.pop("OPENAI_API_KEY", None) + + shared_token_provider: AzureTokenProvider | None = None + if isinstance(key, str) and key.lower() == "identity": + shared_token_provider = get_shared_token_provider() + env["AZURE_OPENAI_API_KEY"] = shared_token_provider.get_token() + + model = typechat.create_language_model(env) + model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS + model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS + if shared_token_provider is not None: + model = ModelWrapper(model, shared_token_provider) + return model + + +register_chat_provider("openai", _openai_chat) +register_chat_provider("azure", _azure_chat) + + +# --------------------------------------------------------------------------- +# Embedding model registry +# --------------------------------------------------------------------------- + +_embedding_providers: dict[str, EmbeddingModelFactory] = {} + + +def register_embedding_provider(provider: str, factory: EmbeddingModelFactory) -> None: + """Register a factory that creates embedding models for *provider*.""" + _embedding_providers[provider] = factory + + +def _openai_embedding(model_name: str) -> IEmbeddingModel: + return AsyncEmbeddingModel(model_name=model_name, use_azure=False) + + +def _azure_embedding(model_name: str) -> IEmbeddingModel: + return AsyncEmbeddingModel(model_name=model_name, use_azure=True) + + +register_embedding_provider("openai", _openai_embedding) +register_embedding_provider("azure", _azure_embedding) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def create_chat_model( + model_spec: str, +) -> typechat.TypeChatLanguageModel: + """Create a chat model from a ``provider/model`` spec. + + Examples:: + + model = create_chat_model("openai/gpt-4o") + model = create_chat_model("azure/my-gpt4o-deployment") + + For Azure, *model* is the **deployment name**, not the underlying + model name. + """ + provider, model_name = _parse_model_spec(model_spec) + factory = _chat_providers.get(provider) + if factory is None: + avail = ", ".join(sorted(_chat_providers)) or "(none)" + raise ValueError( + f"Unknown chat provider {provider!r}. " + f"Available: {avail}. " + f"Use register_chat_provider() to add support." + ) + return factory(model_name) + + +def create_embedding_model( + model_spec: str, +) -> IEmbeddingModel: + """Create an embedding model from a ``provider/model`` spec. + + Examples:: + + model = create_embedding_model("openai/text-embedding-3-small") + model = create_embedding_model("azure/text-embedding-3-small") + """ + provider, model_name = _parse_model_spec(model_spec) + factory = _embedding_providers.get(provider) + if factory is None: + avail = ", ".join(sorted(_embedding_providers)) or "(none)" + raise ValueError( + f"Unknown embedding provider {provider!r}. " + f"Available: {avail}. " + f"Use register_embedding_provider() to add support." + ) + return factory(model_name) + + +def configure_models( + chat_model_spec: str, + embedding_model_spec: str, +) -> tuple[typechat.TypeChatLanguageModel, IEmbeddingModel]: + """Configure both a chat model and an embedding model at once. + + Example:: + + chat, embedder = configure_models( + "openai/gpt-4o", + "openai/text-embedding-3-small", + ) + + settings = ConversationSettings(model=embedder) + extractor = KnowledgeExtractor(model=chat) + """ + return create_chat_model(chat_model_spec), create_embedding_model( + embedding_model_spec + ) diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index 2459473c..b475c728 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -125,12 +125,11 @@ async def complete( def create_typechat_model() -> typechat.TypeChatLanguageModel: """Create a TypeChat language model using OpenAI or Azure OpenAI. - Reads ``OPENAI_API_KEY``, ``AZURE_OPENAI_API_KEY`` and related env vars. - Handles Azure ``identity`` token provider for Microsoft internal usage. + Auto-detects the provider from ``OPENAI_API_KEY`` / ``AZURE_OPENAI_API_KEY`` + environment variables. - To use a different provider (e.g. Anthropic, Gemini), implement - ``typechat.TypeChatLanguageModel`` directly and pass it to - ``KnowledgeExtractor`` or ``create_translator()``. + For explicit provider selection, use :func:`model_registry.create_chat_model` + with a spec string like ``"openai/gpt-4o"`` or ``"azure/my-deployment"``. """ env: dict[str, str | None] = dict(os.environ) key_name = "AZURE_OPENAI_API_KEY" diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py new file mode 100644 index 00000000..11b4f71e --- /dev/null +++ b/tests/test_model_registry.py @@ -0,0 +1,152 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +import typechat + +from typeagent.aitools.embeddings import ( + AsyncEmbeddingModel, + IEmbeddingModel, + TEST_MODEL_NAME, +) +from typeagent.aitools.model_registry import ( + _chat_providers, + _embedding_providers, + _parse_model_spec, + configure_models, + create_chat_model, + create_embedding_model, + register_chat_provider, + register_embedding_provider, +) + +# --------------------------------------------------------------------------- +# Spec parsing +# --------------------------------------------------------------------------- + + +def test_parse_valid_specs() -> None: + assert _parse_model_spec("openai/gpt-4o") == ("openai", "gpt-4o") + assert _parse_model_spec("azure/my-deployment") == ("azure", "my-deployment") + assert _parse_model_spec("anthropic/claude-3.5-sonnet") == ( + "anthropic", + "claude-3.5-sonnet", + ) + + +def test_parse_spec_preserves_slashes() -> None: + """Only the first '/' is a separator; the rest belong to the model name.""" + assert _parse_model_spec("provider/model/variant") == ( + "provider", + "model/variant", + ) + + +def test_parse_invalid_specs() -> None: + with pytest.raises(ValueError, match="Invalid model spec"): + _parse_model_spec("noslash") + with pytest.raises(ValueError, match="Invalid model spec"): + _parse_model_spec("/model") + with pytest.raises(ValueError, match="Invalid model spec"): + _parse_model_spec("provider/") + with pytest.raises(ValueError, match="Invalid model spec"): + _parse_model_spec("") + + +# --------------------------------------------------------------------------- +# Built-in registration +# --------------------------------------------------------------------------- + + +def test_builtin_providers_registered() -> None: + assert "openai" in _chat_providers + assert "azure" in _chat_providers + assert "openai" in _embedding_providers + assert "azure" in _embedding_providers + + +# --------------------------------------------------------------------------- +# Unknown provider errors +# --------------------------------------------------------------------------- + + +def test_unknown_chat_provider() -> None: + with pytest.raises(ValueError, match="Unknown chat provider"): + create_chat_model("magical/unicorn") + + +def test_unknown_embedding_provider() -> None: + with pytest.raises(ValueError, match="Unknown embedding provider"): + create_embedding_model("magical/unicorn") + + +# --------------------------------------------------------------------------- +# Custom provider registration +# --------------------------------------------------------------------------- + + +class FakeChatModel(typechat.TypeChatLanguageModel): + """Minimal chat model for registry tests.""" + + async def complete( + self, prompt: str | list[typechat.PromptSection] + ) -> typechat.Result[str]: + return typechat.Success("fake") + + +def test_register_and_use_custom_chat_provider() -> None: + instance = FakeChatModel() + register_chat_provider("_test_chat", lambda name: instance) + try: + result = create_chat_model("_test_chat/any-model") + assert result is instance + finally: + _chat_providers.pop("_test_chat", None) + + +def test_register_and_use_custom_embedding_provider() -> None: + instance = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + register_embedding_provider("_test_embed", lambda name: instance) + try: + result = create_embedding_model("_test_embed/any-model") + assert result is instance + assert isinstance(result, IEmbeddingModel) + finally: + _embedding_providers.pop("_test_embed", None) + + +def test_model_name_forwarded_to_factory() -> None: + """The model portion of the spec is passed to the factory.""" + received: list[str] = [] + + def capture_factory(model_name: str) -> IEmbeddingModel: + received.append(model_name) + return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + + register_embedding_provider("_test_fwd", capture_factory) + try: + create_embedding_model("_test_fwd/text-embedding-3-small") + assert received == ["text-embedding-3-small"] + finally: + _embedding_providers.pop("_test_fwd", None) + + +# --------------------------------------------------------------------------- +# configure_models +# --------------------------------------------------------------------------- + + +def test_configure_models() -> None: + chat_instance = FakeChatModel() + embed_instance = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + + register_chat_provider("_test_cm", lambda name: chat_instance) + register_embedding_provider("_test_cm", lambda name: embed_instance) + try: + chat, embedder = configure_models("_test_cm/chat", "_test_cm/embed") + assert chat is chat_instance + assert embedder is embed_instance + finally: + _chat_providers.pop("_test_cm", None) + _embedding_providers.pop("_test_cm", None) From ff733f50b2d5b41f3457d6f34d0ed8d66b7183a6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 17 Feb 2026 10:08:06 -0800 Subject: [PATCH 03/23] Agent step 3 -- unreviewed -- use Pydantic's model registry --- src/typeagent/aitools/model_registry.py | 294 +++++++++++++----------- tests/test_model_registry.py | 271 ++++++++++++++-------- 2 files changed, 334 insertions(+), 231 deletions(-) diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_registry.py index 7cc03c4b..3b02e461 100644 --- a/src/typeagent/aitools/model_registry.py +++ b/src/typeagent/aitools/model_registry.py @@ -1,144 +1,165 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Provider-agnostic model configuration. +"""Provider-agnostic model configuration backed by pydantic_ai. -Create chat and embedding models from ``provider/model`` spec strings:: +Create chat and embedding models from ``provider:model`` spec strings:: from typeagent.aitools.model_registry import configure_models chat, embedder = configure_models( - "openai/gpt-4o", - "openai/text-embedding-3-small", + "openai:gpt-4o", + "openai:text-embedding-3-small", ) -Supported built-in providers ------------------------------ +The spec format is ``provider:model``, matching pydantic_ai conventions. +Provider wiring (API keys, endpoints, etc.) is handled by pydantic_ai's +model registry, which supports 25+ providers including ``openai``, +``azure``, ``anthropic``, ``google``, ``bedrock``, ``groq``, ``mistral``, +``ollama``, ``cohere``, and many more. -* ``openai/`` — requires ``OPENAI_API_KEY`` env var. -* ``azure/`` — requires ``AZURE_OPENAI_API_KEY`` (and - ``AZURE_OPENAI_ENDPOINT``) env vars. For Azure, the *model* part of the - spec is the **deployment name**. - -Extending with new providers ----------------------------- - -Implement ``typechat.TypeChatLanguageModel`` for chat, or -``IEmbeddingModel`` for embeddings, then register a factory:: - - from typeagent.aitools.model_registry import ( - register_chat_provider, - register_embedding_provider, - ) - - register_chat_provider("anthropic", my_anthropic_chat_factory) - register_embedding_provider("gemini", my_gemini_embedding_factory) - -Each factory is a callable ``(model_name: str) -> Model``. +See https://ai.pydantic.dev/models/ for all supported providers and their +required environment variables. """ -from collections.abc import Callable -import os - +import numpy as np +from numpy.typing import NDArray + +from pydantic_ai import Embedder as _PydanticAIEmbedder +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + SystemPromptPart, + TextPart, + UserPromptPart, +) +from pydantic_ai.models import infer_model, Model, ModelRequestParameters import typechat -from .embeddings import AsyncEmbeddingModel, IEmbeddingModel +from .embeddings import IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings # --------------------------------------------------------------------------- -# Spec parsing +# Known embedding sizes for common models # --------------------------------------------------------------------------- -type ChatModelFactory = Callable[[str], typechat.TypeChatLanguageModel] -type EmbeddingModelFactory = Callable[[str], IEmbeddingModel] - - -def _parse_model_spec(spec: str) -> tuple[str, str]: - """Parse ``'provider/model'`` into ``(provider, model_name)``. - - Raises ``ValueError`` on malformed specs. - """ - parts = spec.split("/", 1) - if len(parts) != 2 or not parts[0] or not parts[1]: - raise ValueError( - f"Invalid model spec {spec!r}. " - f"Expected 'provider/model', e.g. 'openai/gpt-4o'." - ) - return parts[0], parts[1] +_KNOWN_EMBEDDING_SIZES: dict[str, int] = { + "text-embedding-ada-002": 1536, + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "embed-english-v3.0": 1024, + "embed-multilingual-v3.0": 1024, + "embed-english-light-v3.0": 384, + "embed-multilingual-light-v3.0": 384, + "text-embedding-004": 768, + "embedding-001": 768, +} # --------------------------------------------------------------------------- -# Chat model registry +# Chat model adapter # --------------------------------------------------------------------------- -_chat_providers: dict[str, ChatModelFactory] = {} - - -def register_chat_provider(provider: str, factory: ChatModelFactory) -> None: - """Register a factory that creates chat models for *provider*.""" - _chat_providers[provider] = factory - - -def _openai_chat(model_name: str) -> typechat.TypeChatLanguageModel: - env: dict[str, str | None] = dict(os.environ) - if not env.get("OPENAI_API_KEY"): - raise RuntimeError("OPENAI_API_KEY required for openai/ chat models.") - env["OPENAI_MODEL"] = model_name - # Force the OpenAI path even when Azure env vars are also present. - env.pop("AZURE_OPENAI_API_KEY", None) - return typechat.create_language_model(env) +class PydanticAIChatModel: + """Adapter from :class:`pydantic_ai.models.Model` to TypeChat's + :class:`~typechat.TypeChatLanguageModel`. -def _azure_chat(model_name: str) -> typechat.TypeChatLanguageModel: - from .auth import AzureTokenProvider, get_shared_token_provider - from .utils import DEFAULT_MAX_RETRY_ATTEMPTS, DEFAULT_TIMEOUT_SECONDS, ModelWrapper - - env: dict[str, str | None] = dict(os.environ) - key = env.get("AZURE_OPENAI_API_KEY") - if not key: - raise RuntimeError("AZURE_OPENAI_API_KEY required for azure/ chat models.") - env["OPENAI_MODEL"] = model_name - # Force the Azure path even when OPENAI_API_KEY is also present. - env.pop("OPENAI_API_KEY", None) + This lets any pydantic_ai chat model (OpenAI, Anthropic, Google, …) be + used wherever TypeChat expects a ``TypeChatLanguageModel``. + """ - shared_token_provider: AzureTokenProvider | None = None - if isinstance(key, str) and key.lower() == "identity": - shared_token_provider = get_shared_token_provider() - env["AZURE_OPENAI_API_KEY"] = shared_token_provider.get_token() + def __init__(self, model: Model) -> None: + self._model = model - model = typechat.create_language_model(env) - model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS - if shared_token_provider is not None: - model = ModelWrapper(model, shared_token_provider) - return model + async def complete( + self, prompt: str | list[typechat.PromptSection] + ) -> typechat.Result[str]: + parts: list[SystemPromptPart | UserPromptPart] = [] + if isinstance(prompt, str): + parts.append(UserPromptPart(content=prompt)) + else: + for section in prompt: + if section["role"] == "system": + parts.append(SystemPromptPart(content=section["content"])) + else: + parts.append(UserPromptPart(content=section["content"])) + messages: list[ModelMessage] = [ModelRequest(parts=parts)] + params = ModelRequestParameters() -register_chat_provider("openai", _openai_chat) -register_chat_provider("azure", _azure_chat) + response = await self._model.request(messages, None, params) + text_parts = [p.content for p in response.parts if isinstance(p, TextPart)] + if text_parts: + return typechat.Success("".join(text_parts)) + return typechat.Failure("No text content in model response") # --------------------------------------------------------------------------- -# Embedding model registry +# Embedding model adapter # --------------------------------------------------------------------------- -_embedding_providers: dict[str, EmbeddingModelFactory] = {} - - -def register_embedding_provider(provider: str, factory: EmbeddingModelFactory) -> None: - """Register a factory that creates embedding models for *provider*.""" - _embedding_providers[provider] = factory +class PydanticAIEmbeddingModel: + """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbeddingModel`. -def _openai_embedding(model_name: str) -> IEmbeddingModel: - return AsyncEmbeddingModel(model_name=model_name, use_azure=False) - - -def _azure_embedding(model_name: str) -> IEmbeddingModel: - return AsyncEmbeddingModel(model_name=model_name, use_azure=True) - + This lets any pydantic_ai embedding provider (OpenAI, Cohere, Google, …) + be used wherever the codebase expects an ``IEmbeddingModel``, including + :class:`~typeagent.aitools.vectorbase.VectorBase` and + :class:`~typeagent.knowpro.convsettings.ConversationSettings`. + """ -register_embedding_provider("openai", _openai_embedding) -register_embedding_provider("azure", _azure_embedding) + model_name: str + embedding_size: int + + def __init__( + self, + embedder: _PydanticAIEmbedder, + model_name: str, + embedding_size: int, + ) -> None: + self._embedder = embedder + self.model_name = model_name + self.embedding_size = embedding_size + self._cache: dict[str, NormalizedEmbedding] = {} + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + self._cache[key] = embedding + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + result = await self._embedder.embed([input], input_type="document") + embedding: NDArray[np.float32] = np.array( + result.embeddings[0], dtype=np.float32 + ) + norm = float(np.linalg.norm(embedding)) + if norm > 0: + embedding = (embedding / norm).astype(np.float32) + return embedding + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + if not input: + return np.empty((0, self.embedding_size), dtype=np.float32) + result = await self._embedder.embed(input, input_type="document") + embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) + norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) + norms = np.where(norms > 0, norms, np.float32(1.0)) + embeddings = (embeddings / norms).astype(np.float32) + return embeddings + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + cached = self._cache.get(key) + if cached is not None: + return cached + embedding = await self.get_embedding_nocache(key) + self._cache[key] = embedding + return embedding + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + missing_keys = [k for k in keys if k not in self._cache] + if missing_keys: + fresh = await self.get_embeddings_nocache(missing_keys) + for i, k in enumerate(missing_keys): + self._cache[k] = fresh[i] + return np.array([self._cache[k] for k in keys], dtype=np.float32) # --------------------------------------------------------------------------- @@ -149,66 +170,71 @@ def _azure_embedding(model_name: str) -> IEmbeddingModel: def create_chat_model( model_spec: str, ) -> typechat.TypeChatLanguageModel: - """Create a chat model from a ``provider/model`` spec. + """Create a chat model from a ``provider:model`` spec. - Examples:: + Delegates to :func:`pydantic_ai.models.infer_model` for provider wiring. - model = create_chat_model("openai/gpt-4o") - model = create_chat_model("azure/my-gpt4o-deployment") + Examples:: - For Azure, *model* is the **deployment name**, not the underlying - model name. + model = create_chat_model("openai:gpt-4o") + model = create_chat_model("anthropic:claude-sonnet-4-20250514") + model = create_chat_model("google:gemini-2.0-flash") """ - provider, model_name = _parse_model_spec(model_spec) - factory = _chat_providers.get(provider) - if factory is None: - avail = ", ".join(sorted(_chat_providers)) or "(none)" - raise ValueError( - f"Unknown chat provider {provider!r}. " - f"Available: {avail}. " - f"Use register_chat_provider() to add support." - ) - return factory(model_name) + model = infer_model(model_spec) + return PydanticAIChatModel(model) def create_embedding_model( model_spec: str, + *, + embedding_size: int | None = None, ) -> IEmbeddingModel: - """Create an embedding model from a ``provider/model`` spec. + """Create an embedding model from a ``provider:model`` spec. + + Delegates to :class:`pydantic_ai.Embedder` for provider wiring. + + If *embedding_size* is not given, it is looked up in a table of common + models. For unknown models, pass it explicitly. Examples:: - model = create_embedding_model("openai/text-embedding-3-small") - model = create_embedding_model("azure/text-embedding-3-small") + model = create_embedding_model("openai:text-embedding-3-small") + model = create_embedding_model("cohere:embed-english-v3.0") + model = create_embedding_model("google:text-embedding-004") """ - provider, model_name = _parse_model_spec(model_spec) - factory = _embedding_providers.get(provider) - if factory is None: - avail = ", ".join(sorted(_embedding_providers)) or "(none)" + model_name = model_spec.split(":")[-1] if ":" in model_spec else model_spec + if embedding_size is None: + embedding_size = _KNOWN_EMBEDDING_SIZES.get(model_name) + if embedding_size is None: raise ValueError( - f"Unknown embedding provider {provider!r}. " - f"Available: {avail}. " - f"Use register_embedding_provider() to add support." + f"Unknown embedding size for model {model_name!r}. " + f"Pass embedding_size= explicitly." ) - return factory(model_name) + embedder = _PydanticAIEmbedder(model_spec) + return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) def configure_models( chat_model_spec: str, embedding_model_spec: str, + *, + embedding_size: int | None = None, ) -> tuple[typechat.TypeChatLanguageModel, IEmbeddingModel]: """Configure both a chat model and an embedding model at once. + Delegates to pydantic_ai's model registry for provider wiring. + Example:: chat, embedder = configure_models( - "openai/gpt-4o", - "openai/text-embedding-3-small", + "openai:gpt-4o", + "openai:text-embedding-3-small", ) settings = ConversationSettings(model=embedder) extractor = KnowledgeExtractor(model=chat) """ - return create_chat_model(chat_model_spec), create_embedding_model( - embedding_model_spec + return ( + create_chat_model(chat_model_spec), + create_embedding_model(embedding_model_spec, embedding_size=embedding_size), ) diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index 11b4f71e..d18e071d 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -1,135 +1,218 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import numpy as np import pytest import typechat -from typeagent.aitools.embeddings import ( - AsyncEmbeddingModel, - IEmbeddingModel, - TEST_MODEL_NAME, -) +from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding from typeagent.aitools.model_registry import ( - _chat_providers, - _embedding_providers, - _parse_model_spec, + _KNOWN_EMBEDDING_SIZES, configure_models, create_chat_model, create_embedding_model, - register_chat_provider, - register_embedding_provider, + PydanticAIChatModel, + PydanticAIEmbeddingModel, ) # --------------------------------------------------------------------------- -# Spec parsing +# Spec format # --------------------------------------------------------------------------- -def test_parse_valid_specs() -> None: - assert _parse_model_spec("openai/gpt-4o") == ("openai", "gpt-4o") - assert _parse_model_spec("azure/my-deployment") == ("azure", "my-deployment") - assert _parse_model_spec("anthropic/claude-3.5-sonnet") == ( - "anthropic", - "claude-3.5-sonnet", - ) +def test_spec_uses_colon_separator() -> None: + """Specs use ``provider:model`` format matching pydantic_ai conventions.""" + with pytest.raises(Exception): + # A nonsense provider should fail + create_chat_model("nonexistent_provider_xyz:fake-model") -def test_parse_spec_preserves_slashes() -> None: - """Only the first '/' is a separator; the rest belong to the model name.""" - assert _parse_model_spec("provider/model/variant") == ( - "provider", - "model/variant", - ) +# --------------------------------------------------------------------------- +# Known embedding sizes +# --------------------------------------------------------------------------- -def test_parse_invalid_specs() -> None: - with pytest.raises(ValueError, match="Invalid model spec"): - _parse_model_spec("noslash") - with pytest.raises(ValueError, match="Invalid model spec"): - _parse_model_spec("/model") - with pytest.raises(ValueError, match="Invalid model spec"): - _parse_model_spec("provider/") - with pytest.raises(ValueError, match="Invalid model spec"): - _parse_model_spec("") +def test_known_embedding_sizes() -> None: + assert _KNOWN_EMBEDDING_SIZES["text-embedding-3-small"] == 1536 + assert _KNOWN_EMBEDDING_SIZES["text-embedding-3-large"] == 3072 + assert _KNOWN_EMBEDDING_SIZES["text-embedding-ada-002"] == 1536 -# --------------------------------------------------------------------------- -# Built-in registration -# --------------------------------------------------------------------------- +def test_unknown_embedding_size_raises() -> None: + with pytest.raises(ValueError, match="Unknown embedding size"): + create_embedding_model("openai:completely-unknown-model-xyz") -def test_builtin_providers_registered() -> None: - assert "openai" in _chat_providers - assert "azure" in _chat_providers - assert "openai" in _embedding_providers - assert "azure" in _embedding_providers +def test_explicit_embedding_size() -> None: + """Passing embedding_size= bypasses the lookup table.""" + # This should not raise even though the model name is unknown + model = create_embedding_model( + "openai:completely-unknown-model-xyz", embedding_size=42 + ) + assert model.embedding_size == 42 # --------------------------------------------------------------------------- -# Unknown provider errors +# PydanticAIChatModel adapter # --------------------------------------------------------------------------- -def test_unknown_chat_provider() -> None: - with pytest.raises(ValueError, match="Unknown chat provider"): - create_chat_model("magical/unicorn") +@pytest.mark.asyncio +async def test_chat_adapter_complete() -> None: + """PydanticAIChatModel wraps a pydantic_ai Model.""" + from unittest.mock import AsyncMock + + from pydantic_ai.messages import ModelResponse, TextPart + from pydantic_ai.models import Model + + mock_model = AsyncMock(spec=Model) + mock_model.request.return_value = ModelResponse(parts=[TextPart(content="hello")]) + + adapter = PydanticAIChatModel(mock_model) + result = await adapter.complete("test prompt") + assert isinstance(result, typechat.Success) + assert result.value == "hello" + + +@pytest.mark.asyncio +async def test_chat_adapter_prompt_sections() -> None: + """PydanticAIChatModel handles list[PromptSection] prompts.""" + from unittest.mock import AsyncMock + + from pydantic_ai.messages import ModelResponse, TextPart + from pydantic_ai.models import Model + + mock_model = AsyncMock(spec=Model) + mock_model.request.return_value = ModelResponse( + parts=[TextPart(content="response")] + ) + + adapter = PydanticAIChatModel(mock_model) + sections: list[typechat.PromptSection] = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + result = await adapter.complete(sections) + assert isinstance(result, typechat.Success) + assert result.value == "response" + # Verify the request was called with proper message structure + call_args = mock_model.request.call_args + messages = call_args[0][0] + assert len(messages) == 1 + request = messages[0] + from pydantic_ai.messages import SystemPromptPart, UserPromptPart -def test_unknown_embedding_provider() -> None: - with pytest.raises(ValueError, match="Unknown embedding provider"): - create_embedding_model("magical/unicorn") + assert isinstance(request.parts[0], SystemPromptPart) + assert isinstance(request.parts[1], UserPromptPart) # --------------------------------------------------------------------------- -# Custom provider registration +# PydanticAIEmbeddingModel adapter # --------------------------------------------------------------------------- -class FakeChatModel(typechat.TypeChatLanguageModel): - """Minimal chat model for registry tests.""" +@pytest.mark.asyncio +async def test_embedding_adapter_single() -> None: + """PydanticAIEmbeddingModel computes a single normalized embedding.""" + from unittest.mock import AsyncMock - async def complete( - self, prompt: str | list[typechat.PromptSection] - ) -> typechat.Result[str]: - return typechat.Success("fake") + from pydantic_ai import Embedder + from pydantic_ai.embeddings import EmbeddingResult + mock_embedder = AsyncMock(spec=Embedder) + raw_vec = [3.0, 4.0, 0.0] + mock_embedder.embed.return_value = EmbeddingResult( + embeddings=[raw_vec], + inputs=["test"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + result = await adapter.get_embedding_nocache("test") + assert result.shape == (3,) + norm = float(np.linalg.norm(result)) + assert abs(norm - 1.0) < 1e-6 + + +@pytest.mark.asyncio +async def test_embedding_adapter_batch() -> None: + """PydanticAIEmbeddingModel computes batch embeddings.""" + from unittest.mock import AsyncMock + + from pydantic_ai import Embedder + from pydantic_ai.embeddings import EmbeddingResult + + mock_embedder = AsyncMock(spec=Embedder) + mock_embedder.embed.return_value = EmbeddingResult( + embeddings=[[1.0, 0.0], [0.0, 1.0]], + inputs=["a", "b"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 2) + result = await adapter.get_embeddings_nocache(["a", "b"]) + assert result.shape == (2, 2) -def test_register_and_use_custom_chat_provider() -> None: - instance = FakeChatModel() - register_chat_provider("_test_chat", lambda name: instance) - try: - result = create_chat_model("_test_chat/any-model") - assert result is instance - finally: - _chat_providers.pop("_test_chat", None) +@pytest.mark.asyncio +async def test_embedding_adapter_caching() -> None: + """Caching avoids re-computing embeddings.""" + from unittest.mock import AsyncMock -def test_register_and_use_custom_embedding_provider() -> None: - instance = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - register_embedding_provider("_test_embed", lambda name: instance) - try: - result = create_embedding_model("_test_embed/any-model") - assert result is instance - assert isinstance(result, IEmbeddingModel) - finally: - _embedding_providers.pop("_test_embed", None) + from pydantic_ai import Embedder + from pydantic_ai.embeddings import EmbeddingResult + mock_embedder = AsyncMock(spec=Embedder) + mock_embedder.embed.return_value = EmbeddingResult( + embeddings=[[1.0, 0.0, 0.0]], + inputs=["cached"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + first = await adapter.get_embedding("cached") + second = await adapter.get_embedding("cached") + np.testing.assert_array_equal(first, second) + # embed() should only be called once + assert mock_embedder.embed.call_count == 1 + + +@pytest.mark.asyncio +async def test_embedding_adapter_add_embedding() -> None: + """add_embedding() populates the cache.""" + from unittest.mock import AsyncMock -def test_model_name_forwarded_to_factory() -> None: - """The model portion of the spec is passed to the factory.""" - received: list[str] = [] + from pydantic_ai import Embedder - def capture_factory(model_name: str) -> IEmbeddingModel: - received.append(model_name) - return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + mock_embedder = AsyncMock(spec=Embedder) + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + vec: NormalizedEmbedding = np.array([1.0, 0.0, 0.0], dtype=np.float32) + adapter.add_embedding("key", vec) + result = await adapter.get_embedding("key") + np.testing.assert_array_equal(result, vec) + # No embed() call needed + mock_embedder.embed.assert_not_called() - register_embedding_provider("_test_fwd", capture_factory) - try: - create_embedding_model("_test_fwd/text-embedding-3-small") - assert received == ["text-embedding-3-small"] - finally: - _embedding_providers.pop("_test_fwd", None) + +@pytest.mark.asyncio +async def test_embedding_adapter_empty_batch() -> None: + """Empty batch returns empty array.""" + from unittest.mock import AsyncMock + + from pydantic_ai import Embedder + + mock_embedder = AsyncMock(spec=Embedder) + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 4) + result = await adapter.get_embeddings_nocache([]) + assert result.shape == (0, 4) # --------------------------------------------------------------------------- @@ -137,16 +220,10 @@ def capture_factory(model_name: str) -> IEmbeddingModel: # --------------------------------------------------------------------------- -def test_configure_models() -> None: - chat_instance = FakeChatModel() - embed_instance = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - - register_chat_provider("_test_cm", lambda name: chat_instance) - register_embedding_provider("_test_cm", lambda name: embed_instance) - try: - chat, embedder = configure_models("_test_cm/chat", "_test_cm/embed") - assert chat is chat_instance - assert embedder is embed_instance - finally: - _chat_providers.pop("_test_cm", None) - _embedding_providers.pop("_test_cm", None) +def test_configure_models_returns_correct_types() -> None: + """configure_models creates both adapters.""" + chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") + assert isinstance(chat, PydanticAIChatModel) + assert isinstance(embedder, PydanticAIEmbeddingModel) + assert isinstance(embedder, IEmbeddingModel) + assert embedder.embedding_size == 1536 From 1015737b8eb6242609cc2f06350b3834d45e6617 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Feb 2026 08:21:32 -0800 Subject: [PATCH 04/23] Don't hardcode an incomplete table of embedding sizes --- src/typeagent/aitools/model_registry.py | 56 +++++++++------------- tests/test_model_registry.py | 62 ++++++++++++++++++------- 2 files changed, 67 insertions(+), 51 deletions(-) diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_registry.py index 3b02e461..b83db9a7 100644 --- a/src/typeagent/aitools/model_registry.py +++ b/src/typeagent/aitools/model_registry.py @@ -38,29 +38,12 @@ from .embeddings import IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings -# --------------------------------------------------------------------------- -# Known embedding sizes for common models -# --------------------------------------------------------------------------- - -_KNOWN_EMBEDDING_SIZES: dict[str, int] = { - "text-embedding-ada-002": 1536, - "text-embedding-3-small": 1536, - "text-embedding-3-large": 3072, - "embed-english-v3.0": 1024, - "embed-multilingual-v3.0": 1024, - "embed-english-light-v3.0": 384, - "embed-multilingual-light-v3.0": 384, - "text-embedding-004": 768, - "embedding-001": 768, -} - - # --------------------------------------------------------------------------- # Chat model adapter # --------------------------------------------------------------------------- -class PydanticAIChatModel: +class PydanticAIChatModel(typechat.TypeChatLanguageModel): """Adapter from :class:`pydantic_ai.models.Model` to TypeChat's :class:`~typechat.TypeChatLanguageModel`. @@ -99,13 +82,16 @@ async def complete( # --------------------------------------------------------------------------- -class PydanticAIEmbeddingModel: +class PydanticAIEmbeddingModel(IEmbeddingModel): """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbeddingModel`. This lets any pydantic_ai embedding provider (OpenAI, Cohere, Google, …) be used wherever the codebase expects an ``IEmbeddingModel``, including :class:`~typeagent.aitools.vectorbase.VectorBase` and :class:`~typeagent.knowpro.convsettings.ConversationSettings`. + + If *embedding_size* is not given, it is probed automatically by making a + single embedding call. """ model_name: str @@ -115,7 +101,7 @@ def __init__( self, embedder: _PydanticAIEmbedder, model_name: str, - embedding_size: int, + embedding_size: int = 0, ) -> None: self._embedder = embedder self.model_name = model_name @@ -125,11 +111,18 @@ def __init__( def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: self._cache[key] = embedding + async def _probe_embedding_size(self) -> None: + """Discover embedding_size by making a single API call.""" + result = await self._embedder.embed(["probe"], input_type="document") + self.embedding_size = len(result.embeddings[0]) + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: result = await self._embedder.embed([input], input_type="document") embedding: NDArray[np.float32] = np.array( result.embeddings[0], dtype=np.float32 ) + if self.embedding_size == 0: + self.embedding_size = len(embedding) norm = float(np.linalg.norm(embedding)) if norm > 0: embedding = (embedding / norm).astype(np.float32) @@ -137,9 +130,13 @@ async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: if not input: + if self.embedding_size == 0: + await self._probe_embedding_size() return np.empty((0, self.embedding_size), dtype=np.float32) result = await self._embedder.embed(input, input_type="document") embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) + if self.embedding_size == 0: + self.embedding_size = embeddings.shape[1] norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) norms = np.where(norms > 0, norms, np.float32(1.0)) embeddings = (embeddings / norms).astype(np.float32) @@ -187,14 +184,14 @@ def create_chat_model( def create_embedding_model( model_spec: str, *, - embedding_size: int | None = None, -) -> IEmbeddingModel: + embedding_size: int = 0, +) -> PydanticAIEmbeddingModel: """Create an embedding model from a ``provider:model`` spec. Delegates to :class:`pydantic_ai.Embedder` for provider wiring. - If *embedding_size* is not given, it is looked up in a table of common - models. For unknown models, pass it explicitly. + If *embedding_size* is not given, it will be probed automatically + on the first embedding call. Examples:: @@ -203,13 +200,6 @@ def create_embedding_model( model = create_embedding_model("google:text-embedding-004") """ model_name = model_spec.split(":")[-1] if ":" in model_spec else model_spec - if embedding_size is None: - embedding_size = _KNOWN_EMBEDDING_SIZES.get(model_name) - if embedding_size is None: - raise ValueError( - f"Unknown embedding size for model {model_name!r}. " - f"Pass embedding_size= explicitly." - ) embedder = _PydanticAIEmbedder(model_spec) return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) @@ -218,8 +208,8 @@ def configure_models( chat_model_spec: str, embedding_model_spec: str, *, - embedding_size: int | None = None, -) -> tuple[typechat.TypeChatLanguageModel, IEmbeddingModel]: + embedding_size: int = 0, +) -> tuple[PydanticAIChatModel, PydanticAIEmbeddingModel]: """Configure both a chat model and an embedding model at once. Delegates to pydantic_ai's model registry for provider wiring. diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index d18e071d..d6bcf5ae 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -8,7 +8,6 @@ from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding from typeagent.aitools.model_registry import ( - _KNOWN_EMBEDDING_SIZES, configure_models, create_chat_model, create_embedding_model, @@ -29,35 +28,34 @@ def test_spec_uses_colon_separator() -> None: # --------------------------------------------------------------------------- -# Known embedding sizes +# Embedding size # --------------------------------------------------------------------------- -def test_known_embedding_sizes() -> None: - assert _KNOWN_EMBEDDING_SIZES["text-embedding-3-small"] == 1536 - assert _KNOWN_EMBEDDING_SIZES["text-embedding-3-large"] == 3072 - assert _KNOWN_EMBEDDING_SIZES["text-embedding-ada-002"] == 1536 - - -def test_unknown_embedding_size_raises() -> None: - with pytest.raises(ValueError, match="Unknown embedding size"): - create_embedding_model("openai:completely-unknown-model-xyz") - - def test_explicit_embedding_size() -> None: - """Passing embedding_size= bypasses the lookup table.""" - # This should not raise even though the model name is unknown + """Passing embedding_size= sets it immediately.""" model = create_embedding_model( - "openai:completely-unknown-model-xyz", embedding_size=42 + "openai:text-embedding-3-small", embedding_size=42 ) assert model.embedding_size == 42 +def test_default_embedding_size_is_zero() -> None: + """Without embedding_size=, it defaults to 0 (probed on first call).""" + model = create_embedding_model("openai:text-embedding-3-small") + assert model.embedding_size == 0 + + # --------------------------------------------------------------------------- # PydanticAIChatModel adapter # --------------------------------------------------------------------------- +def test_chat_model_is_typechat_model() -> None: + """PydanticAIChatModel inherits from TypeChatLanguageModel.""" + assert issubclass(PydanticAIChatModel, typechat.TypeChatLanguageModel) + + @pytest.mark.asyncio async def test_chat_adapter_complete() -> None: """PydanticAIChatModel wraps a pydantic_ai Model.""" @@ -113,6 +111,11 @@ async def test_chat_adapter_prompt_sections() -> None: # --------------------------------------------------------------------------- +def test_embedding_model_is_iembedding_model() -> None: + """PydanticAIEmbeddingModel inherits from IEmbeddingModel.""" + assert issubclass(PydanticAIEmbeddingModel, IEmbeddingModel) + + @pytest.mark.asyncio async def test_embedding_adapter_single() -> None: """PydanticAIEmbeddingModel computes a single normalized embedding.""" @@ -138,6 +141,29 @@ async def test_embedding_adapter_single() -> None: assert abs(norm - 1.0) < 1e-6 +@pytest.mark.asyncio +async def test_embedding_adapter_probes_size() -> None: + """embedding_size is discovered from the first embedding call.""" + from unittest.mock import AsyncMock + + from pydantic_ai import Embedder + from pydantic_ai.embeddings import EmbeddingResult + + mock_embedder = AsyncMock(spec=Embedder) + mock_embedder.embed.return_value = EmbeddingResult( + embeddings=[[1.0, 0.0, 0.0]], + inputs=["probe"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model") + assert adapter.embedding_size == 0 + await adapter.get_embedding_nocache("probe") + assert adapter.embedding_size == 3 + + @pytest.mark.asyncio async def test_embedding_adapter_batch() -> None: """PydanticAIEmbeddingModel computes batch embeddings.""" @@ -204,7 +230,7 @@ async def test_embedding_adapter_add_embedding() -> None: @pytest.mark.asyncio async def test_embedding_adapter_empty_batch() -> None: - """Empty batch returns empty array.""" + """Empty batch returns empty array with known size.""" from unittest.mock import AsyncMock from pydantic_ai import Embedder @@ -226,4 +252,4 @@ def test_configure_models_returns_correct_types() -> None: assert isinstance(chat, PydanticAIChatModel) assert isinstance(embedder, PydanticAIEmbeddingModel) assert isinstance(embedder, IEmbeddingModel) - assert embedder.embedding_size == 1536 + assert isinstance(chat, typechat.TypeChatLanguageModel) From 4bd1387fd884831b4b2dfd0a363ac3471d23cf01 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Feb 2026 08:36:03 -0800 Subject: [PATCH 05/23] Fix test failures --- src/typeagent/aitools/model_registry.py | 2 +- tests/test_model_registry.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_registry.py index b83db9a7..fcb55480 100644 --- a/src/typeagent/aitools/model_registry.py +++ b/src/typeagent/aitools/model_registry.py @@ -166,7 +166,7 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: def create_chat_model( model_spec: str, -) -> typechat.TypeChatLanguageModel: +) -> PydanticAIChatModel: """Create a chat model from a ``provider:model`` spec. Delegates to :func:`pydantic_ai.models.infer_model` for provider wiring. diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index d6bcf5ae..73217f7a 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -34,9 +34,7 @@ def test_spec_uses_colon_separator() -> None: def test_explicit_embedding_size() -> None: """Passing embedding_size= sets it immediately.""" - model = create_embedding_model( - "openai:text-embedding-3-small", embedding_size=42 - ) + model = create_embedding_model("openai:text-embedding-3-small", embedding_size=42) assert model.embedding_size == 42 @@ -53,7 +51,7 @@ def test_default_embedding_size_is_zero() -> None: def test_chat_model_is_typechat_model() -> None: """PydanticAIChatModel inherits from TypeChatLanguageModel.""" - assert issubclass(PydanticAIChatModel, typechat.TypeChatLanguageModel) + assert typechat.TypeChatLanguageModel in PydanticAIChatModel.__mro__ @pytest.mark.asyncio @@ -113,7 +111,7 @@ async def test_chat_adapter_prompt_sections() -> None: def test_embedding_model_is_iembedding_model() -> None: """PydanticAIEmbeddingModel inherits from IEmbeddingModel.""" - assert issubclass(PydanticAIEmbeddingModel, IEmbeddingModel) + assert IEmbeddingModel in PydanticAIEmbeddingModel.__mro__ @pytest.mark.asyncio @@ -251,5 +249,4 @@ def test_configure_models_returns_correct_types() -> None: chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") assert isinstance(chat, PydanticAIChatModel) assert isinstance(embedder, PydanticAIEmbeddingModel) - assert isinstance(embedder, IEmbeddingModel) - assert isinstance(chat, typechat.TypeChatLanguageModel) + assert typechat.TypeChatLanguageModel in type(chat).__mro__ From 60aa4031902276fa9678bd109debd94e03b6818c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Feb 2026 08:41:11 -0800 Subject: [PATCH 06/23] Rename model_registry -> model_adapters --- .../aitools/{model_registry.py => model_adapters.py} | 2 +- src/typeagent/aitools/utils.py | 4 ++-- tests/{test_model_registry.py => test_model_adapters.py} | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename src/typeagent/aitools/{model_registry.py => model_adapters.py} (99%) rename tests/{test_model_registry.py => test_model_adapters.py} (99%) diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_adapters.py similarity index 99% rename from src/typeagent/aitools/model_registry.py rename to src/typeagent/aitools/model_adapters.py index fcb55480..6d8f06a9 100644 --- a/src/typeagent/aitools/model_registry.py +++ b/src/typeagent/aitools/model_adapters.py @@ -5,7 +5,7 @@ Create chat and embedding models from ``provider:model`` spec strings:: - from typeagent.aitools.model_registry import configure_models + from typeagent.aitools.model_adapters import configure_models chat, embedder = configure_models( "openai:gpt-4o", diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index b475c728..9d57b854 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -128,8 +128,8 @@ def create_typechat_model() -> typechat.TypeChatLanguageModel: Auto-detects the provider from ``OPENAI_API_KEY`` / ``AZURE_OPENAI_API_KEY`` environment variables. - For explicit provider selection, use :func:`model_registry.create_chat_model` - with a spec string like ``"openai/gpt-4o"`` or ``"azure/my-deployment"``. + For explicit provider selection, use :func:`model_adapters.create_chat_model` + with a spec string like ``"openai:gpt-4o"`` or ``"azure:my-deployment"``. """ env: dict[str, str | None] = dict(os.environ) key_name = "AZURE_OPENAI_API_KEY" diff --git a/tests/test_model_registry.py b/tests/test_model_adapters.py similarity index 99% rename from tests/test_model_registry.py rename to tests/test_model_adapters.py index 73217f7a..33bd78c7 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_adapters.py @@ -7,7 +7,7 @@ import typechat from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding -from typeagent.aitools.model_registry import ( +from typeagent.aitools.model_adapters import ( configure_models, create_chat_model, create_embedding_model, From 067f3b92fd2c0e7ef8bd8aaee701498da58de265 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 24 Feb 2026 13:25:38 -0800 Subject: [PATCH 07/23] Move pydantic-ai to main deps --- pyproject.toml | 2 +- uv.lock | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 848a497d..94aca756 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "numpy>=2.2.6", "openai>=1.81.0", "pydantic>=2.11.4", + "pydantic-ai-slim[openai]>=1.39.0", "pyreadline3>=3.5.4 ; sys_platform == 'win32'", "python-dotenv>=1.1.0", "tiktoken>=0.12.0", @@ -87,7 +88,6 @@ dev = [ "isort>=7.0.0", "logfire>=4.1.0", # So 'make check' passes "opentelemetry-instrumentation-httpx>=0.57b0", - "pydantic-ai-slim[openai]>=1.39.0", "pyright>=1.1.408", # 407 has a regression "pytest>=8.3.5", "pytest-asyncio>=0.26.0", diff --git a/uv.lock b/uv.lock index ce23abdc..a901acd1 100644 --- a/uv.lock +++ b/uv.lock @@ -1821,6 +1821,7 @@ dependencies = [ { name = "numpy" }, { name = "openai" }, { name = "pydantic" }, + { name = "pydantic-ai-slim", extra = ["openai"] }, { name = "pyreadline3", marker = "sys_platform == 'win32'" }, { name = "python-dotenv" }, { name = "tiktoken" }, @@ -1845,7 +1846,6 @@ dev = [ { name = "isort" }, { name = "logfire" }, { name = "opentelemetry-instrumentation-httpx" }, - { name = "pydantic-ai-slim", extra = ["openai"] }, { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -1863,6 +1863,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.81.0" }, { name = "opentelemetry-instrumentation-httpx", marker = "extra == 'logfire'", specifier = ">=0.57b0" }, { name = "pydantic", specifier = ">=2.11.4" }, + { name = "pydantic-ai-slim", extras = ["openai"], specifier = ">=1.39.0" }, { name = "pyreadline3", marker = "sys_platform == 'win32'", specifier = ">=3.5.4" }, { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "tiktoken", specifier = ">=0.12.0" }, @@ -1882,7 +1883,6 @@ dev = [ { name = "isort", specifier = ">=7.0.0" }, { name = "logfire", specifier = ">=4.1.0" }, { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.57b0" }, - { name = "pydantic-ai-slim", extras = ["openai"], specifier = ">=1.39.0" }, { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.3.5" }, { name = "pytest-asyncio", specifier = ">=0.26.0" }, From 17b959fbffea9fdd9d679dc489e9b822993ed9b2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 24 Feb 2026 16:12:24 -0800 Subject: [PATCH 08/23] Remove obsolete create_embedding_model -- wasn't easy --- src/typeagent/aitools/embeddings.py | 26 --------------------- src/typeagent/aitools/model_adapters.py | 8 ++++++- src/typeagent/aitools/vectorbase.py | 29 ++++++++++++++++++++---- src/typeagent/knowpro/convsettings.py | 3 ++- src/typeagent/knowpro/fuzzyindex.py | 6 ++--- src/typeagent/knowpro/serialization.py | 15 ++++++++++++ src/typeagent/podcasts/podcast.py | 15 +++++++----- src/typeagent/storage/sqlite/provider.py | 9 +++++--- src/typeagent/transcripts/transcript.py | 15 +++++++----- tests/test_serialization.py | 2 +- tools/ingest_vtt.py | 7 ++++-- 11 files changed, 82 insertions(+), 53 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index ce8ea062..5aacaafa 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -350,29 +350,3 @@ async def truncate_input(self, input: str) -> tuple[str, int]: return self.encoding.decode(truncated_tokens), self.max_chunk_size else: return input, len(tokens) - - -def create_embedding_model( - embedding_size: int | None = None, - model_name: str | None = None, - **kwargs, -) -> IEmbeddingModel: - """Create an embedding model using OpenAI/Azure OpenAI. - - This is the default factory. To use a different provider, create an - instance of a class that implements ``IEmbeddingModel`` and pass it - directly to ``TextEmbeddingIndexSettings`` or ``ConversationSettings``. - - Args: - embedding_size: Requested embedding dimensionality (provider-specific). - model_name: Model identifier (e.g. "text-embedding-ada-002"). - **kwargs: Extra keyword arguments forwarded to ``AsyncEmbeddingModel``. - - Returns: - An ``IEmbeddingModel`` instance backed by OpenAI / Azure OpenAI. - """ - return AsyncEmbeddingModel( - embedding_size=embedding_size, - model_name=model_name, - **kwargs, - ) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 6d8f06a9..c558111b 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -181,8 +181,11 @@ def create_chat_model( return PydanticAIChatModel(model) +DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-3-small" + + def create_embedding_model( - model_spec: str, + model_spec: str | None = None, *, embedding_size: int = 0, ) -> PydanticAIEmbeddingModel: @@ -190,6 +193,7 @@ def create_embedding_model( Delegates to :class:`pydantic_ai.Embedder` for provider wiring. + If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used. If *embedding_size* is not given, it will be probed automatically on the first embedding call. @@ -199,6 +203,8 @@ def create_embedding_model( model = create_embedding_model("cohere:embed-english-v3.0") model = create_embedding_model("google:text-embedding-004") """ + if model_spec is None: + model_spec = DEFAULT_EMBEDDING_SPEC model_name = model_spec.split(":")[-1] if ":" in model_spec else model_spec embedder = _PydanticAIEmbedder(model_spec) return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 0a52f838..df93997f 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -7,11 +7,11 @@ import numpy as np from .embeddings import ( - create_embedding_model, IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings, ) +from .model_adapters import create_embedding_model DEFAULT_MAX_RETRIES = 2 @@ -47,7 +47,7 @@ def __init__( max_retries if max_retries is not None else DEFAULT_MAX_RETRIES ) self.embedding_model = embedding_model or create_embedding_model( - embedding_size, max_retries=self.max_retries + embedding_size=embedding_size or 0, ) self.embedding_size = self.embedding_model.embedding_size assert ( @@ -93,6 +93,9 @@ def add_embedding( ) -> None: if isinstance(embedding, list): embedding = np.array(embedding, dtype=np.float32) + if self._embedding_size == 0: + self._set_embedding_size(len(embedding)) + self._vectors.shape = (0, self._embedding_size) embeddings = embedding.reshape(1, -1) # Make it 2D: 1xN self._vectors = np.append(self._vectors, embeddings, axis=0) if key is not None: @@ -102,6 +105,9 @@ def add_embeddings( self, keys: None | list[str], embeddings: NormalizedEmbeddings ) -> None: assert embeddings.ndim == 2 + if self._embedding_size == 0: + self._set_embedding_size(embeddings.shape[1]) + self._vectors.shape = (0, self._embedding_size) assert embeddings.shape[1] == self._embedding_size self._vectors = np.concatenate((self._vectors, embeddings), axis=0) if keys is not None: @@ -165,9 +171,17 @@ async def fuzzy_lookup( embedding, max_hits=max_hits, min_score=min_score, predicate=predicate ) + def _set_embedding_size(self, size: int) -> None: + """Adopt *size* when it was not known at construction time.""" + assert size > 0 + self._embedding_size = size + self._model.embedding_size = size + self.settings.embedding_size = size + def clear(self) -> None: self._vectors = np.array([], dtype=np.float32) - self._vectors.shape = (0, self._embedding_size) + if self._embedding_size > 0: + self._vectors.shape = (0, self._embedding_size) def get_embedding_at(self, pos: int) -> NormalizedEmbedding: if 0 <= pos < len(self._vectors): @@ -180,13 +194,20 @@ def serialize_embedding_at(self, pos: int) -> NormalizedEmbedding | None: return self._vectors[pos] if 0 <= pos < len(self._vectors) else None def serialize(self) -> NormalizedEmbeddings: - assert self._vectors.shape == (len(self._vectors), self._embedding_size) + if self._embedding_size > 0: + assert self._vectors.shape == (len(self._vectors), self._embedding_size) return self._vectors # TODO: Should we make a copy? def deserialize(self, data: NormalizedEmbeddings | None) -> None: if data is None: self.clear() return + if self._embedding_size == 0: + if data.ndim < 2 or data.shape[0] == 0: + # Empty data — can't determine size; just clear. + self.clear() + return + self._set_embedding_size(data.shape[1]) assert data.shape == (len(data), self._embedding_size), [ data.shape, self._embedding_size, diff --git a/src/typeagent/knowpro/convsettings.py b/src/typeagent/knowpro/convsettings.py index 7e25cd93..9dbf1214 100644 --- a/src/typeagent/knowpro/convsettings.py +++ b/src/typeagent/knowpro/convsettings.py @@ -5,7 +5,8 @@ from dataclasses import dataclass -from ..aitools.embeddings import create_embedding_model, IEmbeddingModel +from ..aitools.embeddings import IEmbeddingModel +from ..aitools.model_adapters import create_embedding_model from ..aitools.vectorbase import TextEmbeddingIndexSettings from .interfaces import IKnowledgeExtractor, IStorageProvider diff --git a/src/typeagent/knowpro/fuzzyindex.py b/src/typeagent/knowpro/fuzzyindex.py index 6ace1b34..97138e6c 100644 --- a/src/typeagent/knowpro/fuzzyindex.py +++ b/src/typeagent/knowpro/fuzzyindex.py @@ -137,7 +137,7 @@ def deserialize(self, embeddings: NormalizedEmbedding) -> None: assert embeddings.dtype == np.float32, embeddings.dtype assert embeddings.ndim == 2, embeddings.shape assert ( - embeddings.shape[1] == self._vector_base._embedding_size + self._vector_base._embedding_size == 0 + or embeddings.shape[1] == self._vector_base._embedding_size ), embeddings.shape - self.clear() - self.push(embeddings) + self._vector_base.deserialize(embeddings) diff --git a/src/typeagent/knowpro/serialization.py b/src/typeagent/knowpro/serialization.py index 1e48b68c..cbbe7b71 100644 --- a/src/typeagent/knowpro/serialization.py +++ b/src/typeagent/knowpro/serialization.py @@ -46,9 +46,14 @@ def create_file_header() -> FileHeader: return FileHeader(version="0.1") +class ModelMetadata(TypedDict): + embeddingSize: int + + class EmbeddingFileHeader(TypedDict): relatedCount: NotRequired[int | None] messageCount: NotRequired[int | None] + modelMetadata: NotRequired[ModelMetadata | None] class EmbeddingData(TypedDict): @@ -104,6 +109,7 @@ def to_conversation_file_data[TMessageData]( embedding_file_header = EmbeddingFileHeader() embeddings_list: list[NormalizedEmbeddings] = [] + embedding_size = 0 related_terms_index_data = conversation_data.get("relatedTermsIndexData") if related_terms_index_data is not None: @@ -114,6 +120,8 @@ def to_conversation_file_data[TMessageData]( embeddings_list.append(embeddings) text_embedding_data["embeddings"] = None embedding_file_header["relatedCount"] = len(embeddings) + if embedding_size == 0 and embeddings.ndim == 2: + embedding_size = embeddings.shape[1] message_index_data = conversation_data.get("messageIndexData") if message_index_data is not None: @@ -124,6 +132,13 @@ def to_conversation_file_data[TMessageData]( embeddings_list.append(embeddings) text_embedding_data["embeddings"] = None embedding_file_header["messageCount"] = len(embeddings) + if embedding_size == 0 and embeddings.ndim == 2: + embedding_size = embeddings.shape[1] + + if embedding_size > 0: + embedding_file_header["modelMetadata"] = ModelMetadata( + embeddingSize=embedding_size + ) binary_data = ConversationBinaryData(embeddingsList=embeddings_list) json_data = ConversationJsonData( diff --git a/src/typeagent/podcasts/podcast.py b/src/typeagent/podcasts/podcast.py index 3ed2639a..5376d20e 100644 --- a/src/typeagent/podcasts/podcast.py +++ b/src/typeagent/podcasts/podcast.py @@ -143,13 +143,19 @@ async def deserialize( @staticmethod def _read_conversation_data_from_file( - filename_prefix: str, embedding_size: int + filename_prefix: str, ) -> ConversationDataWithIndexes[Any]: """Read podcast conversation data from files. No exceptions are caught; they just bubble out.""" with open(filename_prefix + "_data.json", "r", encoding="utf-8") as f: json_data: serialization.ConversationJsonData[PodcastMessageData] = ( json.load(f) ) + embedding_file_header = json_data.get("embeddingFileHeader") + embedding_size = 0 + if embedding_file_header: + model_metadata = embedding_file_header.get("modelMetadata") + if model_metadata: + embedding_size = model_metadata.get("embeddingSize", 0) embeddings_list: list[NormalizedEmbeddings] | None = None if embedding_size: with open(filename_prefix + "_embeddings.bin", "rb") as f: @@ -159,7 +165,7 @@ def _read_conversation_data_from_file( embeddings_list = [embeddings] else: print( - "Warning: not reading embeddings file because size is {embedding_size}" + f"Warning: not reading embeddings file because size is {embedding_size}" ) embeddings_list = None file_data = serialization.ConversationFileData( @@ -178,10 +184,7 @@ async def read_from_file( settings: ConversationSettings, dbname: str | None = None, ) -> "Podcast": - embedding_size = settings.embedding_model.embedding_size - data = Podcast._read_conversation_data_from_file( - filename_prefix, embedding_size - ) + data = Podcast._read_conversation_data_from_file(filename_prefix) provider = await settings.get_storage_provider() msgs = await provider.get_message_collection() diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 975d6a70..515c9cf0 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone import sqlite3 -from ...aitools.embeddings import create_embedding_model +from ...aitools.model_adapters import create_embedding_model from ...aitools.vectorbase import TextEmbeddingIndexSettings from ...knowpro import interfaces from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings @@ -125,9 +125,12 @@ def _resolve_embedding_settings( if provided_message_settings is None: if stored_size is not None or stored_name is not None: + spec = stored_name or "" + if spec and ":" not in spec: + spec = f"openai:{spec}" embedding_model = create_embedding_model( - embedding_size=stored_size, - model_name=stored_name, + spec, + embedding_size=stored_size or 0, ) base_embedding_settings = TextEmbeddingIndexSettings( embedding_model=embedding_model, diff --git a/src/typeagent/transcripts/transcript.py b/src/typeagent/transcripts/transcript.py index 494166ba..5033e293 100644 --- a/src/typeagent/transcripts/transcript.py +++ b/src/typeagent/transcripts/transcript.py @@ -143,13 +143,19 @@ async def deserialize( @staticmethod def _read_conversation_data_from_file( - filename_prefix: str, embedding_size: int + filename_prefix: str, ) -> ConversationDataWithIndexes[Any]: """Read transcript conversation data from files. No exceptions are caught; they just bubble out.""" with open(filename_prefix + "_data.json", "r", encoding="utf-8") as f: json_data: serialization.ConversationJsonData[TranscriptMessageData] = ( json.load(f) ) + embedding_file_header = json_data.get("embeddingFileHeader") + embedding_size = 0 + if embedding_file_header: + model_metadata = embedding_file_header.get("modelMetadata") + if model_metadata: + embedding_size = model_metadata.get("embeddingSize", 0) embeddings_list: list[NormalizedEmbeddings] | None = None if embedding_size: with open(filename_prefix + "_embeddings.bin", "rb") as f: @@ -159,7 +165,7 @@ def _read_conversation_data_from_file( embeddings_list = [embeddings] else: print( - "Warning: not reading embeddings file because size is {embedding_size}" + f"Warning: not reading embeddings file because size is {embedding_size}" ) embeddings_list = None file_data = serialization.ConversationFileData( @@ -178,10 +184,7 @@ async def read_from_file( settings: ConversationSettings, dbname: str | None = None, ) -> "Transcript": - embedding_size = settings.embedding_model.embedding_size - data = Transcript._read_conversation_data_from_file( - filename_prefix, embedding_size - ) + data = Transcript._read_conversation_data_from_file(filename_prefix) provider = await settings.get_storage_provider() msgs = await provider.get_message_collection() diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 92aa71d9..adb46dd3 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -113,7 +113,7 @@ def test_write_and_read_conversation_data( # Read back the data read_data = Podcast._read_conversation_data_from_file( - str(filename), embedding_size=2 + str(filename), ) assert read_data is not None assert read_data.get("relatedTermsIndexData") is not None diff --git a/tools/ingest_vtt.py b/tools/ingest_vtt.py index 8bcef6d7..7fbf38fc 100644 --- a/tools/ingest_vtt.py +++ b/tools/ingest_vtt.py @@ -24,7 +24,7 @@ from dotenv import load_dotenv import webvtt -from typeagent.aitools.embeddings import create_embedding_model +from typeagent.aitools.model_adapters import create_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import ConversationMetadata from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH @@ -203,7 +203,10 @@ async def ingest_vtt_files( if verbose: print("Setting up conversation settings...") try: - embedding_model = create_embedding_model(model_name=embedding_name) + spec = embedding_name + if spec and ":" not in spec: + spec = f"openai:{spec}" + embedding_model = create_embedding_model(spec) settings = ConversationSettings(embedding_model) # Create metadata with the conversation name From 6f1286ffab1a07f31d072d0782ffc66878b92196 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 24 Feb 2026 19:01:06 -0800 Subject: [PATCH 09/23] Fix test_configure_models_returns_correct_types --- tests/test_model_adapters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index 33bd78c7..f4c910bf 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -244,8 +244,9 @@ async def test_embedding_adapter_empty_batch() -> None: # --------------------------------------------------------------------------- -def test_configure_models_returns_correct_types() -> None: +def test_configure_models_returns_correct_types(monkeypatch: pytest.MonkeyPatch) -> None: """configure_models creates both adapters.""" + monkeypatch.setenv("OPENAI_API_KEY", "test-key") chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") assert isinstance(chat, PydanticAIChatModel) assert isinstance(embedder, PydanticAIEmbeddingModel) From 83d6f0a0b6b9dae4b0947cc417633a3a5dee0f4d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 24 Feb 2026 19:15:32 -0800 Subject: [PATCH 10/23] Fall back on Azure for OpenAI models if only Azure key is present --- src/typeagent/aitools/model_adapters.py | 62 +++++++++++++++++++++++-- tests/test_model_adapters.py | 4 +- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index c558111b..fea2ed57 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -18,10 +18,16 @@ ``azure``, ``anthropic``, ``google``, ``bedrock``, ``groq``, ``mistral``, ``ollama``, ``cohere``, and many more. +When a spec uses ``openai:`` as the provider and ``OPENAI_API_KEY`` is not +set, but ``AZURE_OPENAI_API_KEY`` is available, the provider is +automatically switched to Azure OpenAI. + See https://ai.pydantic.dev/models/ for all supported providers and their required environment variables. """ +import os + import numpy as np from numpy.typing import NDArray @@ -159,6 +165,36 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: return np.array([self._cache[k] for k in keys], dtype=np.float32) +# --------------------------------------------------------------------------- +# Provider auto-detection +# --------------------------------------------------------------------------- + + +def _needs_azure_fallback(provider: str) -> bool: + """Return True if *provider* is ``openai`` but only Azure credentials exist.""" + return ( + provider == "openai" + and not os.getenv("OPENAI_API_KEY") + and bool(os.getenv("AZURE_OPENAI_API_KEY")) + ) + + +def _make_azure_provider(): + """Create a :class:`pydantic_ai.providers.azure.AzureProvider`.""" + from pydantic_ai.providers.azure import AzureProvider + + from .utils import get_azure_api_key, parse_azure_endpoint + + raw_key = os.environ["AZURE_OPENAI_API_KEY"] + api_key = get_azure_api_key(raw_key) + azure_endpoint, api_version = parse_azure_endpoint("AZURE_OPENAI_ENDPOINT") + return AzureProvider( + azure_endpoint=azure_endpoint, + api_version=api_version, + api_key=api_key, + ) + + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- @@ -170,6 +206,8 @@ def create_chat_model( """Create a chat model from a ``provider:model`` spec. Delegates to :func:`pydantic_ai.models.infer_model` for provider wiring. + If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but + ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. Examples:: @@ -177,7 +215,13 @@ def create_chat_model( model = create_chat_model("anthropic:claude-sonnet-4-20250514") model = create_chat_model("google:gemini-2.0-flash") """ - model = infer_model(model_spec) + provider, _, model_name = model_spec.partition(":") + if _needs_azure_fallback(provider): + from pydantic_ai.models.openai import OpenAIChatModel + + model = OpenAIChatModel(model_name, provider=_make_azure_provider()) + else: + model = infer_model(model_spec) return PydanticAIChatModel(model) @@ -192,6 +236,8 @@ def create_embedding_model( """Create an embedding model from a ``provider:model`` spec. Delegates to :class:`pydantic_ai.Embedder` for provider wiring. + If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but + ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used. If *embedding_size* is not given, it will be probed automatically @@ -205,8 +251,18 @@ def create_embedding_model( """ if model_spec is None: model_spec = DEFAULT_EMBEDDING_SPEC - model_name = model_spec.split(":")[-1] if ":" in model_spec else model_spec - embedder = _PydanticAIEmbedder(model_spec) + provider, _, model_name = model_spec.partition(":") + if not model_name: + model_name = provider # No colon in spec + if _needs_azure_fallback(provider): + from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel + + embedding_model = OpenAIEmbeddingModel( + model_name, provider=_make_azure_provider() + ) + embedder = _PydanticAIEmbedder(embedding_model) + else: + embedder = _PydanticAIEmbedder(model_spec) return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index f4c910bf..998ddd97 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -244,7 +244,9 @@ async def test_embedding_adapter_empty_batch() -> None: # --------------------------------------------------------------------------- -def test_configure_models_returns_correct_types(monkeypatch: pytest.MonkeyPatch) -> None: +def test_configure_models_returns_correct_types( + monkeypatch: pytest.MonkeyPatch, +) -> None: """configure_models creates both adapters.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key") chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") From 2659f3015675808ed876adbbba7365fa34881efd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 09:47:51 -0800 Subject: [PATCH 11/23] Use embed_documents() instead of embed(input_type=["document"]) --- src/typeagent/aitools/model_adapters.py | 6 +++--- tests/test_model_adapters.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index fea2ed57..472eae8e 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -119,11 +119,11 @@ def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: async def _probe_embedding_size(self) -> None: """Discover embedding_size by making a single API call.""" - result = await self._embedder.embed(["probe"], input_type="document") + result = await self._embedder.embed_documents(["probe"]) self.embedding_size = len(result.embeddings[0]) async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: - result = await self._embedder.embed([input], input_type="document") + result = await self._embedder.embed_documents([input]) embedding: NDArray[np.float32] = np.array( result.embeddings[0], dtype=np.float32 ) @@ -139,7 +139,7 @@ async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings if self.embedding_size == 0: await self._probe_embedding_size() return np.empty((0, self.embedding_size), dtype=np.float32) - result = await self._embedder.embed(input, input_type="document") + result = await self._embedder.embed_documents(input) embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) if self.embedding_size == 0: self.embedding_size = embeddings.shape[1] diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index 998ddd97..5e773430 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -124,7 +124,7 @@ async def test_embedding_adapter_single() -> None: mock_embedder = AsyncMock(spec=Embedder) raw_vec = [3.0, 4.0, 0.0] - mock_embedder.embed.return_value = EmbeddingResult( + mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[raw_vec], inputs=["test"], input_type="document", @@ -148,7 +148,7 @@ async def test_embedding_adapter_probes_size() -> None: from pydantic_ai.embeddings import EmbeddingResult mock_embedder = AsyncMock(spec=Embedder) - mock_embedder.embed.return_value = EmbeddingResult( + mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], inputs=["probe"], input_type="document", @@ -171,7 +171,7 @@ async def test_embedding_adapter_batch() -> None: from pydantic_ai.embeddings import EmbeddingResult mock_embedder = AsyncMock(spec=Embedder) - mock_embedder.embed.return_value = EmbeddingResult( + mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0], [0.0, 1.0]], inputs=["a", "b"], input_type="document", @@ -193,7 +193,7 @@ async def test_embedding_adapter_caching() -> None: from pydantic_ai.embeddings import EmbeddingResult mock_embedder = AsyncMock(spec=Embedder) - mock_embedder.embed.return_value = EmbeddingResult( + mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], inputs=["cached"], input_type="document", @@ -205,8 +205,8 @@ async def test_embedding_adapter_caching() -> None: first = await adapter.get_embedding("cached") second = await adapter.get_embedding("cached") np.testing.assert_array_equal(first, second) - # embed() should only be called once - assert mock_embedder.embed.call_count == 1 + # embed_documents() should only be called once + assert mock_embedder.embed_documents.call_count == 1 @pytest.mark.asyncio @@ -222,8 +222,8 @@ async def test_embedding_adapter_add_embedding() -> None: adapter.add_embedding("key", vec) result = await adapter.get_embedding("key") np.testing.assert_array_equal(result, vec) - # No embed() call needed - mock_embedder.embed.assert_not_called() + # No embed_documents() call needed + mock_embedder.embed_documents.assert_not_called() @pytest.mark.asyncio From 2b8735b9c9e82d6eaac12ba206e0e0213c904727 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 09:59:24 -0800 Subject: [PATCH 12/23] Fix the mcp test. We now do the right thing with azure endpoint env vars (I hope) --- src/typeagent/aitools/model_adapters.py | 40 ++++++++++++++++++++----- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 472eae8e..ab376034 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -179,20 +179,32 @@ def _needs_azure_fallback(provider: str) -> bool: ) -def _make_azure_provider(): - """Create a :class:`pydantic_ai.providers.azure.AzureProvider`.""" +def _make_azure_provider( + endpoint_envvar: str = "AZURE_OPENAI_ENDPOINT", + api_key_envvar: str = "AZURE_OPENAI_API_KEY", +): + """Create a :class:`pydantic_ai.providers.azure.AzureProvider`. + + Constructs an ``AsyncAzureOpenAI`` client from the given environment + variables and wraps it in an ``AzureProvider``. The endpoint env-var + may contain a full Azure deployment URL (including path and + ``api-version`` query parameter) — the same format used throughout + this codebase. + """ + from openai import AsyncAzureOpenAI from pydantic_ai.providers.azure import AzureProvider from .utils import get_azure_api_key, parse_azure_endpoint - raw_key = os.environ["AZURE_OPENAI_API_KEY"] + raw_key = os.environ[api_key_envvar] api_key = get_azure_api_key(raw_key) - azure_endpoint, api_version = parse_azure_endpoint("AZURE_OPENAI_ENDPOINT") - return AzureProvider( + azure_endpoint, api_version = parse_azure_endpoint(endpoint_envvar) + client = AsyncAzureOpenAI( azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key, ) + return AzureProvider(openai_client=client) # --------------------------------------------------------------------------- @@ -257,9 +269,23 @@ def create_embedding_model( if _needs_azure_fallback(provider): from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel - embedding_model = OpenAIEmbeddingModel( - model_name, provider=_make_azure_provider() + from .embeddings import model_to_embedding_size_and_envvar + + # Look up model-specific Azure endpoint, falling back to the generic one. + _, suggested_envvar = model_to_embedding_size_and_envvar.get( + model_name, (None, None) ) + if suggested_envvar and os.getenv(suggested_envvar): + endpoint_envvar = suggested_envvar + else: + endpoint_envvar = "AZURE_OPENAI_ENDPOINT_EMBEDDING" + # Allow a model-specific API key, falling back to the generic one. + api_key_envvar = "AZURE_OPENAI_API_KEY_EMBEDDING" + if not os.getenv(api_key_envvar): + api_key_envvar = "AZURE_OPENAI_API_KEY" + + azure_provider = _make_azure_provider(endpoint_envvar, api_key_envvar) + embedding_model = OpenAIEmbeddingModel(model_name, provider=azure_provider) embedder = _PydanticAIEmbedder(embedding_model) else: embedder = _PydanticAIEmbedder(model_spec) From 05183d956a06ea5314e11d32144dbd3360df3625 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 11:31:13 -0800 Subject: [PATCH 13/23] Remove AsyncEmbeddingModel; migrate all tests to PydanticAIEmbeddingModel --- src/typeagent/aitools/embeddings.py | 299 +------------------ src/typeagent/aitools/model_adapters.py | 89 ++++++ tests/conftest.py | 93 +----- tests/test_add_messages_with_indexing.py | 8 +- tests/test_conversation_metadata.py | 12 +- tests/test_embedding_consistency.py | 14 +- tests/test_embeddings.py | 251 +++------------- tests/test_factory.py | 8 +- tests/test_incremental_index.py | 12 +- tests/test_message_text_index_population.py | 4 +- tests/test_messageindex.py | 16 +- tests/test_podcast_incremental.py | 6 +- tests/test_property_index_population.py | 17 +- tests/test_query_method.py | 6 +- tests/test_related_terms_fast.py | 4 +- tests/test_related_terms_index_population.py | 4 +- tests/test_secindex.py | 8 +- tests/test_secindex_storage_integration.py | 4 +- tests/test_transcripts.py | 7 +- tests/test_vectorbase.py | 15 +- 20 files changed, 200 insertions(+), 677 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index 5aacaafa..f56db25d 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -1,22 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import asyncio -import os from typing import Protocol, runtime_checkable import numpy as np from numpy.typing import NDArray -from openai import AsyncAzureOpenAI, AsyncOpenAI, DEFAULT_MAX_RETRIES, OpenAIError -from openai.types import Embedding -import tiktoken -from tiktoken import model as tiktoken_model -from tiktoken.core import Encoding - -from .auth import AzureTokenProvider, get_shared_token_provider -from .utils import timelog - type NormalizedEmbedding = NDArray[np.float32] # A single embedding type NormalizedEmbeddings = NDArray[np.float32] # An array of embeddings @@ -26,8 +15,8 @@ class IEmbeddingModel(Protocol): """Provider-agnostic interface for embedding models. Implement this protocol to add support for a new embedding provider - (e.g. Anthropic, Gemini, local models). The existing AsyncEmbeddingModel - implements it for OpenAI and Azure OpenAI. + (e.g. Anthropic, Gemini, local models). The production implementation + is :class:`~typeagent.aitools.model_adapters.PydanticAIEmbeddingModel`. """ model_name: str @@ -58,11 +47,6 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002) DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI TEST_MODEL_NAME = "test" -MAX_BATCH_SIZE = 2048 -MAX_TOKEN_SIZE = 4096 -MAX_TOKENS_PER_BATCH = 300_000 -MAX_CHAR_SIZE = MAX_TOKEN_SIZE * 3 -MAX_CHARS_PER_BATCH = MAX_TOKENS_PER_BATCH * 3 model_to_embedding_size_and_envvar: dict[str, tuple[int | None, str]] = { DEFAULT_MODEL_NAME: (DEFAULT_EMBEDDING_SIZE, DEFAULT_ENVVAR), @@ -71,282 +55,3 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: # For testing only, not a real model (insert real embeddings above) TEST_MODEL_NAME: (3, "SIR_NOT_APPEARING_IN_THIS_FILM"), } - - -class AsyncEmbeddingModel: - model_name: str - embedding_size: int - endpoint_envvar: str - use_azure: bool - azure_token_provider: AzureTokenProvider | None - async_client: AsyncOpenAI | None - azure_endpoint: str - azure_api_version: str - encoding: Encoding | None - max_chunk_size: int - max_size_per_batch: int - - _embedding_cache: dict[str, NormalizedEmbedding] - - def __init__( - self, - embedding_size: int | None = None, - model_name: str | None = None, - endpoint_envvar: str | None = None, - max_retries: int = DEFAULT_MAX_RETRIES, - use_azure: bool | None = None, - ): - if model_name is None: - model_name = DEFAULT_MODEL_NAME - self.model_name = model_name - - suggested_embedding_size, suggested_endpoint_envvar = ( - model_to_embedding_size_and_envvar.get(model_name, (None, None)) - ) - - if embedding_size is None: - if suggested_embedding_size is not None: - embedding_size = suggested_embedding_size - else: - embedding_size = DEFAULT_EMBEDDING_SIZE - self.embedding_size = embedding_size - - if ( - model_name == DEFAULT_MODEL_NAME - and embedding_size != DEFAULT_EMBEDDING_SIZE - ): - raise ValueError( - f"Cannot customize embedding_size for default model {DEFAULT_MODEL_NAME}" - ) - - # Read API keys once - openai_api_key = os.getenv("OPENAI_API_KEY") - azure_api_key = os.getenv("AZURE_OPENAI_API_KEY") - - # Determine provider: explicit use_azure overrides auto-detection. - if use_azure is not None: - self.use_azure = use_azure - else: - # Prefer OpenAI if both are set, use Azure if only Azure is set - self.use_azure = bool(azure_api_key) and not bool(openai_api_key) - - if endpoint_envvar is None: - # Check if OpenAI credentials are available, prefer OpenAI over Azure - if openai_api_key: - endpoint_envvar = "OPENAI_BASE_URL" # Use OpenAI - elif suggested_endpoint_envvar is not None: - endpoint_envvar = suggested_endpoint_envvar - else: - endpoint_envvar = DEFAULT_ENVVAR - - self.endpoint_envvar = endpoint_envvar - self.azure_token_provider = None - - if self.model_name == TEST_MODEL_NAME: - self.async_client = None - elif self.use_azure: - if not azure_api_key: - raise ValueError("AZURE_OPENAI_API_KEY not found in environment.") - with timelog("Using Azure OpenAI"): - self._setup_azure(azure_api_key) - else: - if not openai_api_key: - raise ValueError("OPENAI_API_KEY not found in environment.") - endpoint = os.getenv(self.endpoint_envvar) - with timelog("Using OpenAI"): - self.async_client = AsyncOpenAI( - base_url=endpoint, api_key=openai_api_key, max_retries=max_retries - ) - - if self.model_name in tiktoken_model.MODEL_TO_ENCODING: - encoding_name = tiktoken.encoding_name_for_model(self.model_name) - self.encoding = tiktoken.get_encoding(encoding_name) - self.max_chunk_size = MAX_TOKEN_SIZE - self.max_size_per_batch = MAX_TOKENS_PER_BATCH - else: - self.encoding = None - self.max_chunk_size = MAX_CHAR_SIZE - self.max_size_per_batch = MAX_CHARS_PER_BATCH - - self._embedding_cache = {} - - def _setup_azure(self, azure_api_key: str) -> None: - from .utils import get_azure_api_key, parse_azure_endpoint - - azure_api_key = get_azure_api_key(azure_api_key) - self.azure_endpoint, self.azure_api_version = parse_azure_endpoint( - self.endpoint_envvar - ) - - if azure_api_key != os.getenv("AZURE_OPENAI_API_KEY"): - # If we got a token from identity, store the provider for refresh - self.azure_token_provider = get_shared_token_provider() - - self.async_client = AsyncAzureOpenAI( - api_version=self.azure_api_version, - azure_endpoint=self.azure_endpoint, - api_key=azure_api_key, - ) - - async def refresh_auth(self): - """Update client when using a token provider and it's nearly expired.""" - # refresh_token is synchronous and slow -- run it in a separate thread - assert self.azure_token_provider - refresh_token = self.azure_token_provider.refresh_token - loop = asyncio.get_running_loop() - azure_api_key = await loop.run_in_executor(None, refresh_token) - assert self.azure_api_version - assert self.azure_endpoint - self.async_client = AsyncAzureOpenAI( - api_version=self.azure_api_version, - azure_endpoint=self.azure_endpoint, - api_key=azure_api_key, - ) - - def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: - existing = self._embedding_cache.get(key) - if existing is not None: - assert np.array_equal(existing, embedding) - else: - self._embedding_cache[key] = embedding - - async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: - embeddings = await self.get_embeddings_nocache([input]) - return embeddings[0] - - async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: - if not input: - empty = np.array([], dtype=np.float32) - empty.shape = (0, self.embedding_size) - return empty - if self.azure_token_provider and self.azure_token_provider.needs_refresh(): - await self.refresh_auth() - extra_args = {} - if self.model_name != DEFAULT_MODEL_NAME: - extra_args["dimensions"] = self.embedding_size - if self.async_client is None: - # Compute a random embedding for testing purposes. - - def hashish(s: str) -> int: - # Primitive deterministic hash function (hash() varies per run) - h = 0 - for ch in s: - h = (h * 31 + ord(ch)) & 0xFFFFFFFF - return h - - prime = 1961 - fake_data: list[NormalizedEmbedding] = [] - for item in input: - if not item: - raise OpenAIError - length = len(item) - floats = [] - for i in range(self.embedding_size): - cut = i % length - scrambled = item[cut:] + item[:cut] - hashed = hashish(scrambled) - reduced = (hashed % prime) / prime - floats.append(reduced) - array = np.array(floats, dtype=np.float64) - normalized = array / np.sqrt(np.dot(array, array)) - dot = np.dot(normalized, normalized) - assert ( - abs(dot - 1.0) < 1e-15 - ), f"Embedding {normalized} is not normalized: {dot}" - fake_data.append(normalized) - assert len(fake_data) == len(input), (len(fake_data), "!=", len(input)) - result = np.array(fake_data, dtype=np.float32) - return result - else: - batches: list[list[str]] = [] - batch: list[str] = [] - batch_sum: int = 0 - for sentence in input: - truncated_input, truncated_input_size = await self.truncate_input( - sentence - ) - if ( - len(batch) >= MAX_BATCH_SIZE - or batch_sum + truncated_input_size > self.max_size_per_batch - ): - batches.append(batch) - batch = [] - batch_sum = 0 - batch.append(truncated_input) - batch_sum += truncated_input_size - if batch: - batches.append(batch) - - data: list[Embedding] = [] - for batch in batches: - embeddings_data = ( - await self.async_client.embeddings.create( - input=batch, - model=self.model_name, - encoding_format="float", - **extra_args, - ) - ).data - data.extend(embeddings_data) - - assert len(data) == len(input), (len(data), "!=", len(input)) - return np.array([d.embedding for d in data], dtype=np.float32) - - async def get_embedding(self, key: str) -> NormalizedEmbedding: - """Retrieve an embedding, using the cache.""" - if key in self._embedding_cache: - return self._embedding_cache[key] - embedding = await self.get_embedding_nocache(key) - self._embedding_cache[key] = embedding - return embedding - - async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: - """Retrieve embeddings for multiple keys, using the cache.""" - embeddings: list[NormalizedEmbedding | None] = [] - missing_keys: list[str] = [] - - # Collect cached embeddings and identify missing keys - for key in keys: - if key in self._embedding_cache: - embeddings.append(self._embedding_cache[key]) - else: - embeddings.append(None) # Placeholder for missing keys - missing_keys.append(key) - - # Retrieve embeddings for missing keys - if missing_keys: - new_embeddings = await self.get_embeddings_nocache(missing_keys) - for key, embedding in zip(missing_keys, new_embeddings): - self._embedding_cache[key] = embedding - - # Replace placeholders with retrieved embeddings - for i, key in enumerate(keys): - if embeddings[i] is None: - embeddings[i] = self._embedding_cache[key] - return np.array(embeddings, dtype=np.float32).reshape( - (len(keys), self.embedding_size) - ) - - async def truncate_input(self, input: str) -> tuple[str, int]: - """Truncate input strings to fit within model limits. - - args: - input: The input string to truncate. - - returns: - A tuple of (truncated string, size after truncation). - """ - if self.encoding is None: - # Non-token-aware truncation - if len(input) > self.max_chunk_size: - return input[: self.max_chunk_size], self.max_chunk_size - else: - return input, len(input) - else: - # Token-aware truncation - tokens = self.encoding.encode(input) - if len(tokens) > self.max_chunk_size: - truncated_tokens = tokens[: self.max_chunk_size] - return self.encoding.decode(truncated_tokens), self.max_chunk_size - else: - return input, len(tokens) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index ab376034..4acc0ff8 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -26,12 +26,16 @@ required environment variables. """ +from collections.abc import Sequence import os import numpy as np from numpy.typing import NDArray from pydantic_ai import Embedder as _PydanticAIEmbedder +from pydantic_ai.embeddings.base import EmbeddingModel as _PydanticAIEmbeddingModelBase +from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType +from pydantic_ai.embeddings.settings import EmbeddingSettings from pydantic_ai.messages import ( ModelMessage, ModelRequest, @@ -157,6 +161,10 @@ async def get_embedding(self, key: str) -> NormalizedEmbedding: return embedding async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + if not keys: + if self.embedding_size == 0: + await self._probe_embedding_size() + return np.empty((0, self.embedding_size), dtype=np.float32) missing_keys = [k for k in keys if k not in self._cache] if missing_keys: fresh = await self.get_embeddings_nocache(missing_keys) @@ -292,6 +300,87 @@ def create_embedding_model( return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +def _hashish(s: str) -> int: + """Deterministic hash function for fake embeddings (hash() varies per run).""" + h = 0 + for ch in s: + h = (h * 31 + ord(ch)) & 0xFFFFFFFF + return h + + +def _compute_fake_embeddings( + input_texts: list[str], embedding_size: int +) -> list[list[float]]: + """Generate deterministic fake embeddings for testing (unnormalized). + + Raises :class:`ValueError` on empty input strings. + """ + prime = 1961 + result: list[list[float]] = [] + for item in input_texts: + if not item: + raise ValueError("Empty input text") + length = len(item) + floats: list[float] = [] + for i in range(embedding_size): + cut = i % length + scrambled = item[cut:] + item[:cut] + hashed = _hashish(scrambled) + reduced = (hashed % prime) / prime + floats.append(reduced) + result.append(floats) + return result + + +class _FakePydanticAIEmbeddingModel(_PydanticAIEmbeddingModelBase): + """A pydantic_ai :class:`EmbeddingModel` that returns deterministic fake + embeddings. Used only for testing — no network calls are made.""" + + def __init__(self, embedding_size: int = 3) -> None: + super().__init__() + self._embedding_size = embedding_size + + @property + def model_name(self) -> str: + return "test" + + @property + def system(self) -> str: + return "test" + + async def embed( + self, + inputs: str | Sequence[str], + *, + input_type: EmbedInputType, + settings: EmbeddingSettings | None = None, + ) -> EmbeddingResult: + inputs_list, settings = self.prepare_embed(inputs, settings) + embeddings = _compute_fake_embeddings(inputs_list, self._embedding_size) + return EmbeddingResult( + embeddings=embeddings, + inputs=inputs_list, + input_type=input_type, + model_name="test", + provider_name="test", + ) + + +def create_test_embedding_model( + embedding_size: int = 3, +) -> PydanticAIEmbeddingModel: + """Create a :class:`PydanticAIEmbeddingModel` with deterministic fake + embeddings for testing. No API keys or network access required.""" + fake_model = _FakePydanticAIEmbeddingModel(embedding_size) + embedder = _PydanticAIEmbedder(fake_model) + return PydanticAIEmbeddingModel(embedder, "test", embedding_size) + + def configure_models( chat_model_spec: str, embedding_model_spec: str, diff --git a/tests/conftest.py b/tests/conftest.py index 7f7ce210..c4de6d47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,15 +11,8 @@ import pytest import pytest_asyncio -from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage -from openai.types.embedding import Embedding -import tiktoken - -from typeagent.aitools.embeddings import ( - AsyncEmbeddingModel, - IEmbeddingModel, - TEST_MODEL_NAME, -) +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -96,7 +89,7 @@ def really_needs_auth() -> None: @pytest.fixture(scope="session") def embedding_model() -> IEmbeddingModel: """Fixture to create a test embedding model with small embedding size for faster tests.""" - return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + return create_test_embedding_model() @pytest.fixture(scope="session") @@ -303,7 +296,7 @@ def __init__( self._has_secondary_indexes = has_secondary_indexes else: # Create test model for settings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() self.settings = ConversationSettings(test_model, storage_provider) self._needs_async_init = False self._storage_provider = storage_provider @@ -323,7 +316,7 @@ def __init__( async def ensure_initialized(self): """Ensure async initialization is complete.""" if self._needs_async_init: - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() self.settings = ConversationSettings(test_model) storage_provider = await self.settings.get_storage_provider() self._storage_provider = storage_provider @@ -355,79 +348,3 @@ async def fake_conversation_with_storage( ) -> FakeConversation: """Fixture to create a FakeConversation instance with storage provider.""" return FakeConversation(storage_provider=memory_storage) - - -class FakeEmbeddings: - - def __init__( - self, - max_batch_size: int = 2048, - max_chunk_size: int = 4096, - max_elements_per_batch: int = 300_000, - use_tiktoken: bool = False, - ): - self.model_name = "text-embedding-ada-002" - self.call_count = 0 - self.max_batch_size = max_batch_size - self.max_chunk_size = max_chunk_size - self.max_elements_per_batch = max_elements_per_batch - self.use_tiktoken = use_tiktoken - - def reset_counter(self): - self.call_count = 0 - - async def create(self, **kwargs): - self.call_count += 1 - input = kwargs["input"] - len_input = len(input) - if len_input > self.max_batch_size: - raise ValueError("Embedding model received batch larger 2048") - dimensions = 1536 - if "dimensions" in kwargs: - dimensions = kwargs["dimensions"] - - embedding_result = [] - total_elements = 0 - for index in range(len_input): - entity = input[index] - if self.use_tiktoken: - enc_name = tiktoken.encoding_name_for_model(self.model_name) - enc = tiktoken.get_encoding(enc_name) - entity = enc.encode(entity) - total_elements += len(entity) - if len(entity) > self.max_chunk_size: - raise ValueError( - f"Chunk size {len(entity)} larger than max size {self.max_chunk_size}" - ) - value = index % 2 - embedding_result.append( - Embedding( - embedding=[value] * dimensions, index=index, object="embedding" - ) - ) - - if total_elements > self.max_elements_per_batch: - raise ValueError( - f"Batch size {total_elements} larger than max tokens/chars per batch {self.max_elements_per_batch}" - ) - - response = CreateEmbeddingResponse( - data=embedding_result, - model="test_model", - object="list", - usage=Usage(prompt_tokens=0, total_tokens=0), - ) - - return response - - -@pytest.fixture -def fake_embeddings() -> FakeEmbeddings: - """Fixture to create a FaceEmbedding instance""" - return FakeEmbeddings(max_batch_size=2048, max_chunk_size=4096 * 3) - - -@pytest.fixture -def fake_embeddings_tiktoken() -> FakeEmbeddings: - """Fixture to create a FaceEmbedding instance""" - return FakeEmbeddings(max_batch_size=2048, max_chunk_size=4096, use_tiktoken=True) diff --git a/tests/test_add_messages_with_indexing.py b/tests/test_add_messages_with_indexing.py index 4f00cfb1..d3df2c4d 100644 --- a/tests/test_add_messages_with_indexing.py +++ b/tests/test_add_messages_with_indexing.py @@ -8,7 +8,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.transcripts.transcript import ( @@ -24,7 +24,7 @@ async def test_add_messages_with_indexing_basic(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -65,7 +65,7 @@ async def test_add_messages_with_indexing_batched(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -122,7 +122,7 @@ async def test_transaction_rollback_on_error(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False diff --git a/tests/test_conversation_metadata.py b/tests/test_conversation_metadata.py index 69d80dab..887c50b2 100644 --- a/tests/test_conversation_metadata.py +++ b/tests/test_conversation_metadata.py @@ -16,11 +16,8 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import ( - AsyncEmbeddingModel, - IEmbeddingModel, - TEST_MODEL_NAME, -) +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -80,7 +77,7 @@ async def storage_provider_memory() -> ( AsyncGenerator[SqliteStorageProvider[DummyMessage], None] ): """Create an in-memory SqliteStorageProvider for testing conversation metadata.""" - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) @@ -624,9 +621,8 @@ async def test_embedding_metadata_mismatch_raises( provider.db.commit() await provider.close() - mismatched_model = AsyncEmbeddingModel( + mismatched_model = create_test_embedding_model( embedding_size=embedding_settings.embedding_size + 1, - model_name=embedding_model.model_name, ) mismatched_settings = TextEmbeddingIndexSettings( embedding_model=mismatched_model, diff --git a/tests/test_embedding_consistency.py b/tests/test_embedding_consistency.py index 906c2b52..f032c856 100644 --- a/tests/test_embedding_consistency.py +++ b/tests/test_embedding_consistency.py @@ -9,7 +9,7 @@ import pytest from typeagent import create_conversation -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite import SqliteStorageProvider from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @@ -25,7 +25,7 @@ async def test_embedding_size_mismatch_in_message_index(): try: # Create a conversation with test model (embedding size 3) settings1 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=3, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False @@ -46,7 +46,7 @@ async def test_embedding_size_mismatch_in_message_index(): # Now try to open the same database with a different embedding size # This should raise an error settings2 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=5, model_name="test") + model=create_test_embedding_model(embedding_size=5) ) with pytest.raises(ValueError, match="embedding_size"): @@ -74,7 +74,7 @@ async def test_embedding_size_mismatch_in_related_terms(): try: # Create a conversation with default embedding size settings1 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=3, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False @@ -95,7 +95,7 @@ async def test_embedding_size_mismatch_in_related_terms(): # Now try to open the same database with a different embedding size # This should raise an error settings2 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=5, model_name="test") + model=create_test_embedding_model(embedding_size=5) ) with pytest.raises(ValueError, match="embedding_size"): @@ -123,7 +123,7 @@ async def test_empty_db_no_error(): try: # Create an empty database settings1 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=3, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False @@ -134,7 +134,7 @@ async def test_empty_db_no_error(): # Open with different embedding size should work since DB is empty settings2 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=5, model_name="test") + model=create_test_embedding_model(embedding_size=5) ) provider = SqliteStorageProvider( db_path=db_path, diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index ac766f1b..2ae09845 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -3,23 +3,18 @@ import numpy as np import pytest -from pytest import MonkeyPatch from pytest_mock import MockerFixture -import openai - -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import PydanticAIEmbeddingModel from conftest import ( embedding_model, # type: ignore # Magic, prevents side effects of mocking ) -from conftest import ( - FakeEmbeddings, -) @pytest.mark.asyncio -async def test_get_embedding_nocache(embedding_model: AsyncEmbeddingModel): +async def test_get_embedding_nocache(embedding_model: PydanticAIEmbeddingModel): """Test retrieving an embedding without using the cache.""" input_text = "Hello, world" embedding = await embedding_model.get_embedding_nocache(input_text) @@ -30,7 +25,7 @@ async def test_get_embedding_nocache(embedding_model: AsyncEmbeddingModel): @pytest.mark.asyncio -async def test_get_embeddings_nocache(embedding_model: AsyncEmbeddingModel): +async def test_get_embeddings_nocache(embedding_model: PydanticAIEmbeddingModel): """Test retrieving multiple embeddings without using the cache.""" inputs = ["Hello, world", "Foo bar baz"] embeddings = await embedding_model.get_embeddings_nocache(inputs) @@ -42,14 +37,14 @@ async def test_get_embeddings_nocache(embedding_model: AsyncEmbeddingModel): @pytest.mark.asyncio async def test_get_embedding_with_cache( - embedding_model: AsyncEmbeddingModel, mocker: MockerFixture + embedding_model: PydanticAIEmbeddingModel, mocker: MockerFixture ): """Test retrieving an embedding with caching.""" input_text = "Hello, world" # First call should populate the cache embedding1 = await embedding_model.get_embedding(input_text) - assert input_text in embedding_model._embedding_cache + assert input_text in embedding_model._cache # Mock the nocache method to ensure it's not called mock_get_embedding_nocache = mocker.patch.object( @@ -66,7 +61,7 @@ async def test_get_embedding_with_cache( @pytest.mark.asyncio async def test_get_embeddings_with_cache( - embedding_model: AsyncEmbeddingModel, mocker: MockerFixture + embedding_model: PydanticAIEmbeddingModel, mocker: MockerFixture ): """Test retrieving multiple embeddings with caching.""" inputs = ["Hello, world", "Foo bar baz"] @@ -74,7 +69,7 @@ async def test_get_embeddings_with_cache( # First call should populate the cache embeddings1 = await embedding_model.get_embeddings(inputs) for input_text in inputs: - assert input_text in embedding_model._embedding_cache + assert input_text in embedding_model._cache # Mock the nocache method to ensure it's not called mock_get_embeddings_nocache = mocker.patch.object( @@ -90,9 +85,9 @@ async def test_get_embeddings_with_cache( @pytest.mark.asyncio -async def test_get_embeddings_empty_input(embedding_model: AsyncEmbeddingModel): +async def test_get_embeddings_empty_input(embedding_model: PydanticAIEmbeddingModel): """Test retrieving embeddings for an empty input list.""" - inputs = [] + inputs: list[str] = [] embeddings = await embedding_model.get_embeddings(inputs) assert isinstance(embeddings, np.ndarray) @@ -101,222 +96,60 @@ async def test_get_embeddings_empty_input(embedding_model: AsyncEmbeddingModel): @pytest.mark.asyncio -async def test_add_embedding_to_cache(embedding_model: AsyncEmbeddingModel): +async def test_add_embedding_to_cache(embedding_model: PydanticAIEmbeddingModel): """Test adding an embedding to the cache.""" key = "test_key" embedding = np.array([0.1, 0.2, 0.3], dtype=np.float32) embedding_model.add_embedding(key, embedding) - assert key in embedding_model._embedding_cache - assert np.array_equal(embedding_model._embedding_cache[key], embedding) + assert key in embedding_model._cache + assert np.array_equal(embedding_model._cache[key], embedding) @pytest.mark.asyncio -async def test_get_embedding_nocache_empty_input(embedding_model: AsyncEmbeddingModel): +async def test_get_embedding_nocache_empty_input( + embedding_model: PydanticAIEmbeddingModel, +): """Test retrieving an embedding with no cache for an empty input.""" - with pytest.raises(openai.OpenAIError): + with pytest.raises(ValueError, match="Empty input text"): await embedding_model.get_embedding_nocache("") @pytest.mark.asyncio -async def test_refresh_auth( - embedding_model: AsyncEmbeddingModel, mocker: MockerFixture -): - """Test refreshing authentication when using Azure.""" - # Note that pyright doesn't understand mocking, hence the `# type: ignore` below - mocker.patch.object(embedding_model, "azure_token_provider", autospec=True) - mocker.patch.object(embedding_model, "_setup_azure", autospec=True) - - embedding_model.azure_token_provider.needs_refresh.return_value = True # type: ignore - embedding_model.azure_token_provider.refresh_token.return_value = "new_token" # type: ignore - embedding_model.azure_api_version = "2023-05-15" - embedding_model.azure_endpoint = "https://example.azure.com" - - await embedding_model.refresh_auth() +async def test_embeddings_are_normalized(embedding_model: PydanticAIEmbeddingModel): + """Test that returned embeddings are unit-normalized.""" + inputs = ["Hello, world", "Foo bar baz", "Testing normalization"] + embeddings = await embedding_model.get_embeddings_nocache(inputs) - embedding_model.azure_token_provider.refresh_token.assert_called_once() # type: ignore - assert embedding_model.async_client is not None + for i in range(len(inputs)): + norm = float(np.linalg.norm(embeddings[i])) + assert abs(norm - 1.0) < 1e-6, f"Embedding {i} not normalized: norm={norm}" @pytest.mark.asyncio -async def test_set_endpoint(monkeypatch: MonkeyPatch): - """Test creating of model with custom endpoint.""" - - monkeypatch.setenv("AZURE_OPENAI_API_KEY", "does-not-matter") - monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Ensure Azure path is used - - # Default - monkeypatch.setenv( - "AZURE_OPENAI_ENDPOINT_EMBEDDING", - "http://localhost:7997?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel() - assert embedding_model.embedding_size == 1536 - assert embedding_model.model_name == "text-embedding-ada-002" - assert embedding_model.endpoint_envvar == "AZURE_OPENAI_ENDPOINT_EMBEDDING" - - # 3-large - monkeypatch.setenv( - "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE", - "http://localhost:7997?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel(model_name="text-embedding-3-large") - assert embedding_model.embedding_size == 3072 - assert embedding_model.model_name == "text-embedding-3-large" - assert embedding_model.endpoint_envvar == "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE" - - # 3-small - monkeypatch.setenv( - "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL", - "http://localhost:7998?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel(model_name="text-embedding-3-small") - assert embedding_model.embedding_size == 1536 - assert embedding_model.model_name == "text-embedding-3-small" - assert embedding_model.endpoint_envvar == "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL" - - # Fully custom with OpenAI - monkeypatch.setenv("OPENAI_API_KEY", "does-not-matter") - monkeypatch.setenv("INFINITY_EMBEDDING_URL", "http://localhost:7997") - embedding_model = AsyncEmbeddingModel( - 1024, "custom_model", endpoint_envvar="INFINITY_EMBEDDING_URL" - ) - assert embedding_model.embedding_size == 1024 - assert embedding_model.model_name == "custom_model" - # NOTE: checking openai.AsyncOpenAI internals - assert embedding_model.async_client is not None - assert embedding_model.async_client.base_url == "http://localhost:7997" - assert embedding_model.async_client.api_key == "does-not-matter" - assert embedding_model.endpoint_envvar == "INFINITY_EMBEDDING_URL" - - # Customized 3-small with Azure (endpoint_envvar must contain "AZURE") - monkeypatch.delenv("OPENAI_API_KEY") # Force Azure path - monkeypatch.setenv( - "AZURE_ALTERNATE_ENDPOINT", - "http://localhost:7999?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel( - 2000, "text-embedding-3-small", endpoint_envvar="AZURE_ALTERNATE_ENDPOINT" - ) - assert embedding_model.embedding_size == 2000 - assert embedding_model.model_name == "text-embedding-3-small" - assert embedding_model.endpoint_envvar == "AZURE_ALTERNATE_ENDPOINT" - - # Allow explicitly setting default embedding size - AsyncEmbeddingModel(1536) - - # Can't customize embedding_size for default model - with pytest.raises(ValueError): - AsyncEmbeddingModel(1024) - - # Not even when default model name specified explicitly - with pytest.raises(ValueError): - AsyncEmbeddingModel(1024, "text-embedding-ada-002") +async def test_embeddings_are_deterministic( + embedding_model: PydanticAIEmbeddingModel, +): + """Test that the same input always produces the same embedding.""" + input_text = "Deterministic test" + e1 = await embedding_model.get_embedding_nocache(input_text) + e2 = await embedding_model.get_embedding_nocache(input_text) + assert np.array_equal(e1, e2) @pytest.mark.asyncio -async def test_embeddings_batching_tiktoken( - fake_embeddings_tiktoken: FakeEmbeddings, monkeypatch: MonkeyPatch +async def test_different_inputs_produce_different_embeddings( + embedding_model: PydanticAIEmbeddingModel, ): - monkeypatch.setenv("OPENAI_API_KEY", "test_key") - - embedding_model = AsyncEmbeddingModel() - assert embedding_model.max_chunk_size == 4096 - - embedding_model.async_client.embeddings = fake_embeddings_tiktoken # type: ignore - - # Check max batch size - inputs = ["a"] * 2049 - embeddings = await embedding_model.get_embeddings(inputs) - assert len(embeddings) == 2049 - assert fake_embeddings_tiktoken.call_count == 2 - - # Check max token size - inputs = ["Very long input longer than 4096 tokens will be truncated" * 500] - embeddings = await embedding_model.get_embeddings(inputs) - assert len(embeddings) == 1 - - fake_embeddings_tiktoken.reset_counter() - - TEST_MAX_TOKEN_SIZE = 10 - TEST_MAX_TOKENS_PER_BATCH = 20 - embedding_model.max_chunk_size = TEST_MAX_TOKEN_SIZE - embedding_model.max_size_per_batch = TEST_MAX_TOKENS_PER_BATCH - fake_embeddings_tiktoken.max_elements_per_batch = TEST_MAX_TOKENS_PER_BATCH - - assert embedding_model.encoding is not None - - token = [500] * 20 # --> 20 tokens - input = [embedding_model.encoding.decode(token)] * 4 - embeddings = await embedding_model.get_embeddings_nocache(input) # type: ignore - - # each input gets truncated to 10 tokens, so 4 inputs fit in 2 batches of 20 tokens - assert fake_embeddings_tiktoken.call_count == 2 - assert len(embeddings) == 4 - - fake_embeddings_tiktoken.reset_counter() - - TEST_MAX_TOKEN_SIZE = 7 - embedding_model.max_chunk_size = TEST_MAX_TOKEN_SIZE - - token = [500] * 20 # --> 20 tokens - input = [embedding_model.encoding.decode(token)] * 5 - embeddings = await embedding_model.get_embeddings_nocache(input) # type: ignore - - # each input gets truncated to 7 tokens, so each batch can hold 2 inputs (14 tokens) - # 5 inputs require 3 batches - assert fake_embeddings_tiktoken.call_count == 3 - assert len(embeddings) == 5 + """Test that different inputs produce different embeddings.""" + e1 = await embedding_model.get_embedding_nocache("Hello") + e2 = await embedding_model.get_embedding_nocache("World") + assert not np.array_equal(e1, e2) @pytest.mark.asyncio -async def test_embeddings_batching( - fake_embeddings: FakeEmbeddings, monkeypatch: MonkeyPatch +async def test_implements_iembedding_model( + embedding_model: PydanticAIEmbeddingModel, ): - monkeypatch.setenv("OPENAI_API_KEY", "test_key") - - embedding_model = AsyncEmbeddingModel(1024, "custom_model") - embedding_model.async_client.embeddings = fake_embeddings # type: ignore - - # Check max batch size - inputs = ["a"] * 2049 - embeddings = await embedding_model.get_embeddings(inputs) - assert len(embeddings) == 2049 - assert fake_embeddings.call_count == 2 - - TEST_MAX_CHAR_SIZE = 10 - TEST_MAX_CHARS_PER_BATCH = 20 - embedding_model.max_chunk_size = TEST_MAX_CHAR_SIZE - embedding_model.max_size_per_batch = TEST_MAX_CHARS_PER_BATCH - fake_embeddings.max_elements_per_batch = TEST_MAX_CHARS_PER_BATCH - - # Check max token size - inputs = ["a" * TEST_MAX_CHAR_SIZE] - embeddings = await embedding_model.get_embeddings_nocache(inputs) - assert len(embeddings) == 1 - assert np.all(embeddings[0] == 0) - - fake_embeddings.reset_counter() - - # Check one over max token size - inputs = ["a" * (TEST_MAX_CHAR_SIZE + 1)] - embeddings = await embedding_model.get_embeddings_nocache(inputs) - assert len(embeddings) == 1 - assert fake_embeddings.call_count == 1 - - fake_embeddings.reset_counter() - - # Check input as large as max_size_per_batch - inputs = ["a" * 10, "a" * 5, "a" * 5] - embeddings = await embedding_model.get_embeddings_nocache(inputs) # type: ignore - assert fake_embeddings.call_count == 1 - assert len(embeddings) == 3 - - fake_embeddings.reset_counter() - - # Check input larger than max_size_per_batch - # max chars per batch is 20, so 10*10 chars requires 5 batches - inputs = ["a" * 10] * 10 - embeddings = await embedding_model.get_embeddings_nocache(inputs) # type: ignore - assert fake_embeddings.call_count == 5 - assert len(embeddings) == 10 + """Test that PydanticAIEmbeddingModel satisfies the IEmbeddingModel protocol.""" + assert isinstance(embedding_model, IEmbeddingModel) diff --git a/tests/test_factory.py b/tests/test_factory.py index 0f62220f..44c45e5f 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -6,7 +6,7 @@ import pytest from typeagent import create_conversation -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @@ -15,7 +15,7 @@ async def test_create_conversation_minimal(): """Test creating a conversation with minimal parameters.""" # Create empty conversation with test model - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, @@ -36,7 +36,7 @@ async def test_create_conversation_minimal(): @pytest.mark.asyncio async def test_create_conversation_with_tags(): """Test creating a conversation with tags.""" - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, @@ -54,7 +54,7 @@ async def test_create_conversation_with_tags(): async def test_create_conversation_and_add_messages(really_needs_auth): """Test the complete workflow: create conversation and add messages.""" # 1. Create empty conversation - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, diff --git a/tests/test_incremental_index.py b/tests/test_incremental_index.py index 12f706a6..ced12f19 100644 --- a/tests/test_incremental_index.py +++ b/tests/test_incremental_index.py @@ -8,7 +8,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.transcripts.transcript import ( @@ -30,7 +30,7 @@ async def test_incremental_index_building(): db_path = os.path.join(tmpdir, "test.db") # Create settings with test model (no API keys needed) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -74,7 +74,7 @@ async def test_incremental_index_building(): # Second ingestion - add more messages and rebuild index print("\n=== Second ingestion ===") - test_model2 = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model2 = create_test_embedding_model() settings2 = ConversationSettings(model=test_model2) settings2.semantic_ref_index_settings.auto_extract_knowledge = False storage2 = SqliteStorageProvider( @@ -136,7 +136,7 @@ async def test_incremental_index_with_vtt_files(): db_path = os.path.join(tmpdir, "test.db") # Create settings with test model (no API keys needed) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -161,9 +161,7 @@ async def test_incremental_index_with_vtt_files(): # Second VTT file ingestion into same database print("\n=== Import second VTT file ===") - settings2 = ConversationSettings( - model=AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - ) + settings2 = ConversationSettings(model=create_test_embedding_model()) settings2.semantic_ref_index_settings.auto_extract_knowledge = False # Ingest the second transcript diff --git a/tests/test_message_text_index_population.py b/tests/test_message_text_index_population.py index 384aaa9c..13d53c00 100644 --- a/tests/test_message_text_index_population.py +++ b/tests/test_message_text_index_population.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -30,7 +30,7 @@ async def test_message_text_index_population_from_database(): try: # Use the test model that's already configured in the system - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index b4e40dd6..ec80498f 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -42,10 +42,10 @@ def message_text_index( mock_text_location_index: MagicMock, ) -> IMessageTextEmbeddingIndex: """Fixture to create a MessageTextIndex instance with a mocked TextToTextLocationIndex.""" - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = MessageTextIndexSettings(embedding_settings) index = MessageTextIndex(settings) @@ -55,10 +55,10 @@ def message_text_index( def test_message_text_index_init(needs_auth: None): """Test initialization of MessageTextIndex.""" - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = MessageTextIndexSettings(embedding_settings) index = MessageTextIndex(settings) @@ -147,11 +147,11 @@ async def test_generate_embedding(needs_auth: None): """Test generating an embedding for a message without mocking.""" import numpy as np - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings # Create real MessageTextIndex with test model - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = MessageTextIndexSettings(embedding_settings) index = MessageTextIndex(settings) @@ -205,14 +205,14 @@ async def test_build_message_index(needs_auth: None): ] # Create storage provider asynchronously - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, RelatedTermIndexSettings, ) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_podcast_incremental.py b/tests/test_podcast_incremental.py index 92d5ad32..4b1732d6 100644 --- a/tests/test_podcast_incremental.py +++ b/tests/test_podcast_incremental.py @@ -8,7 +8,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.podcasts.podcast import Podcast, PodcastMessage, PodcastMessageMeta from typeagent.storage.sqlite.provider import SqliteStorageProvider @@ -20,7 +20,7 @@ async def test_podcast_add_messages_with_indexing(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -57,7 +57,7 @@ async def test_podcast_add_messages_batched(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False diff --git a/tests/test_property_index_population.py b/tests/test_property_index_population.py index 8b751bb7..f6cc3edb 100644 --- a/tests/test_property_index_population.py +++ b/tests/test_property_index_population.py @@ -8,10 +8,9 @@ import tempfile from dotenv import load_dotenv -import numpy as np import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import kplib from typeagent.knowpro.convsettings import ( @@ -23,16 +22,6 @@ from typeagent.storage import SqliteStorageProvider -class MockEmbeddingModel(AsyncEmbeddingModel): - def __init__(self): - super().__init__(embedding_size=3, model_name="test") - - async def get_embeddings(self, keys: list[str]) -> np.ndarray: - result = np.random.rand(len(keys), 3).astype(np.float32) - norms = np.linalg.norm(result, axis=1, keepdims=True) - return result / norms - - @pytest.mark.asyncio async def test_property_index_population_from_database(really_needs_auth): """Test that property index is correctly populated when reopening a database.""" @@ -42,7 +31,7 @@ async def test_property_index_population_from_database(really_needs_auth): temp_db_file.close() try: - embedding_model = MockEmbeddingModel() + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) @@ -98,7 +87,7 @@ async def test_property_index_population_from_database(really_needs_auth): # Reopen database and verify property index # Use the same embedding settings to avoid dimension mismatch - embedding_model2 = MockEmbeddingModel() + embedding_model2 = create_test_embedding_model() embedding_settings2 = TextEmbeddingIndexSettings(embedding_model2) message_text_settings2 = MessageTextIndexSettings(embedding_settings2) related_terms_settings2 = RelatedTermIndexSettings(embedding_settings2) diff --git a/tests/test_query_method.py b/tests/test_query_method.py index ac605823..bcbf2e00 100644 --- a/tests/test_query_method.py +++ b/tests/test_query_method.py @@ -6,7 +6,7 @@ import pytest from typeagent import create_conversation -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @@ -15,7 +15,7 @@ async def test_query_method_basic(really_needs_auth: None): """Test the basic query method workflow.""" # Create a conversation with some test data - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, @@ -60,7 +60,7 @@ async def test_query_method_basic(really_needs_auth: None): @pytest.mark.asyncio async def test_query_method_empty_conversation(really_needs_auth: None): """Test query method on an empty conversation.""" - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, diff --git a/tests/test_related_terms_fast.py b/tests/test_related_terms_fast.py index 3919666f..fbcf60c5 100644 --- a/tests/test_related_terms_fast.py +++ b/tests/test_related_terms_fast.py @@ -9,7 +9,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import SemanticRef, TextLocation, TextRange from typeagent.knowpro.kplib import ConcreteEntity @@ -26,7 +26,7 @@ async def test_related_terms_index_minimal(): try: # Create minimal test data with test embedding model (no API keys needed) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) # Use a simple storage provider without AI embeddings diff --git a/tests/test_related_terms_index_population.py b/tests/test_related_terms_index_population.py index 9d16936f..9de6f015 100644 --- a/tests/test_related_terms_index_population.py +++ b/tests/test_related_terms_index_population.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import kplib from typeagent.knowpro.convsettings import ( @@ -32,7 +32,7 @@ async def test_related_terms_index_population_from_database(really_needs_auth): try: # Use the test model that's already configured in the system - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_secindex.py b/tests/test_secindex.py index a9008aa3..39665b05 100644 --- a/tests/test_secindex.py +++ b/tests/test_secindex.py @@ -3,7 +3,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -29,9 +29,9 @@ def simple_conversation() -> FakeConversation: @pytest.fixture def conversation_settings(needs_auth: None) -> ConversationSettings: - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model - model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + model = create_test_embedding_model() return ConversationSettings(model) @@ -41,7 +41,7 @@ def test_conversation_secondary_indexes_initialization( """Test initialization of ConversationSecondaryIndexes.""" storage_provider = memory_storage # Create proper settings for testing - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = RelatedTermIndexSettings(embedding_settings) indexes = ConversationSecondaryIndexes(storage_provider, settings) diff --git a/tests/test_secindex_storage_integration.py b/tests/test_secindex_storage_integration.py index a050771b..15738bb6 100644 --- a/tests/test_secindex_storage_integration.py +++ b/tests/test_secindex_storage_integration.py @@ -4,7 +4,7 @@ # Test that ConversationSecondaryIndexes now uses storage provider properly import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import RelatedTermIndexSettings from typeagent.knowpro.secindex import ConversationSecondaryIndexes @@ -19,7 +19,7 @@ async def test_secondary_indexes_use_storage_provider( storage_provider = memory_storage # Create test settings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index 0d0bfb57..0286f5bc 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -6,7 +6,8 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, IEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH from typeagent.transcripts.transcript import ( @@ -224,10 +225,8 @@ def test_transcript_message_creation(): @pytest.mark.asyncio async def test_transcript_creation(): """Test creating an empty transcript.""" - from typeagent.aitools.embeddings import TEST_MODEL_NAME - # Create a minimal transcript for testing structure - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() settings = ConversationSettings(embedding_model) transcript = await Transcript.create( diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index be8d2c07..4a92b23e 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -5,10 +5,9 @@ import pytest from typeagent.aitools.embeddings import ( - AsyncEmbeddingModel, NormalizedEmbedding, - TEST_MODEL_NAME, ) +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase @@ -19,9 +18,7 @@ def vector_base() -> VectorBase: def make_vector_base() -> VectorBase: - settings = TextEmbeddingIndexSettings( - AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - ) + settings = TextEmbeddingIndexSettings(create_test_embedding_model()) return VectorBase(settings) @@ -61,8 +58,8 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): assert len(bulk_vector_base) == len(vector_base) np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) - sequential_cache = vector_base._model._embedding_cache # type: ignore[attr-defined] - bulk_cache = bulk_vector_base._model._embedding_cache # type: ignore[attr-defined] + sequential_cache = vector_base._model._cache # type: ignore[attr-defined] + bulk_cache = bulk_vector_base._model._cache # type: ignore[attr-defined] assert set(sequential_cache.keys()) == set(bulk_cache.keys()) for key in keys: np.testing.assert_array_equal(bulk_cache[key], sequential_cache[key]) @@ -85,7 +82,7 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp assert len(vector_base) == len(sample_embeddings) assert ( - vector_base._model._embedding_cache == {} # type: ignore[attr-defined] + vector_base._model._cache == {} # type: ignore[attr-defined] ), "Cache should remain empty when cache=False" @@ -106,7 +103,7 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam assert len(vector_base) == len(sample_embeddings) assert ( - vector_base._model._embedding_cache == {} # type: ignore[attr-defined] + vector_base._model._cache == {} # type: ignore[attr-defined] ), "Cache should remain empty when cache=False" From dec2e6fe5c2c3f313a35f9b54ce0a0ae63cdeb0b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 11:46:53 -0800 Subject: [PATCH 14/23] Move in-function imports to module level in tests/ --- tests/test_mcp_server.py | 22 +++-------- tests/test_messageindex.py | 3 +- tests/test_model_adapters.py | 51 ++++++------------------- tests/test_sqlitestore.py | 3 +- tests/test_storage_providers_unified.py | 5 +-- tests/test_transcripts.py | 5 +-- tests/test_utils.py | 7 ++-- 7 files changed, 25 insertions(+), 71 deletions(-) diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 03fd0e69..24933ca2 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -3,16 +3,21 @@ """End-to-end tests for the MCP server.""" +import json import os import sys from typing import Any import pytest -from mcp import StdioServerParameters +from mcp import ClientSession, StdioServerParameters from mcp.client.session import ClientSession as ClientSessionType +from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent +from openai.types.chat import ChatCompletionMessageParam + +from typeagent.aitools.utils import create_async_openai_client from conftest import EPISODE_53_INDEX @@ -38,11 +43,6 @@ async def sampling_callback( params: CreateMessageRequestParams, ) -> CreateMessageResult: """Sampling callback that uses OpenAI to generate responses.""" - # Use OpenAI to generate a response - from openai.types.chat import ChatCompletionMessageParam - - from typeagent.aitools.utils import create_async_openai_client - client = create_async_openai_client() # Convert MCP SamplingMessage to OpenAI format @@ -91,9 +91,6 @@ async def test_mcp_server_query_conversation_slow( really_needs_auth, server_params: StdioServerParameters ): """Test the query_conversation tool end-to-end using MCP client.""" - from mcp import ClientSession - from mcp.client.stdio import stdio_client - # Pass through environment variables needed for authentication # otherwise this test will fail in the CI on Windows only if not (server_params.env) is None: @@ -135,8 +132,6 @@ async def test_mcp_server_query_conversation_slow( response_text = content_item.text # Parse response (it should be JSON with success, answer, time_used) - import json - try: response_data = json.loads(response_text) except json.JSONDecodeError as e: @@ -158,9 +153,6 @@ async def test_mcp_server_query_conversation_slow( @pytest.mark.asyncio async def test_mcp_server_empty_question(server_params: StdioServerParameters): """Test the query_conversation tool with an empty question.""" - from mcp import ClientSession - from mcp.client.stdio import stdio_client - # Create client session and connect to server async with stdio_client(server_params) as (read, write): async with ClientSession( @@ -183,8 +175,6 @@ async def test_mcp_server_empty_question(server_params: StdioServerParameters): assert isinstance(content_item, TextContent) response_text = content_item.text - import json - response_data = json.loads(response_text) assert response_data["success"] is False assert "No question provided" in response_data["answer"] diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index ec80498f..f91ac933 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -4,6 +4,7 @@ from typing import cast from unittest.mock import AsyncMock, MagicMock +import numpy as np import pytest from typeagent.knowpro.convsettings import MessageTextIndexSettings @@ -145,8 +146,6 @@ async def test_lookup_messages_in_subset( @pytest.mark.asyncio async def test_generate_embedding(needs_auth: None): """Test generating an embedding for a message without mocking.""" - import numpy as np - from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index 5e773430..dbff97cd 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -1,9 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from unittest.mock import AsyncMock + import numpy as np import pytest +from pydantic_ai import Embedder +from pydantic_ai.embeddings import EmbeddingResult +from pydantic_ai.messages import ( + ModelResponse, + SystemPromptPart, + TextPart, + UserPromptPart, +) +from pydantic_ai.models import Model import typechat from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding @@ -57,11 +68,6 @@ def test_chat_model_is_typechat_model() -> None: @pytest.mark.asyncio async def test_chat_adapter_complete() -> None: """PydanticAIChatModel wraps a pydantic_ai Model.""" - from unittest.mock import AsyncMock - - from pydantic_ai.messages import ModelResponse, TextPart - from pydantic_ai.models import Model - mock_model = AsyncMock(spec=Model) mock_model.request.return_value = ModelResponse(parts=[TextPart(content="hello")]) @@ -74,11 +80,6 @@ async def test_chat_adapter_complete() -> None: @pytest.mark.asyncio async def test_chat_adapter_prompt_sections() -> None: """PydanticAIChatModel handles list[PromptSection] prompts.""" - from unittest.mock import AsyncMock - - from pydantic_ai.messages import ModelResponse, TextPart - from pydantic_ai.models import Model - mock_model = AsyncMock(spec=Model) mock_model.request.return_value = ModelResponse( parts=[TextPart(content="response")] @@ -98,8 +99,6 @@ async def test_chat_adapter_prompt_sections() -> None: messages = call_args[0][0] assert len(messages) == 1 request = messages[0] - from pydantic_ai.messages import SystemPromptPart, UserPromptPart - assert isinstance(request.parts[0], SystemPromptPart) assert isinstance(request.parts[1], UserPromptPart) @@ -117,11 +116,6 @@ def test_embedding_model_is_iembedding_model() -> None: @pytest.mark.asyncio async def test_embedding_adapter_single() -> None: """PydanticAIEmbeddingModel computes a single normalized embedding.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - from pydantic_ai.embeddings import EmbeddingResult - mock_embedder = AsyncMock(spec=Embedder) raw_vec = [3.0, 4.0, 0.0] mock_embedder.embed_documents.return_value = EmbeddingResult( @@ -142,11 +136,6 @@ async def test_embedding_adapter_single() -> None: @pytest.mark.asyncio async def test_embedding_adapter_probes_size() -> None: """embedding_size is discovered from the first embedding call.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - from pydantic_ai.embeddings import EmbeddingResult - mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], @@ -165,11 +154,6 @@ async def test_embedding_adapter_probes_size() -> None: @pytest.mark.asyncio async def test_embedding_adapter_batch() -> None: """PydanticAIEmbeddingModel computes batch embeddings.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - from pydantic_ai.embeddings import EmbeddingResult - mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0], [0.0, 1.0]], @@ -187,11 +171,6 @@ async def test_embedding_adapter_batch() -> None: @pytest.mark.asyncio async def test_embedding_adapter_caching() -> None: """Caching avoids re-computing embeddings.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - from pydantic_ai.embeddings import EmbeddingResult - mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], @@ -212,10 +191,6 @@ async def test_embedding_adapter_caching() -> None: @pytest.mark.asyncio async def test_embedding_adapter_add_embedding() -> None: """add_embedding() populates the cache.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - mock_embedder = AsyncMock(spec=Embedder) adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) vec: NormalizedEmbedding = np.array([1.0, 0.0, 0.0], dtype=np.float32) @@ -229,10 +204,6 @@ async def test_embedding_adapter_add_embedding() -> None: @pytest.mark.asyncio async def test_embedding_adapter_empty_batch() -> None: """Empty batch returns empty array with known size.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - mock_embedder = AsyncMock(spec=Embedder) adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 4) result = await adapter.get_embeddings_nocache([]) diff --git a/tests/test_sqlitestore.py b/tests/test_sqlitestore.py index ad784221..704ab0bb 100644 --- a/tests/test_sqlitestore.py +++ b/tests/test_sqlitestore.py @@ -3,6 +3,7 @@ from collections.abc import AsyncGenerator from dataclasses import field +from datetime import datetime import pytest import pytest_asyncio @@ -128,8 +129,6 @@ async def test_sqlite_timestamp_index( dummy_sqlite_storage_provider: SqliteStorageProvider[DummyMessage], ): """Test SqliteTimestampToTextRangeIndex functionality.""" - from datetime import datetime - from typeagent.knowpro.interfaces import DateRange # Set up database with some messages diff --git a/tests/test_storage_providers_unified.py b/tests/test_storage_providers_unified.py index 14829f03..d0ecb9c5 100644 --- a/tests/test_storage_providers_unified.py +++ b/tests/test_storage_providers_unified.py @@ -9,6 +9,8 @@ """ from dataclasses import field +import os +import tempfile from typing import assert_never, AsyncGenerator import pytest @@ -605,9 +607,6 @@ async def test_storage_provider_independence( ) # Create two sqlite providers (with different temp files) - import os - import tempfile - temp_file1 = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) temp_path1 = temp_file1.name temp_file1.close() diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index 0286f5bc..9d98ae88 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -5,6 +5,7 @@ import os import pytest +import webvtt from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.model_adapters import create_test_embedding_model @@ -102,8 +103,6 @@ def conversation_settings( @pytest.mark.asyncio async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings): """Test importing a VTT file into a Transcript object.""" - import webvtt - from typeagent.storage.memory.collections import ( MemoryMessageCollection, MemorySemanticRefCollection, @@ -253,8 +252,6 @@ async def test_transcript_knowledge_extraction_slow( 4. Verifies both mechanical extraction (entities/actions from metadata) and LLM extraction (topics from content) work correctly """ - import webvtt - from typeagent.storage.memory.collections import ( MemoryMessageCollection, MemorySemanticRefCollection, diff --git a/tests/test_utils.py b/tests/test_utils.py index cb95c93a..ceea367d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,6 +7,9 @@ from dotenv import load_dotenv +import pydantic.dataclasses +import typechat + import typeagent.aitools.utils as utils @@ -37,14 +40,10 @@ def test_load_dotenv(really_needs_auth): def test_create_translator(): - import typechat - class DummyModel(typechat.TypeChatLanguageModel): async def complete(self, *args, **kwargs) -> typechat.Result: return typechat.Failure("dummy response") - import pydantic.dataclasses - @pydantic.dataclasses.dataclass class DummySchema: pass From 909247d7c86da51019ec87d2def57188022f2844 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 12:01:56 -0800 Subject: [PATCH 15/23] Don't re-export create_typechat_model from convknowledge.py --- src/typeagent/emails/email_memory.py | 3 +-- src/typeagent/knowpro/conversation_base.py | 4 ++-- src/typeagent/knowpro/convknowledge.py | 5 +---- src/typeagent/knowpro/knowledge.py | 3 ++- tools/query.py | 3 +-- 5 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/typeagent/emails/email_memory.py b/src/typeagent/emails/email_memory.py index d6cf06cb..3da149df 100644 --- a/src/typeagent/emails/email_memory.py +++ b/src/typeagent/emails/email_memory.py @@ -12,7 +12,6 @@ from ..knowpro import ( answer_response_schema, answers, - convknowledge, search_query_schema, searchlang, ) @@ -24,7 +23,7 @@ class EmailMemorySettings: def __init__(self, conversation_settings: ConversationSettings) -> None: - self.language_model = convknowledge.create_typechat_model() + self.language_model = utils.create_typechat_model() self.query_translator = utils.create_translator( self.language_model, search_query_schema.SearchQuery ) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 732e239a..74d34b19 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -352,12 +352,12 @@ async def query( """ # Create translators lazily (once per conversation instance) if self._query_translator is None: - model = convknowledge.create_typechat_model() + model = utils.create_typechat_model() self._query_translator = utils.create_translator( model, search_query_schema.SearchQuery ) if self._answer_translator is None: - model = convknowledge.create_typechat_model() + model = utils.create_typechat_model() self._answer_translator = utils.create_translator( model, answer_response_schema.AnswerResponse ) diff --git a/src/typeagent/knowpro/convknowledge.py b/src/typeagent/knowpro/convknowledge.py index 53a10b25..6f9000d3 100644 --- a/src/typeagent/knowpro/convknowledge.py +++ b/src/typeagent/knowpro/convknowledge.py @@ -6,10 +6,7 @@ import typechat from . import kplib -from ..aitools.utils import create_typechat_model # Re-export for backward compat - -# Re-export: callers may still do ``convknowledge.create_typechat_model()``. -__all__ = ["create_typechat_model", "KnowledgeExtractor"] +from ..aitools.utils import create_typechat_model @dataclass diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index bda7397f..48d5b8fa 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -8,6 +8,7 @@ from typechat import Result, TypeChatLanguageModel from . import convknowledge, kplib +from ..aitools import utils from .interfaces import IKnowledgeExtractor @@ -15,7 +16,7 @@ def create_knowledge_extractor( chat_model: TypeChatLanguageModel | None = None, ) -> convknowledge.KnowledgeExtractor: """Create a knowledge extractor using the given Chat Model.""" - chat_model = chat_model or convknowledge.create_typechat_model() + chat_model = chat_model or utils.create_typechat_model() extractor = convknowledge.KnowledgeExtractor( chat_model, max_chars_per_chunk=4096, merge_action_knowledge=False ) diff --git a/tools/query.py b/tools/query.py index 24b1f4c9..c9817803 100644 --- a/tools/query.py +++ b/tools/query.py @@ -36,7 +36,6 @@ from typeagent.knowpro import ( answer_response_schema, answers, - convknowledge, kplib, query, search, @@ -576,7 +575,7 @@ async def main(): "Error: non-empty --search-results required for batch mode." ) - model = convknowledge.create_typechat_model() + model = utils.create_typechat_model() query_translator = utils.create_translator(model, search_query_schema.SearchQuery) if args.alt_schema: if args.verbose: From 8807cc57e417f64983944cded8cac391adaaebaa Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 12:12:19 -0800 Subject: [PATCH 16/23] Remove redundant tests that Chat/Embedding models subclass their protocols --- tests/test_model_adapters.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index dbff97cd..c7ccf8f9 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -17,7 +17,7 @@ from pydantic_ai.models import Model import typechat -from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding +from typeagent.aitools.embeddings import NormalizedEmbedding from typeagent.aitools.model_adapters import ( configure_models, create_chat_model, @@ -60,11 +60,6 @@ def test_default_embedding_size_is_zero() -> None: # --------------------------------------------------------------------------- -def test_chat_model_is_typechat_model() -> None: - """PydanticAIChatModel inherits from TypeChatLanguageModel.""" - assert typechat.TypeChatLanguageModel in PydanticAIChatModel.__mro__ - - @pytest.mark.asyncio async def test_chat_adapter_complete() -> None: """PydanticAIChatModel wraps a pydantic_ai Model.""" @@ -108,11 +103,6 @@ async def test_chat_adapter_prompt_sections() -> None: # --------------------------------------------------------------------------- -def test_embedding_model_is_iembedding_model() -> None: - """PydanticAIEmbeddingModel inherits from IEmbeddingModel.""" - assert IEmbeddingModel in PydanticAIEmbeddingModel.__mro__ - - @pytest.mark.asyncio async def test_embedding_adapter_single() -> None: """PydanticAIEmbeddingModel computes a single normalized embedding.""" @@ -223,4 +213,3 @@ def test_configure_models_returns_correct_types( chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") assert isinstance(chat, PydanticAIChatModel) assert isinstance(embedder, PydanticAIEmbeddingModel) - assert typechat.TypeChatLanguageModel in type(chat).__mro__ From 68d3082faafaa388bf2df2da3c47ddc1ee5fa730 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 12:28:45 -0800 Subject: [PATCH 17/23] Avoid type-ignore in favor of isinstance --- tests/test_vectorbase.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 4a92b23e..04c53c36 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -7,7 +7,10 @@ from typeagent.aitools.embeddings import ( NormalizedEmbedding, ) -from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.aitools.model_adapters import ( + create_test_embedding_model, + PydanticAIEmbeddingModel, +) from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase @@ -58,8 +61,10 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): assert len(bulk_vector_base) == len(vector_base) np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) - sequential_cache = vector_base._model._cache # type: ignore[attr-defined] - bulk_cache = bulk_vector_base._model._cache # type: ignore[attr-defined] + assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert isinstance(bulk_vector_base._model, PydanticAIEmbeddingModel) + sequential_cache = vector_base._model._cache + bulk_cache = bulk_vector_base._model._cache assert set(sequential_cache.keys()) == set(bulk_cache.keys()) for key in keys: np.testing.assert_array_equal(bulk_cache[key], sequential_cache[key]) @@ -81,9 +86,8 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp await vector_base.add_key(key, cache=False) assert len(vector_base) == len(sample_embeddings) - assert ( - vector_base._model._cache == {} # type: ignore[attr-defined] - ), "Cache should remain empty when cache=False" + assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" @pytest.mark.asyncio @@ -102,9 +106,8 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam await vector_base.add_keys(keys, cache=False) assert len(vector_base) == len(sample_embeddings) - assert ( - vector_base._model._cache == {} # type: ignore[attr-defined] - ), "Cache should remain empty when cache=False" + assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" @pytest.mark.asyncio From 3697f89a043ddb5bd28b04d44435cf0513ad9179 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 15:39:37 -0800 Subject: [PATCH 18/23] Remove ModelWrapper, create_typechat_model; use create_chat_model everywhere --- AGENTS.md | 3 ++ src/typeagent/aitools/model_adapters.py | 43 +++++++++++++--- src/typeagent/aitools/utils.py | 59 ---------------------- src/typeagent/emails/email_memory.py | 4 +- src/typeagent/knowpro/conversation_base.py | 6 +-- src/typeagent/knowpro/convknowledge.py | 4 +- src/typeagent/knowpro/knowledge.py | 4 +- tools/query.py | 4 +- 8 files changed, 49 insertions(+), 78 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f08200b2..dafd9da6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -93,3 +93,6 @@ please follow these guidelines: * **Code Validation**: Don't use `py_compile` for syntax checking. Use `pyright` or `make check` instead for proper type checking and validation. + +* **Deprecations**: Don't deprecate things -- just delete them and fix the usage sites. + Don't create backward compatibility APIs or exports or whatever. Fix the usage sites. diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 4acc0ff8..7f73d2a5 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -198,20 +198,34 @@ def _make_azure_provider( may contain a full Azure deployment URL (including path and ``api-version`` query parameter) — the same format used throughout this codebase. + + When ``AZURE_OPENAI_API_KEY`` is set to ``"identity"``, the client + uses Azure Managed Identity via a token provider callback, which + refreshes tokens automatically before each request. """ from openai import AsyncAzureOpenAI from pydantic_ai.providers.azure import AzureProvider - from .utils import get_azure_api_key, parse_azure_endpoint + from .utils import parse_azure_endpoint raw_key = os.environ[api_key_envvar] - api_key = get_azure_api_key(raw_key) azure_endpoint, api_version = parse_azure_endpoint(endpoint_envvar) - client = AsyncAzureOpenAI( - azure_endpoint=azure_endpoint, - api_version=api_version, - api_key=api_key, - ) + + if raw_key.lower() == "identity": + from .auth import get_shared_token_provider + + token_provider = get_shared_token_provider() + client = AsyncAzureOpenAI( + azure_endpoint=azure_endpoint, + api_version=api_version, + azure_ad_token_provider=token_provider.get_token, + ) + else: + client = AsyncAzureOpenAI( + azure_endpoint=azure_endpoint, + api_version=api_version, + api_key=raw_key, + ) return AzureProvider(openai_client=client) @@ -220,8 +234,11 @@ def _make_azure_provider( # --------------------------------------------------------------------------- +DEFAULT_CHAT_SPEC = "openai:gpt-4o" + + def create_chat_model( - model_spec: str, + model_spec: str | None = None, ) -> PydanticAIChatModel: """Create a chat model from a ``provider:model`` spec. @@ -229,12 +246,22 @@ def create_chat_model( If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. + If *model_spec* is ``None``, it is constructed from the ``OPENAI_MODEL`` + environment variable (falling back to :data:`DEFAULT_CHAT_SPEC`). + Examples:: + model = create_chat_model() # uses OPENAI_MODEL or gpt-4o model = create_chat_model("openai:gpt-4o") model = create_chat_model("anthropic:claude-sonnet-4-20250514") model = create_chat_model("google:gemini-2.0-flash") """ + if model_spec is None: + openai_model = os.getenv("OPENAI_MODEL") + if openai_model: + model_spec = f"openai:{openai_model}" + else: + model_spec = DEFAULT_CHAT_SPEC provider, _, model_name = model_spec.partition(":") if _needs_azure_fallback(provider): from pydantic_ai.models.openai import OpenAIChatModel diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index 9d57b854..b9150334 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -3,7 +3,6 @@ """Utilities that are hard to fit in any specific module.""" -import asyncio from contextlib import contextmanager import difflib import os @@ -17,8 +16,6 @@ import typechat -from .auth import AzureTokenProvider, get_shared_token_provider - @contextmanager def timelog(label: str, verbose: bool = True): @@ -90,62 +87,6 @@ def create_translator[T]( return translator -# TODO: Make these parameters that can be configured (e.g. from command line). -DEFAULT_MAX_RETRY_ATTEMPTS = 0 -DEFAULT_TIMEOUT_SECONDS = 25 - - -class ModelWrapper(typechat.TypeChatLanguageModel): - """Wraps a TypeChat model to handle Azure token refresh.""" - - def __init__( - self, - base_model: typechat.TypeChatLanguageModel, - token_provider: AzureTokenProvider, - ): - self.base_model = base_model - self.token_provider = token_provider - - async def complete( - self, prompt: str | list[typechat.PromptSection] - ) -> typechat.Result[str]: - if self.token_provider.needs_refresh(): - loop = asyncio.get_running_loop() - api_key = await loop.run_in_executor( - None, self.token_provider.refresh_token - ) - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - env[key_name] = api_key - self.base_model = typechat.create_language_model(env) - self.base_model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - return await self.base_model.complete(prompt) - - -def create_typechat_model() -> typechat.TypeChatLanguageModel: - """Create a TypeChat language model using OpenAI or Azure OpenAI. - - Auto-detects the provider from ``OPENAI_API_KEY`` / ``AZURE_OPENAI_API_KEY`` - environment variables. - - For explicit provider selection, use :func:`model_adapters.create_chat_model` - with a spec string like ``"openai:gpt-4o"`` or ``"azure:my-deployment"``. - """ - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - key = env.get(key_name) - shared_token_provider: AzureTokenProvider | None = None - if key is not None and key.lower() == "identity": - shared_token_provider = get_shared_token_provider() - env[key_name] = shared_token_provider.get_token() - model = typechat.create_language_model(env) - model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS - if shared_token_provider is not None: - model = ModelWrapper(model, shared_token_provider) - return model - - # Vibe-coded by o4-mini-high def list_diff(label_a, a, label_b, b, max_items): """Print colorized diff between two sorted list of numbers.""" diff --git a/src/typeagent/emails/email_memory.py b/src/typeagent/emails/email_memory.py index 3da149df..6dd50cc4 100644 --- a/src/typeagent/emails/email_memory.py +++ b/src/typeagent/emails/email_memory.py @@ -8,7 +8,7 @@ import typechat -from ..aitools import utils +from ..aitools import model_adapters, utils from ..knowpro import ( answer_response_schema, answers, @@ -23,7 +23,7 @@ class EmailMemorySettings: def __init__(self, conversation_settings: ConversationSettings) -> None: - self.language_model = utils.create_typechat_model() + self.language_model = model_adapters.create_chat_model() self.query_translator = utils.create_translator( self.language_model, search_query_schema.SearchQuery ) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 74d34b19..07ea1553 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -18,7 +18,7 @@ searchlang, secindex, ) -from ..aitools import utils +from ..aitools import model_adapters, utils from ..storage.memory import semrefindex from .convsettings import ConversationSettings from .interfaces import ( @@ -352,12 +352,12 @@ async def query( """ # Create translators lazily (once per conversation instance) if self._query_translator is None: - model = utils.create_typechat_model() + model = model_adapters.create_chat_model() self._query_translator = utils.create_translator( model, search_query_schema.SearchQuery ) if self._answer_translator is None: - model = utils.create_typechat_model() + model = model_adapters.create_chat_model() self._answer_translator = utils.create_translator( model, answer_response_schema.AnswerResponse ) diff --git a/src/typeagent/knowpro/convknowledge.py b/src/typeagent/knowpro/convknowledge.py index 6f9000d3..fe1d5f5c 100644 --- a/src/typeagent/knowpro/convknowledge.py +++ b/src/typeagent/knowpro/convknowledge.py @@ -6,12 +6,12 @@ import typechat from . import kplib -from ..aitools.utils import create_typechat_model +from ..aitools.model_adapters import create_chat_model @dataclass class KnowledgeExtractor: - model: typechat.TypeChatLanguageModel = field(default_factory=create_typechat_model) + model: typechat.TypeChatLanguageModel = field(default_factory=create_chat_model) max_chars_per_chunk: int = 2048 merge_action_knowledge: bool = ( False # TODO: Implement merge_action_knowledge_into_response diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index 48d5b8fa..dbcfe206 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -8,7 +8,7 @@ from typechat import Result, TypeChatLanguageModel from . import convknowledge, kplib -from ..aitools import utils +from ..aitools import model_adapters from .interfaces import IKnowledgeExtractor @@ -16,7 +16,7 @@ def create_knowledge_extractor( chat_model: TypeChatLanguageModel | None = None, ) -> convknowledge.KnowledgeExtractor: """Create a knowledge extractor using the given Chat Model.""" - chat_model = chat_model or utils.create_typechat_model() + chat_model = chat_model or model_adapters.create_chat_model() extractor = convknowledge.KnowledgeExtractor( chat_model, max_chars_per_chunk=4096, merge_action_knowledge=False ) diff --git a/tools/query.py b/tools/query.py index c9817803..b3a35091 100644 --- a/tools/query.py +++ b/tools/query.py @@ -32,7 +32,7 @@ import typechat -from typeagent.aitools import embeddings, utils +from typeagent.aitools import embeddings, model_adapters, utils from typeagent.knowpro import ( answer_response_schema, answers, @@ -575,7 +575,7 @@ async def main(): "Error: non-empty --search-results required for batch mode." ) - model = utils.create_typechat_model() + model = model_adapters.create_chat_model() query_translator = utils.create_translator(model, search_query_schema.SearchQuery) if args.alt_schema: if args.verbose: From 087b7a3f6727d0be8f411ab35458360e1cedcfab Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 16:30:06 -0800 Subject: [PATCH 19/23] Split up *EmbeddingModel into IEmbedder and CachingEmbeddingModel --- src/typeagent/aitools/embeddings.py | 90 +++++++++++++++++++++++-- src/typeagent/aitools/model_adapters.py | 61 +++++++---------- src/typeagent/aitools/vectorbase.py | 1 - tests/test_embeddings.py | 35 +++++----- tests/test_model_adapters.py | 29 ++++---- tests/test_vectorbase.py | 10 +-- 6 files changed, 146 insertions(+), 80 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index f56db25d..e0e04b7a 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -11,16 +11,46 @@ @runtime_checkable -class IEmbeddingModel(Protocol): - """Provider-agnostic interface for embedding models. +class IEmbedder(Protocol): + """Minimal provider interface for embedding models. Implement this protocol to add support for a new embedding provider - (e.g. Anthropic, Gemini, local models). The production implementation - is :class:`~typeagent.aitools.model_adapters.PydanticAIEmbeddingModel`. + (e.g. Anthropic, Gemini, local models). Only raw embedding computation + is required; caching is handled by :class:`CachingEmbeddingModel`. + + The production implementation is + :class:`~typeagent.aitools.model_adapters.PydanticAIEmbedder`. + """ + + @property + def model_name(self) -> str: ... + + @property + def embedding_size(self) -> int: ... + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + """Compute a single embedding without caching.""" + ... + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + """Compute embeddings for a batch of strings without caching.""" + ... + + +@runtime_checkable +class IEmbeddingModel(Protocol): + """Consumer-facing interface for embedding models with caching. + + This extends the provider interface (:class:`IEmbedder`) with caching + methods. Use :class:`CachingEmbeddingModel` to wrap an :class:`IEmbedder` + and get a ready-to-use ``IEmbeddingModel``. """ - model_name: str - embedding_size: int + @property + def model_name(self) -> str: ... + + @property + def embedding_size(self) -> int: ... def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: """Cache an already-computed embedding under the given key.""" @@ -43,6 +73,54 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: ... +class CachingEmbeddingModel: + """Wraps an :class:`IEmbedder` with an in-memory embedding cache. + + This shared base class implements the caching logic once, so individual + embedding providers only need to implement the minimal :class:`IEmbedder` + protocol (``get_embedding_nocache`` / ``get_embeddings_nocache``). + """ + + def __init__(self, embedder: IEmbedder) -> None: + self._embedder = embedder + self._cache: dict[str, NormalizedEmbedding] = {} + + @property + def model_name(self) -> str: + return self._embedder.model_name + + @property + def embedding_size(self) -> int: + return self._embedder.embedding_size + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + self._cache[key] = embedding + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + return await self._embedder.get_embedding_nocache(input) + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + return await self._embedder.get_embeddings_nocache(input) + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + cached = self._cache.get(key) + if cached is not None: + return cached + embedding = await self._embedder.get_embedding_nocache(key) + self._cache[key] = embedding + return embedding + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + if not keys: + return await self._embedder.get_embeddings_nocache([]) + missing_keys = [k for k in keys if k not in self._cache] + if missing_keys: + fresh = await self._embedder.get_embeddings_nocache(missing_keys) + for i, k in enumerate(missing_keys): + self._cache[k] = fresh[i] + return np.array([self._cache[k] for k in keys], dtype=np.float32) + + DEFAULT_MODEL_NAME = "text-embedding-ada-002" DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002) DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 7f73d2a5..6ccf8292 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -46,7 +46,11 @@ from pydantic_ai.models import infer_model, Model, ModelRequestParameters import typechat -from .embeddings import IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings +from .embeddings import ( + CachingEmbeddingModel, + NormalizedEmbedding, + NormalizedEmbeddings, +) # --------------------------------------------------------------------------- # Chat model adapter @@ -92,13 +96,13 @@ async def complete( # --------------------------------------------------------------------------- -class PydanticAIEmbeddingModel(IEmbeddingModel): - """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbeddingModel`. +class PydanticAIEmbedder: + """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbedder`. This lets any pydantic_ai embedding provider (OpenAI, Cohere, Google, …) - be used wherever the codebase expects an ``IEmbeddingModel``, including - :class:`~typeagent.aitools.vectorbase.VectorBase` and - :class:`~typeagent.knowpro.convsettings.ConversationSettings`. + be used wherever the codebase expects an ``IEmbedder``. Wrap in + :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` to get a + ready-to-use ``IEmbeddingModel`` with caching. If *embedding_size* is not given, it is probed automatically by making a single embedding call. @@ -116,10 +120,6 @@ def __init__( self._embedder = embedder self.model_name = model_name self.embedding_size = embedding_size - self._cache: dict[str, NormalizedEmbedding] = {} - - def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: - self._cache[key] = embedding async def _probe_embedding_size(self) -> None: """Discover embedding_size by making a single API call.""" @@ -152,26 +152,6 @@ async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings embeddings = (embeddings / norms).astype(np.float32) return embeddings - async def get_embedding(self, key: str) -> NormalizedEmbedding: - cached = self._cache.get(key) - if cached is not None: - return cached - embedding = await self.get_embedding_nocache(key) - self._cache[key] = embedding - return embedding - - async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: - if not keys: - if self.embedding_size == 0: - await self._probe_embedding_size() - return np.empty((0, self.embedding_size), dtype=np.float32) - missing_keys = [k for k in keys if k not in self._cache] - if missing_keys: - fresh = await self.get_embeddings_nocache(missing_keys) - for i, k in enumerate(missing_keys): - self._cache[k] = fresh[i] - return np.array([self._cache[k] for k in keys], dtype=np.float32) - # --------------------------------------------------------------------------- # Provider auto-detection @@ -279,7 +259,7 @@ def create_embedding_model( model_spec: str | None = None, *, embedding_size: int = 0, -) -> PydanticAIEmbeddingModel: +) -> CachingEmbeddingModel: """Create an embedding model from a ``provider:model`` spec. Delegates to :class:`pydantic_ai.Embedder` for provider wiring. @@ -290,6 +270,9 @@ def create_embedding_model( If *embedding_size* is not given, it will be probed automatically on the first embedding call. + Returns a :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` + wrapping a :class:`PydanticAIEmbedder`. + Examples:: model = create_embedding_model("openai:text-embedding-3-small") @@ -324,7 +307,9 @@ def create_embedding_model( embedder = _PydanticAIEmbedder(embedding_model) else: embedder = _PydanticAIEmbedder(model_spec) - return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) + return CachingEmbeddingModel( + PydanticAIEmbedder(embedder, model_name, embedding_size) + ) # --------------------------------------------------------------------------- @@ -400,12 +385,14 @@ async def embed( def create_test_embedding_model( embedding_size: int = 3, -) -> PydanticAIEmbeddingModel: - """Create a :class:`PydanticAIEmbeddingModel` with deterministic fake +) -> CachingEmbeddingModel: + """Create a :class:`CachingEmbeddingModel` with deterministic fake embeddings for testing. No API keys or network access required.""" fake_model = _FakePydanticAIEmbeddingModel(embedding_size) - embedder = _PydanticAIEmbedder(fake_model) - return PydanticAIEmbeddingModel(embedder, "test", embedding_size) + pydantic_embedder = _PydanticAIEmbedder(fake_model) + return CachingEmbeddingModel( + PydanticAIEmbedder(pydantic_embedder, "test", embedding_size) + ) def configure_models( @@ -413,7 +400,7 @@ def configure_models( embedding_model_spec: str, *, embedding_size: int = 0, -) -> tuple[PydanticAIChatModel, PydanticAIEmbeddingModel]: +) -> tuple[PydanticAIChatModel, CachingEmbeddingModel]: """Configure both a chat model and an embedding model at once. Delegates to pydantic_ai's model registry for provider wiring. diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index df93997f..076a3476 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -175,7 +175,6 @@ def _set_embedding_size(self, size: int) -> None: """Adopt *size* when it was not known at construction time.""" assert size > 0 self._embedding_size = size - self._model.embedding_size = size self.settings.embedding_size = size def clear(self) -> None: diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 2ae09845..94f17e7a 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -5,8 +5,7 @@ import pytest from pytest_mock import MockerFixture -from typeagent.aitools.embeddings import IEmbeddingModel -from typeagent.aitools.model_adapters import PydanticAIEmbeddingModel +from typeagent.aitools.embeddings import CachingEmbeddingModel, IEmbeddingModel from conftest import ( embedding_model, # type: ignore # Magic, prevents side effects of mocking @@ -14,7 +13,7 @@ @pytest.mark.asyncio -async def test_get_embedding_nocache(embedding_model: PydanticAIEmbeddingModel): +async def test_get_embedding_nocache(embedding_model: CachingEmbeddingModel): """Test retrieving an embedding without using the cache.""" input_text = "Hello, world" embedding = await embedding_model.get_embedding_nocache(input_text) @@ -25,7 +24,7 @@ async def test_get_embedding_nocache(embedding_model: PydanticAIEmbeddingModel): @pytest.mark.asyncio -async def test_get_embeddings_nocache(embedding_model: PydanticAIEmbeddingModel): +async def test_get_embeddings_nocache(embedding_model: CachingEmbeddingModel): """Test retrieving multiple embeddings without using the cache.""" inputs = ["Hello, world", "Foo bar baz"] embeddings = await embedding_model.get_embeddings_nocache(inputs) @@ -37,7 +36,7 @@ async def test_get_embeddings_nocache(embedding_model: PydanticAIEmbeddingModel) @pytest.mark.asyncio async def test_get_embedding_with_cache( - embedding_model: PydanticAIEmbeddingModel, mocker: MockerFixture + embedding_model: CachingEmbeddingModel, mocker: MockerFixture ): """Test retrieving an embedding with caching.""" input_text = "Hello, world" @@ -46,9 +45,9 @@ async def test_get_embedding_with_cache( embedding1 = await embedding_model.get_embedding(input_text) assert input_text in embedding_model._cache - # Mock the nocache method to ensure it's not called + # Mock the nocache method on the underlying embedder to ensure it's not called mock_get_embedding_nocache = mocker.patch.object( - embedding_model, "get_embedding_nocache", autospec=True + embedding_model._embedder, "get_embedding_nocache", autospec=True ) # Second call should retrieve from the cache @@ -61,7 +60,7 @@ async def test_get_embedding_with_cache( @pytest.mark.asyncio async def test_get_embeddings_with_cache( - embedding_model: PydanticAIEmbeddingModel, mocker: MockerFixture + embedding_model: CachingEmbeddingModel, mocker: MockerFixture ): """Test retrieving multiple embeddings with caching.""" inputs = ["Hello, world", "Foo bar baz"] @@ -71,9 +70,9 @@ async def test_get_embeddings_with_cache( for input_text in inputs: assert input_text in embedding_model._cache - # Mock the nocache method to ensure it's not called + # Mock the nocache method on the underlying embedder to ensure it's not called mock_get_embeddings_nocache = mocker.patch.object( - embedding_model, "get_embeddings_nocache", autospec=True + embedding_model._embedder, "get_embeddings_nocache", autospec=True ) # Second call should retrieve from the cache @@ -85,7 +84,7 @@ async def test_get_embeddings_with_cache( @pytest.mark.asyncio -async def test_get_embeddings_empty_input(embedding_model: PydanticAIEmbeddingModel): +async def test_get_embeddings_empty_input(embedding_model: CachingEmbeddingModel): """Test retrieving embeddings for an empty input list.""" inputs: list[str] = [] embeddings = await embedding_model.get_embeddings(inputs) @@ -96,7 +95,7 @@ async def test_get_embeddings_empty_input(embedding_model: PydanticAIEmbeddingMo @pytest.mark.asyncio -async def test_add_embedding_to_cache(embedding_model: PydanticAIEmbeddingModel): +async def test_add_embedding_to_cache(embedding_model: CachingEmbeddingModel): """Test adding an embedding to the cache.""" key = "test_key" embedding = np.array([0.1, 0.2, 0.3], dtype=np.float32) @@ -108,7 +107,7 @@ async def test_add_embedding_to_cache(embedding_model: PydanticAIEmbeddingModel) @pytest.mark.asyncio async def test_get_embedding_nocache_empty_input( - embedding_model: PydanticAIEmbeddingModel, + embedding_model: CachingEmbeddingModel, ): """Test retrieving an embedding with no cache for an empty input.""" with pytest.raises(ValueError, match="Empty input text"): @@ -116,7 +115,7 @@ async def test_get_embedding_nocache_empty_input( @pytest.mark.asyncio -async def test_embeddings_are_normalized(embedding_model: PydanticAIEmbeddingModel): +async def test_embeddings_are_normalized(embedding_model: CachingEmbeddingModel): """Test that returned embeddings are unit-normalized.""" inputs = ["Hello, world", "Foo bar baz", "Testing normalization"] embeddings = await embedding_model.get_embeddings_nocache(inputs) @@ -128,7 +127,7 @@ async def test_embeddings_are_normalized(embedding_model: PydanticAIEmbeddingMod @pytest.mark.asyncio async def test_embeddings_are_deterministic( - embedding_model: PydanticAIEmbeddingModel, + embedding_model: CachingEmbeddingModel, ): """Test that the same input always produces the same embedding.""" input_text = "Deterministic test" @@ -139,7 +138,7 @@ async def test_embeddings_are_deterministic( @pytest.mark.asyncio async def test_different_inputs_produce_different_embeddings( - embedding_model: PydanticAIEmbeddingModel, + embedding_model: CachingEmbeddingModel, ): """Test that different inputs produce different embeddings.""" e1 = await embedding_model.get_embedding_nocache("Hello") @@ -149,7 +148,7 @@ async def test_different_inputs_produce_different_embeddings( @pytest.mark.asyncio async def test_implements_iembedding_model( - embedding_model: PydanticAIEmbeddingModel, + embedding_model: CachingEmbeddingModel, ): - """Test that PydanticAIEmbeddingModel satisfies the IEmbeddingModel protocol.""" + """Test that CachingEmbeddingModel satisfies the IEmbeddingModel protocol.""" assert isinstance(embedding_model, IEmbeddingModel) diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index c7ccf8f9..f089c4ba 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -17,13 +17,13 @@ from pydantic_ai.models import Model import typechat -from typeagent.aitools.embeddings import NormalizedEmbedding +from typeagent.aitools.embeddings import CachingEmbeddingModel, NormalizedEmbedding from typeagent.aitools.model_adapters import ( configure_models, create_chat_model, create_embedding_model, PydanticAIChatModel, - PydanticAIEmbeddingModel, + PydanticAIEmbedder, ) # --------------------------------------------------------------------------- @@ -99,13 +99,13 @@ async def test_chat_adapter_prompt_sections() -> None: # --------------------------------------------------------------------------- -# PydanticAIEmbeddingModel adapter +# PydanticAIEmbedder adapter # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_embedding_adapter_single() -> None: - """PydanticAIEmbeddingModel computes a single normalized embedding.""" + """PydanticAIEmbedder computes a single normalized embedding.""" mock_embedder = AsyncMock(spec=Embedder) raw_vec = [3.0, 4.0, 0.0] mock_embedder.embed_documents.return_value = EmbeddingResult( @@ -116,7 +116,7 @@ async def test_embedding_adapter_single() -> None: provider_name="test", ) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + adapter = PydanticAIEmbedder(mock_embedder, "test-model", 3) result = await adapter.get_embedding_nocache("test") assert result.shape == (3,) norm = float(np.linalg.norm(result)) @@ -135,7 +135,7 @@ async def test_embedding_adapter_probes_size() -> None: provider_name="test", ) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model") + adapter = PydanticAIEmbedder(mock_embedder, "test-model") assert adapter.embedding_size == 0 await adapter.get_embedding_nocache("probe") assert adapter.embedding_size == 3 @@ -143,7 +143,7 @@ async def test_embedding_adapter_probes_size() -> None: @pytest.mark.asyncio async def test_embedding_adapter_batch() -> None: - """PydanticAIEmbeddingModel computes batch embeddings.""" + """PydanticAIEmbedder computes batch embeddings.""" mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0], [0.0, 1.0]], @@ -153,14 +153,14 @@ async def test_embedding_adapter_batch() -> None: provider_name="test", ) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 2) + adapter = PydanticAIEmbedder(mock_embedder, "test-model", 2) result = await adapter.get_embeddings_nocache(["a", "b"]) assert result.shape == (2, 2) @pytest.mark.asyncio async def test_embedding_adapter_caching() -> None: - """Caching avoids re-computing embeddings.""" + """CachingEmbeddingModel avoids re-computing embeddings.""" mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], @@ -170,7 +170,8 @@ async def test_embedding_adapter_caching() -> None: provider_name="test", ) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + embedder = PydanticAIEmbedder(mock_embedder, "test-model", 3) + adapter = CachingEmbeddingModel(embedder) first = await adapter.get_embedding("cached") second = await adapter.get_embedding("cached") np.testing.assert_array_equal(first, second) @@ -182,7 +183,8 @@ async def test_embedding_adapter_caching() -> None: async def test_embedding_adapter_add_embedding() -> None: """add_embedding() populates the cache.""" mock_embedder = AsyncMock(spec=Embedder) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + embedder = PydanticAIEmbedder(mock_embedder, "test-model", 3) + adapter = CachingEmbeddingModel(embedder) vec: NormalizedEmbedding = np.array([1.0, 0.0, 0.0], dtype=np.float32) adapter.add_embedding("key", vec) result = await adapter.get_embedding("key") @@ -195,7 +197,8 @@ async def test_embedding_adapter_add_embedding() -> None: async def test_embedding_adapter_empty_batch() -> None: """Empty batch returns empty array with known size.""" mock_embedder = AsyncMock(spec=Embedder) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 4) + embedder = PydanticAIEmbedder(mock_embedder, "test-model", 4) + adapter = CachingEmbeddingModel(embedder) result = await adapter.get_embeddings_nocache([]) assert result.shape == (0, 4) @@ -212,4 +215,4 @@ def test_configure_models_returns_correct_types( monkeypatch.setenv("OPENAI_API_KEY", "test-key") chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") assert isinstance(chat, PydanticAIChatModel) - assert isinstance(embedder, PydanticAIEmbeddingModel) + assert isinstance(embedder, CachingEmbeddingModel) diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 04c53c36..416ed3eb 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -5,11 +5,11 @@ import pytest from typeagent.aitools.embeddings import ( + CachingEmbeddingModel, NormalizedEmbedding, ) from typeagent.aitools.model_adapters import ( create_test_embedding_model, - PydanticAIEmbeddingModel, ) from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase @@ -61,8 +61,8 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): assert len(bulk_vector_base) == len(vector_base) np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) - assert isinstance(vector_base._model, PydanticAIEmbeddingModel) - assert isinstance(bulk_vector_base._model, PydanticAIEmbeddingModel) + assert isinstance(vector_base._model, CachingEmbeddingModel) + assert isinstance(bulk_vector_base._model, CachingEmbeddingModel) sequential_cache = vector_base._model._cache bulk_cache = bulk_vector_base._model._cache assert set(sequential_cache.keys()) == set(bulk_cache.keys()) @@ -86,7 +86,7 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp await vector_base.add_key(key, cache=False) assert len(vector_base) == len(sample_embeddings) - assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert isinstance(vector_base._model, CachingEmbeddingModel) assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" @@ -106,7 +106,7 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam await vector_base.add_keys(keys, cache=False) assert len(vector_base) == len(sample_embeddings) - assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert isinstance(vector_base._model, CachingEmbeddingModel) assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" From 43894bd7adaae5a980a15ffa831bb88415aeafea Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 16:56:33 -0800 Subject: [PATCH 20/23] Remove max_retries everywhere -- this is now under Pydantic control --- src/typeagent/aitools/vectorbase.py | 7 ------- src/typeagent/knowpro/knowledge.py | 12 +++--------- tests/test_knowledge.py | 6 +++--- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 076a3476..107c7a21 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -13,8 +13,6 @@ ) from .model_adapters import create_embedding_model -DEFAULT_MAX_RETRIES = 2 - @dataclass class ScoredInt: @@ -29,7 +27,6 @@ class TextEmbeddingIndexSettings: min_score: float # Between 0.0 and 1.0 max_matches: int | None # >= 1; None means no limit batch_size: int # >= 1 - max_retries: int def __init__( self, @@ -38,14 +35,10 @@ def __init__( min_score: float | None = None, max_matches: int | None = None, batch_size: int | None = None, - max_retries: int | None = None, ): self.min_score = min_score if min_score is not None else 0.85 self.max_matches = max_matches if max_matches and max_matches >= 1 else None self.batch_size = batch_size if batch_size and batch_size >= 1 else 8 - self.max_retries = ( - max_retries if max_retries is not None else DEFAULT_MAX_RETRIES - ) self.embedding_model = embedding_model or create_embedding_model( embedding_size=embedding_size or 0, ) diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index dbcfe206..60ce302d 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -26,7 +26,6 @@ def create_knowledge_extractor( async def extract_knowledge_from_text( knowledge_extractor: IKnowledgeExtractor, text: str, - max_retries: int, ) -> Result[kplib.KnowledgeResponse]: """Extract knowledge from a single text input with retries.""" # TODO: Add a retry mechanism to handle transient errors. @@ -37,13 +36,10 @@ async def batch_worker( q: asyncio.Queue[tuple[int, str] | None], knowledge_extractor: IKnowledgeExtractor, results: dict[int, Result[kplib.KnowledgeResponse]], - max_retries: int, ) -> None: while item := await q.get(): index, text = item - result = await extract_knowledge_from_text( - knowledge_extractor, text, max_retries - ) + result = await extract_knowledge_from_text(knowledge_extractor, text) results[index] = result @@ -51,7 +47,6 @@ async def extract_knowledge_from_text_batch( knowledge_extractor: IKnowledgeExtractor, text_batch: list[str], concurrency: int = 2, - max_retries: int = 3, ) -> list[Result[kplib.KnowledgeResponse]]: """Extract knowledge from a batch of text inputs concurrently.""" if not text_batch: @@ -64,7 +59,7 @@ async def extract_knowledge_from_text_batch( async with asyncio.TaskGroup() as tg: for _ in range(concurrency): - tg.create_task(batch_worker(q, knowledge_extractor, results, max_retries)) + tg.create_task(batch_worker(q, knowledge_extractor, results)) for index, text in enumerate(text_batch): await q.put((index, text)) @@ -203,7 +198,6 @@ async def extract_knowledge_for_text_batch_q( knowledge_extractor: convknowledge.KnowledgeExtractor, text_batch: list[str], concurrency: int = 2, - max_retries: int = 3, ) -> list[Result[kplib.KnowledgeResponse]]: """Extract knowledge for a batch of text inputs using a task queue.""" raise NotImplementedError("TODO") @@ -212,7 +206,7 @@ async def extract_knowledge_for_text_batch_q( # await run_in_batches( # task_batch, - # lambda text: extract_knowledge_from_text(knowledge_extractor, text, max_retries), + # lambda text: extract_knowledge_from_text(knowledge_extractor, text), # concurrency, # ) diff --git a/tests/test_knowledge.py b/tests/test_knowledge.py index e20ff1f2..d4f46fd1 100644 --- a/tests/test_knowledge.py +++ b/tests/test_knowledge.py @@ -44,12 +44,12 @@ async def test_extract_knowledge_from_text( mock_knowledge_extractor: convknowledge.KnowledgeExtractor, ): """Test extracting knowledge from a single text input.""" - result = await extract_knowledge_from_text(mock_knowledge_extractor, "test text", 3) + result = await extract_knowledge_from_text(mock_knowledge_extractor, "test text") assert isinstance(result, Success) assert result.value.topics[0] == "test text" failure_result = await extract_knowledge_from_text( - mock_knowledge_extractor, "error", 3 + mock_knowledge_extractor, "error" ) assert isinstance(failure_result, Failure) assert failure_result.message == "Extraction failed" @@ -62,7 +62,7 @@ async def test_extract_knowledge_from_text_batch( """Test extracting knowledge from a batch of text inputs.""" text_batch = ["text 1", "text 2", "error"] results = await extract_knowledge_from_text_batch( - mock_knowledge_extractor, text_batch, 2, 3 + mock_knowledge_extractor, text_batch, 2 ) assert len(results) == 3 From 910a99b94f1a78f5e8ceb5e92b69265f364e9567 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 19:40:57 -0800 Subject: [PATCH 21/23] Remove embedding_size argument everywhere. Handle it internally --- src/typeagent/aitools/embeddings.py | 28 ++-- src/typeagent/aitools/model_adapters.py | 41 +---- src/typeagent/aitools/vectorbase.py | 37 +++-- src/typeagent/knowpro/interfaces_storage.py | 1 - src/typeagent/storage/sqlite/messageindex.py | 3 + src/typeagent/storage/sqlite/provider.py | 104 ++++--------- tests/test_conversation_metadata.py | 48 ------ tests/test_embedding_consistency.py | 148 ++++++++++++------- tests/test_embeddings.py | 13 +- tests/test_messageindex.py | 2 +- tests/test_model_adapters.py | 53 ++----- tests/test_vectorbase.py | 25 ++++ 12 files changed, 205 insertions(+), 298 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index e0e04b7a..5d0e172c 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -25,15 +25,15 @@ class IEmbedder(Protocol): @property def model_name(self) -> str: ... - @property - def embedding_size(self) -> int: ... - async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: """Compute a single embedding without caching.""" ... async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: - """Compute embeddings for a batch of strings without caching.""" + """Compute embeddings for a batch of strings without caching. + + Raises :class:`ValueError` if *input* is empty. + """ ... @@ -49,9 +49,6 @@ class IEmbeddingModel(Protocol): @property def model_name(self) -> str: ... - @property - def embedding_size(self) -> int: ... - def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: """Cache an already-computed embedding under the given key.""" ... @@ -89,10 +86,6 @@ def __init__(self, embedder: IEmbedder) -> None: def model_name(self) -> str: return self._embedder.model_name - @property - def embedding_size(self) -> int: - return self._embedder.embedding_size - def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: self._cache[key] = embedding @@ -112,7 +105,7 @@ async def get_embedding(self, key: str) -> NormalizedEmbedding: async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: if not keys: - return await self._embedder.get_embeddings_nocache([]) + raise ValueError("Cannot embed an empty list") missing_keys = [k for k in keys if k not in self._cache] if missing_keys: fresh = await self._embedder.get_embeddings_nocache(missing_keys) @@ -122,14 +115,11 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: DEFAULT_MODEL_NAME = "text-embedding-ada-002" -DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002) DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI TEST_MODEL_NAME = "test" -model_to_embedding_size_and_envvar: dict[str, tuple[int | None, str]] = { - DEFAULT_MODEL_NAME: (DEFAULT_EMBEDDING_SIZE, DEFAULT_ENVVAR), - "text-embedding-3-small": (1536, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL"), - "text-embedding-3-large": (3072, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE"), - # For testing only, not a real model (insert real embeddings above) - TEST_MODEL_NAME: (3, "SIR_NOT_APPEARING_IN_THIS_FILM"), +model_to_envvar: dict[str, str] = { + DEFAULT_MODEL_NAME: DEFAULT_ENVVAR, + "text-embedding-3-small": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL", + "text-embedding-3-large": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE", } diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 6ccf8292..86dca394 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -103,36 +103,23 @@ class PydanticAIEmbedder: be used wherever the codebase expects an ``IEmbedder``. Wrap in :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` to get a ready-to-use ``IEmbeddingModel`` with caching. - - If *embedding_size* is not given, it is probed automatically by making a - single embedding call. """ model_name: str - embedding_size: int def __init__( self, embedder: _PydanticAIEmbedder, model_name: str, - embedding_size: int = 0, ) -> None: self._embedder = embedder self.model_name = model_name - self.embedding_size = embedding_size - - async def _probe_embedding_size(self) -> None: - """Discover embedding_size by making a single API call.""" - result = await self._embedder.embed_documents(["probe"]) - self.embedding_size = len(result.embeddings[0]) async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: result = await self._embedder.embed_documents([input]) embedding: NDArray[np.float32] = np.array( result.embeddings[0], dtype=np.float32 ) - if self.embedding_size == 0: - self.embedding_size = len(embedding) norm = float(np.linalg.norm(embedding)) if norm > 0: embedding = (embedding / norm).astype(np.float32) @@ -140,13 +127,9 @@ async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: if not input: - if self.embedding_size == 0: - await self._probe_embedding_size() - return np.empty((0, self.embedding_size), dtype=np.float32) + raise ValueError("Cannot embed an empty list") result = await self._embedder.embed_documents(input) embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) - if self.embedding_size == 0: - self.embedding_size = embeddings.shape[1] norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) norms = np.where(norms > 0, norms, np.float32(1.0)) embeddings = (embeddings / norms).astype(np.float32) @@ -257,8 +240,6 @@ def create_chat_model( def create_embedding_model( model_spec: str | None = None, - *, - embedding_size: int = 0, ) -> CachingEmbeddingModel: """Create an embedding model from a ``provider:model`` spec. @@ -267,8 +248,6 @@ def create_embedding_model( ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used. - If *embedding_size* is not given, it will be probed automatically - on the first embedding call. Returns a :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` wrapping a :class:`PydanticAIEmbedder`. @@ -287,12 +266,10 @@ def create_embedding_model( if _needs_azure_fallback(provider): from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel - from .embeddings import model_to_embedding_size_and_envvar + from .embeddings import model_to_envvar # Look up model-specific Azure endpoint, falling back to the generic one. - _, suggested_envvar = model_to_embedding_size_and_envvar.get( - model_name, (None, None) - ) + suggested_envvar = model_to_envvar.get(model_name) if suggested_envvar and os.getenv(suggested_envvar): endpoint_envvar = suggested_envvar else: @@ -307,9 +284,7 @@ def create_embedding_model( embedder = _PydanticAIEmbedder(embedding_model) else: embedder = _PydanticAIEmbedder(model_spec) - return CachingEmbeddingModel( - PydanticAIEmbedder(embedder, model_name, embedding_size) - ) + return CachingEmbeddingModel(PydanticAIEmbedder(embedder, model_name)) # --------------------------------------------------------------------------- @@ -390,16 +365,12 @@ def create_test_embedding_model( embeddings for testing. No API keys or network access required.""" fake_model = _FakePydanticAIEmbeddingModel(embedding_size) pydantic_embedder = _PydanticAIEmbedder(fake_model) - return CachingEmbeddingModel( - PydanticAIEmbedder(pydantic_embedder, "test", embedding_size) - ) + return CachingEmbeddingModel(PydanticAIEmbedder(pydantic_embedder, "test")) def configure_models( chat_model_spec: str, embedding_model_spec: str, - *, - embedding_size: int = 0, ) -> tuple[PydanticAIChatModel, CachingEmbeddingModel]: """Configure both a chat model and an embedding model at once. @@ -417,5 +388,5 @@ def configure_models( """ return ( create_chat_model(chat_model_spec), - create_embedding_model(embedding_model_spec, embedding_size=embedding_size), + create_embedding_model(embedding_model_spec), ) diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 107c7a21..46c4dbae 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -23,7 +23,6 @@ class ScoredInt: @dataclass class TextEmbeddingIndexSettings: embedding_model: IEmbeddingModel - embedding_size: int # Set to embedding_model.embedding_size min_score: float # Between 0.0 and 1.0 max_matches: int | None # >= 1; None means no limit batch_size: int # >= 1 @@ -31,7 +30,6 @@ class TextEmbeddingIndexSettings: def __init__( self, embedding_model: IEmbeddingModel | None = None, - embedding_size: int | None = None, min_score: float | None = None, max_matches: int | None = None, batch_size: int | None = None, @@ -39,13 +37,7 @@ def __init__( self.min_score = min_score if min_score is not None else 0.85 self.max_matches = max_matches if max_matches and max_matches >= 1 else None self.batch_size = batch_size if batch_size and batch_size >= 1 else 8 - self.embedding_model = embedding_model or create_embedding_model( - embedding_size=embedding_size or 0, - ) - self.embedding_size = self.embedding_model.embedding_size - assert ( - embedding_size is None or self.embedding_size == embedding_size - ), f"Given embedding size {embedding_size} doesn't match model's embedding size {self.embedding_size}" + self.embedding_model = embedding_model or create_embedding_model() class VectorBase: @@ -57,7 +49,7 @@ class VectorBase: def __init__(self, settings: TextEmbeddingIndexSettings): self.settings = settings self._model = settings.embedding_model - self._embedding_size = self._model.embedding_size + self._embedding_size = 0 self.clear() async def get_embedding(self, key: str, cache: bool = True) -> NormalizedEmbedding: @@ -89,6 +81,11 @@ def add_embedding( if self._embedding_size == 0: self._set_embedding_size(len(embedding)) self._vectors.shape = (0, self._embedding_size) + if len(embedding) != self._embedding_size: + raise ValueError( + f"Embedding size mismatch: expected {self._embedding_size}, " + f"got {len(embedding)}" + ) embeddings = embedding.reshape(1, -1) # Make it 2D: 1xN self._vectors = np.append(self._vectors, embeddings, axis=0) if key is not None: @@ -97,23 +94,30 @@ def add_embedding( def add_embeddings( self, keys: None | list[str], embeddings: NormalizedEmbeddings ) -> None: - assert embeddings.ndim == 2 + if embeddings.ndim != 2: + raise ValueError(f"Expected 2D embeddings array, got {embeddings.ndim}D") if self._embedding_size == 0: self._set_embedding_size(embeddings.shape[1]) self._vectors.shape = (0, self._embedding_size) - assert embeddings.shape[1] == self._embedding_size + if embeddings.shape[1] != self._embedding_size: + raise ValueError( + f"Embedding size mismatch: expected {self._embedding_size}, " + f"got {embeddings.shape[1]}" + ) self._vectors = np.concatenate((self._vectors, embeddings), axis=0) if keys is not None: for key, embedding in zip(keys, embeddings): self._model.add_embedding(key, embedding) async def add_key(self, key: str, cache: bool = True) -> None: - embeddings = (await self.get_embedding(key, cache=cache)).reshape(1, -1) - self._vectors = np.append(self._vectors, embeddings, axis=0) + embedding = await self.get_embedding(key, cache=cache) + self.add_embedding(key if cache else None, embedding) async def add_keys(self, keys: list[str], cache: bool = True) -> None: + if not keys: + return embeddings = await self.get_embeddings(keys, cache=cache) - self._vectors = np.concatenate((self._vectors, embeddings), axis=0) + self.add_embeddings(keys if cache else None, embeddings) def fuzzy_lookup_embedding( self, @@ -126,6 +130,8 @@ def fuzzy_lookup_embedding( max_hits = 10 if min_score is None: min_score = 0.0 + if len(self._vectors) == 0: + return [] # This line does most of the work: scores: Iterable[float] = np.dot(self._vectors, embedding) scored_ordinals = [ @@ -168,7 +174,6 @@ def _set_embedding_size(self, size: int) -> None: """Adopt *size* when it was not known at construction time.""" assert size > 0 self._embedding_size = size - self.settings.embedding_size = size def clear(self) -> None: self._vectors = np.array([], dtype=np.float32) diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index 877e3ba9..a82fe7ad 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -52,7 +52,6 @@ class ConversationMetadata: schema_version: int | None = None created_at: Datetime | None = None updated_at: Datetime | None = None - embedding_size: int | None = None embedding_model: str | None = None tags: list[str] | None = None extra: dict[str, str] | None = None diff --git a/src/typeagent/storage/sqlite/messageindex.py b/src/typeagent/storage/sqlite/messageindex.py index 877cbd6d..d48a9761 100644 --- a/src/typeagent/storage/sqlite/messageindex.py +++ b/src/typeagent/storage/sqlite/messageindex.py @@ -63,6 +63,9 @@ async def add_messages_starting_at( for chunk_ord, chunk in enumerate(message.text_chunks): chunks_to_embed.append((msg_ord, chunk_ord, chunk)) + if not chunks_to_embed: + return + embeddings = await self._vectorbase.get_embeddings( [chunk for _, _, chunk in chunks_to_embed], cache=False ) diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 515c9cf0..3d5a3185 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -31,7 +31,7 @@ class SqliteStorageProvider[TMessage: interfaces.IMessage]( """SQLite-backed storage provider implementation. This provider performs consistency checks on database initialization to ensure - that existing embeddings match the configured embedding_size. If a mismatch is + that existing embeddings match the configured embedding model. If a mismatch is detected, a ValueError is raised with a descriptive error message. """ @@ -119,22 +119,16 @@ def _resolve_embedding_settings( provided_related_settings: RelatedTermIndexSettings | None, ) -> tuple[MessageTextIndexSettings, RelatedTermIndexSettings]: metadata_exists = self._conversation_metadata_exists() - stored_size_str = self._get_single_metadata_value("embedding_size") stored_name = self._get_single_metadata_value("embedding_name") - stored_size = int(stored_size_str) if stored_size_str else None if provided_message_settings is None: - if stored_size is not None or stored_name is not None: - spec = stored_name or "" + if stored_name is not None: + spec = stored_name if spec and ":" not in spec: spec = f"openai:{spec}" - embedding_model = create_embedding_model( - spec, - embedding_size=stored_size or 0, - ) + embedding_model = create_embedding_model(spec) base_embedding_settings = TextEmbeddingIndexSettings( embedding_model=embedding_model, - embedding_size=stored_size, ) else: base_embedding_settings = TextEmbeddingIndexSettings() @@ -142,13 +136,7 @@ def _resolve_embedding_settings( else: message_settings = provided_message_settings base_embedding_settings = message_settings.embedding_index_settings - provided_size = base_embedding_settings.embedding_size provided_name = base_embedding_settings.embedding_model.model_name - if stored_size is not None and stored_size != provided_size: - raise ValueError( - f"Conversation metadata embedding_size " - f"({stored_size}) does not match provided embedding size ({provided_size})." - ) if stored_name is not None and stored_name != provided_name: raise ValueError( f"Conversation metadata embedding_model " @@ -160,12 +148,7 @@ def _resolve_embedding_settings( else: related_settings = provided_related_settings related_embedding_settings = related_settings.embedding_index_settings - related_size = related_embedding_settings.embedding_size related_name = related_embedding_settings.embedding_model.model_name - if related_size != base_embedding_settings.embedding_size: - raise ValueError( - "Related term index embedding_size does not match message text index embedding_size" - ) if related_name != base_embedding_settings.embedding_model.model_name: raise ValueError( "Related term index embedding_model does not match message text index embedding_model" @@ -173,17 +156,9 @@ def _resolve_embedding_settings( if related_settings.embedding_index_settings is not base_embedding_settings: related_settings.embedding_index_settings = base_embedding_settings - actual_size = base_embedding_settings.embedding_size actual_name = base_embedding_settings.embedding_model.model_name if self._metadata is not None: - if self._metadata.embedding_size is None: - self._metadata.embedding_size = actual_size - elif self._metadata.embedding_size != actual_size: - raise ValueError( - "Conversation metadata embedding_size does not match provider settings" - ) - if self._metadata.embedding_model is None: self._metadata.embedding_model = actual_name elif self._metadata.embedding_model != actual_name: @@ -193,8 +168,6 @@ def _resolve_embedding_settings( if metadata_exists: metadata_updates: dict[str, str] = {} - if stored_size is None: - metadata_updates["embedding_size"] = str(actual_size) if stored_name is None: metadata_updates["embedding_name"] = actual_name if metadata_updates: @@ -203,51 +176,47 @@ def _resolve_embedding_settings( return message_settings, related_settings def _check_embedding_consistency(self) -> None: - """Check that existing embeddings in the database match the expected embedding size. + """Check that existing embeddings in the database are consistent. - This method is called during initialization to ensure that embeddings stored in the - database match the embedding_size specified in ConversationSettings. This prevents - runtime errors when trying to use embeddings of incompatible sizes. + This method is called during initialization to ensure that embeddings + stored in the message text index and related terms index have the same + size. This prevents runtime errors when trying to use embeddings of + incompatible sizes. Raises: - ValueError: If embeddings in the database don't match the expected size. + ValueError: If embeddings in the database have inconsistent sizes. """ from .schema import deserialize_embedding cursor = self.db.cursor() - expected_size = ( - self.message_text_index_settings.embedding_index_settings.embedding_size - ) - # Check message text index embeddings + # Get size from message text index embeddings + message_size: int | None = None cursor.execute("SELECT embedding FROM MessageTextIndex LIMIT 1") row = cursor.fetchone() if row and row[0]: embedding = deserialize_embedding(row[0]) - actual_size = len(embedding) - if actual_size != expected_size: - raise ValueError( - f"Message text index embedding size mismatch: " - f"database contains embeddings of size {actual_size}, " - f"but ConversationSettings specifies embedding_size={expected_size}. " - f"The database was likely created with a different embedding model. " - f"Please use the same embedding model or create a new database." - ) + message_size = len(embedding) - # Check related terms fuzzy index embeddings + # Get size from related terms fuzzy index embeddings + related_size: int | None = None cursor.execute("SELECT term_embedding FROM RelatedTermsFuzzy LIMIT 1") row = cursor.fetchone() if row and row[0]: embedding = deserialize_embedding(row[0]) - actual_size = len(embedding) - if actual_size != expected_size: - raise ValueError( - f"Related terms index embedding size mismatch: " - f"database contains embeddings of size {actual_size}, " - f"but ConversationSettings specifies embedding_size={expected_size}. " - f"The database was likely created with a different embedding model. " - f"Please use the same embedding model or create a new database." - ) + related_size = len(embedding) + + if ( + message_size is not None + and related_size is not None + and message_size != related_size + ): + raise ValueError( + f"Embedding size mismatch: " + f"message text index has size {message_size}, " + f"but related terms index has size {related_size}. " + f"The database may be corrupted." + ) def _init_conversation_metadata_if_needed(self) -> None: """Initialize conversation metadata if the database is new (empty metadata table). @@ -276,18 +245,10 @@ def _init_conversation_metadata_if_needed(self) -> None: tags = None extras = {} - actual_embedding_size = ( - self.message_text_index_settings.embedding_index_settings.embedding_size - ) actual_embedding_name = ( self.message_text_index_settings.embedding_index_settings.embedding_model.model_name ) - metadata_embedding_size = ( - self._metadata.embedding_size - if self._metadata and self._metadata.embedding_size is not None - else actual_embedding_size - ) metadata_embedding_name = ( self._metadata.embedding_model if self._metadata and self._metadata.embedding_model is not None @@ -309,7 +270,6 @@ def _init_conversation_metadata_if_needed(self) -> None: created_at=format_timestamp_utc(current_time), updated_at=format_timestamp_utc(current_time), tag=tags, # None or list of tags - embedding_size=str(metadata_embedding_size), embedding_name=metadata_embedding_name, **extras, ) @@ -516,9 +476,6 @@ def parse_datetime(value_str: str) -> datetime: updated_at_str = get_single("updated_at") updated_at = parse_datetime(updated_at_str) if updated_at_str else None - embedding_size_str = get_single("embedding_size") - embedding_size = int(embedding_size_str) if embedding_size_str else None - embedding_model = get_single("embedding_name") # Handle tags (multiple values allowed, None if key doesn't exist) @@ -545,7 +502,6 @@ def parse_datetime(value_str: str) -> datetime: schema_version=schema_version, created_at=created_at, updated_at=updated_at, - embedding_size=embedding_size, embedding_model=embedding_model, tags=tags, extra=extra if extra else None, @@ -592,9 +548,6 @@ async def update_conversation_timestamps( # Insert default values if no metadata exists name_tag = self._metadata.name_tag if self._metadata else "conversation" schema_version = str(CONVERSATION_SCHEMA_VERSION) - actual_embedding_size = ( - self.message_text_index_settings.embedding_index_settings.embedding_size - ) actual_embedding_name = ( self.message_text_index_settings.embedding_index_settings.embedding_model.model_name ) @@ -602,7 +555,6 @@ async def update_conversation_timestamps( metadata_kwds: dict[str, str | None] = { "name_tag": name_tag or "conversation", "schema_version": schema_version, - "embedding_size": str(actual_embedding_size), "embedding_name": actual_embedding_name, } if created_at is not None: diff --git a/tests/test_conversation_metadata.py b/tests/test_conversation_metadata.py index 887c50b2..37a194a2 100644 --- a/tests/test_conversation_metadata.py +++ b/tests/test_conversation_metadata.py @@ -106,7 +106,6 @@ async def test_get_conversation_metadata_nonexistent( assert metadata.schema_version is None assert metadata.created_at is None assert metadata.updated_at is None - assert metadata.embedding_size is None assert metadata.embedding_model is None assert metadata.tags is None assert metadata.extra is None @@ -131,9 +130,7 @@ async def test_update_conversation_timestamps_new( assert metadata.created_at == created_at assert metadata.updated_at == updated_at settings = storage_provider.message_text_index_settings.embedding_index_settings - expected_size = settings.embedding_size expected_model = settings.embedding_model.model_name - assert metadata.embedding_size == expected_size assert metadata.embedding_model == expected_model assert metadata.tags is None assert metadata.extra is None @@ -458,9 +455,7 @@ async def test_conversation_metadata_persistence( assert metadata.name_tag == "conversation_persistent_test" assert metadata.created_at == created_at assert metadata.updated_at == updated_at - expected_size = embedding_settings.embedding_size expected_model = embedding_settings.embedding_model.model_name - assert metadata.embedding_size == expected_size assert metadata.embedding_model == expected_model finally: await provider2.close() @@ -598,49 +593,6 @@ async def test_conversation_metadata_shared_access( await provider1.close() await provider2.close() - @pytest.mark.asyncio - async def test_embedding_metadata_mismatch_raises( - self, temp_db_path: str, embedding_model: IEmbeddingModel - ): - """Ensure a mismatch between stored metadata and provided settings raises.""" - embedding_settings = TextEmbeddingIndexSettings(embedding_model) - message_text_settings = MessageTextIndexSettings(embedding_settings) - related_terms_settings = RelatedTermIndexSettings(embedding_settings) - - provider = SqliteStorageProvider( - db_path=temp_db_path, - message_type=DummyMessage, - message_text_index_settings=message_text_settings, - related_term_index_settings=related_terms_settings, - ) - - await provider.update_conversation_timestamps( - created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), - updated_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), - ) - provider.db.commit() - await provider.close() - - mismatched_model = create_test_embedding_model( - embedding_size=embedding_settings.embedding_size + 1, - ) - mismatched_settings = TextEmbeddingIndexSettings( - embedding_model=mismatched_model, - embedding_size=mismatched_model.embedding_size, - ) - - with pytest.raises(ValueError, match="embedding_size"): - SqliteStorageProvider( - db_path=temp_db_path, - message_type=DummyMessage, - message_text_index_settings=MessageTextIndexSettings( - mismatched_settings - ), - related_term_index_settings=RelatedTermIndexSettings( - mismatched_settings - ), - ) - @pytest.mark.asyncio async def test_embedding_model_mismatch_raises( self, temp_db_path: str, embedding_model: IEmbeddingModel diff --git a/tests/test_embedding_consistency.py b/tests/test_embedding_consistency.py index f032c856..619c9210 100644 --- a/tests/test_embedding_consistency.py +++ b/tests/test_embedding_consistency.py @@ -1,39 +1,38 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Test embedding consistency checks between database and settings.""" +"""Test embedding consistency checks between database indexes.""" import os +import sqlite3 import tempfile +import numpy as np import pytest from typeagent import create_conversation from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite import SqliteStorageProvider +from typeagent.storage.sqlite.schema import serialize_embedding from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @pytest.mark.asyncio -async def test_embedding_size_mismatch_in_message_index(): - """Test that opening a DB with mismatched embedding size raises an error.""" - # Create a temporary database file +async def test_same_embedding_size_no_error(): + """Test that opening a DB with the same model works fine.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create a conversation with test model (embedding size 3) settings1 = ConversationSettings( model=create_test_embedding_model(embedding_size=3) ) - # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False conv1 = await create_conversation( db_path, TranscriptMessage, settings=settings1 ) - # Add some messages to populate the index messages = [ TranscriptMessage( text_chunks=["Hello world"], @@ -43,108 +42,151 @@ async def test_embedding_size_mismatch_in_message_index(): await conv1.add_messages_with_indexing(messages) await conv1.storage_provider.close() - # Now try to open the same database with a different embedding size - # This should raise an error + # Reopen with same settings — should work settings2 = ConversationSettings( - model=create_test_embedding_model(embedding_size=5) + model=create_test_embedding_model(embedding_size=3) ) - - with pytest.raises(ValueError, match="embedding_size"): - provider = SqliteStorageProvider( - db_path=db_path, - message_type=TranscriptMessage, - message_text_index_settings=settings2.message_text_index_settings, - related_term_index_settings=settings2.related_term_index_settings, - ) - await provider.close() + provider = SqliteStorageProvider( + db_path=db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings2.message_text_index_settings, + related_term_index_settings=settings2.related_term_index_settings, + ) + await provider.close() finally: - # Clean up the temporary database if os.path.exists(db_path): os.unlink(db_path) @pytest.mark.asyncio -async def test_embedding_size_mismatch_in_related_terms(): - """Test that opening a DB with mismatched embedding size in related terms raises an error.""" - # Create a temporary database file +async def test_empty_db_no_error(): + """Test that opening an empty DB doesn't raise an error regardless of embedding size.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create a conversation with default embedding size settings1 = ConversationSettings( model=create_test_embedding_model(embedding_size=3) ) - # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False conv1 = await create_conversation( db_path, TranscriptMessage, settings=settings1 ) - - # Add some messages to populate the related terms index - messages = [ - TranscriptMessage( - text_chunks=["Apple is a fruit"], - metadata=TranscriptMessageMeta(speaker="Alice"), - ) - ] - await conv1.add_messages_with_indexing(messages) await conv1.storage_provider.close() - # Now try to open the same database with a different embedding size - # This should raise an error + # Open with different embedding size should work since DB is empty settings2 = ConversationSettings( model=create_test_embedding_model(embedding_size=5) ) + provider = SqliteStorageProvider( + db_path=db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings2.message_text_index_settings, + related_term_index_settings=settings2.related_term_index_settings, + ) + await provider.close() + + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +@pytest.mark.asyncio +async def test_embedding_size_mismatch_raises(): + """Test that mismatched embedding sizes between indexes raises ValueError.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name - with pytest.raises(ValueError, match="embedding_size"): - provider = SqliteStorageProvider( + try: + # Create a conversation so that the schema is set up + settings = ConversationSettings( + model=create_test_embedding_model(embedding_size=3) + ) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + conv = await create_conversation(db_path, TranscriptMessage, settings=settings) + await conv.storage_provider.close() + + # Manually insert embeddings of different sizes into the two tables + conn = sqlite3.connect(db_path) + msg_emb = serialize_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float32)) + term_emb = serialize_embedding( + np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32) + ) + conn.execute( + "INSERT INTO MessageTextIndex " + "(msg_id, chunk_ordinal, embedding, index_position) " + "VALUES (0, 0, ?, 0)", + (msg_emb,), + ) + conn.execute( + "INSERT INTO RelatedTermsFuzzy (term, term_embedding) VALUES (?, ?)", + ("hello", term_emb), + ) + conn.commit() + conn.close() + + # Reopening should detect the mismatch + settings2 = ConversationSettings( + model=create_test_embedding_model(embedding_size=3) + ) + with pytest.raises(ValueError, match="Embedding size mismatch"): + SqliteStorageProvider( db_path=db_path, message_type=TranscriptMessage, message_text_index_settings=settings2.message_text_index_settings, related_term_index_settings=settings2.related_term_index_settings, ) - await provider.close() finally: - # Clean up the temporary database if os.path.exists(db_path): os.unlink(db_path) @pytest.mark.asyncio -async def test_empty_db_no_error(): - """Test that opening an empty DB doesn't raise an error regardless of embedding size.""" - # Create a temporary database file +async def test_adding_mismatched_embeddings_raises(): + """Test that adding messages with a different embedding size raises ValueError.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create an empty database + # Create and populate with size-3 embeddings settings1 = ConversationSettings( model=create_test_embedding_model(embedding_size=3) ) - # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False conv1 = await create_conversation( db_path, TranscriptMessage, settings=settings1 ) + await conv1.add_messages_with_indexing( + [ + TranscriptMessage( + text_chunks=["Hello world"], + metadata=TranscriptMessageMeta(speaker="Alice"), + ) + ] + ) await conv1.storage_provider.close() - # Open with different embedding size should work since DB is empty + # Reopen with size-5 embeddings and try to add more messages settings2 = ConversationSettings( model=create_test_embedding_model(embedding_size=5) ) - provider = SqliteStorageProvider( - db_path=db_path, - message_type=TranscriptMessage, - message_text_index_settings=settings2.message_text_index_settings, - related_term_index_settings=settings2.related_term_index_settings, + settings2.semantic_ref_index_settings.auto_extract_knowledge = False + conv2 = await create_conversation( + db_path, TranscriptMessage, settings=settings2 ) - await provider.close() + with pytest.raises(ValueError, match="Embedding size mismatch"): + await conv2.add_messages_with_indexing( + [ + TranscriptMessage( + text_chunks=["Goodbye world"], + metadata=TranscriptMessageMeta(speaker="Bob"), + ) + ] + ) + await conv2.storage_provider.close() finally: - # Clean up the temporary database if os.path.exists(db_path): os.unlink(db_path) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 94f17e7a..24a4ff69 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -19,7 +19,6 @@ async def test_get_embedding_nocache(embedding_model: CachingEmbeddingModel): embedding = await embedding_model.get_embedding_nocache(input_text) assert isinstance(embedding, np.ndarray) - assert embedding.shape == (embedding_model.embedding_size,) assert embedding.dtype == np.float32 @@ -30,7 +29,7 @@ async def test_get_embeddings_nocache(embedding_model: CachingEmbeddingModel): embeddings = await embedding_model.get_embeddings_nocache(inputs) assert isinstance(embeddings, np.ndarray) - assert embeddings.shape == (len(inputs), embedding_model.embedding_size) + assert embeddings.shape[0] == len(inputs) assert embeddings.dtype == np.float32 @@ -85,13 +84,9 @@ async def test_get_embeddings_with_cache( @pytest.mark.asyncio async def test_get_embeddings_empty_input(embedding_model: CachingEmbeddingModel): - """Test retrieving embeddings for an empty input list.""" - inputs: list[str] = [] - embeddings = await embedding_model.get_embeddings(inputs) - - assert isinstance(embeddings, np.ndarray) - assert embeddings.shape == (0, embedding_model.embedding_size) - assert embeddings.dtype == np.float32 + """Test retrieving embeddings for an empty input list raises ValueError.""" + with pytest.raises(ValueError, match="Cannot embed an empty list"): + await embedding_model.get_embeddings([]) @pytest.mark.asyncio diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index f91ac933..7e00cc45 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -158,7 +158,7 @@ async def test_generate_embedding(needs_auth: None): embedding = await index.generate_embedding("test message") assert embedding is not None - assert len(embedding) == test_model.embedding_size # 3 for test model + assert len(embedding) == 3 # test model uses embedding size 3 dot = float(np.dot(embedding, embedding)) assert abs(dot - 1.0) < 1e-6, f"Embedding not normalized: {dot}" diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index f089c4ba..11907bd9 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -21,7 +21,6 @@ from typeagent.aitools.model_adapters import ( configure_models, create_chat_model, - create_embedding_model, PydanticAIChatModel, PydanticAIEmbedder, ) @@ -38,23 +37,6 @@ def test_spec_uses_colon_separator() -> None: create_chat_model("nonexistent_provider_xyz:fake-model") -# --------------------------------------------------------------------------- -# Embedding size -# --------------------------------------------------------------------------- - - -def test_explicit_embedding_size() -> None: - """Passing embedding_size= sets it immediately.""" - model = create_embedding_model("openai:text-embedding-3-small", embedding_size=42) - assert model.embedding_size == 42 - - -def test_default_embedding_size_is_zero() -> None: - """Without embedding_size=, it defaults to 0 (probed on first call).""" - model = create_embedding_model("openai:text-embedding-3-small") - assert model.embedding_size == 0 - - # --------------------------------------------------------------------------- # PydanticAIChatModel adapter # --------------------------------------------------------------------------- @@ -116,7 +98,7 @@ async def test_embedding_adapter_single() -> None: provider_name="test", ) - adapter = PydanticAIEmbedder(mock_embedder, "test-model", 3) + adapter = PydanticAIEmbedder(mock_embedder, "test-model") result = await adapter.get_embedding_nocache("test") assert result.shape == (3,) norm = float(np.linalg.norm(result)) @@ -124,21 +106,12 @@ async def test_embedding_adapter_single() -> None: @pytest.mark.asyncio -async def test_embedding_adapter_probes_size() -> None: - """embedding_size is discovered from the first embedding call.""" +async def test_embedding_adapter_empty_batch_raises() -> None: + """Empty batch raises ValueError.""" mock_embedder = AsyncMock(spec=Embedder) - mock_embedder.embed_documents.return_value = EmbeddingResult( - embeddings=[[1.0, 0.0, 0.0]], - inputs=["probe"], - input_type="document", - model_name="test-model", - provider_name="test", - ) - adapter = PydanticAIEmbedder(mock_embedder, "test-model") - assert adapter.embedding_size == 0 - await adapter.get_embedding_nocache("probe") - assert adapter.embedding_size == 3 + with pytest.raises(ValueError, match="Cannot embed an empty list"): + await adapter.get_embeddings_nocache([]) @pytest.mark.asyncio @@ -153,7 +126,7 @@ async def test_embedding_adapter_batch() -> None: provider_name="test", ) - adapter = PydanticAIEmbedder(mock_embedder, "test-model", 2) + adapter = PydanticAIEmbedder(mock_embedder, "test-model") result = await adapter.get_embeddings_nocache(["a", "b"]) assert result.shape == (2, 2) @@ -170,7 +143,7 @@ async def test_embedding_adapter_caching() -> None: provider_name="test", ) - embedder = PydanticAIEmbedder(mock_embedder, "test-model", 3) + embedder = PydanticAIEmbedder(mock_embedder, "test-model") adapter = CachingEmbeddingModel(embedder) first = await adapter.get_embedding("cached") second = await adapter.get_embedding("cached") @@ -183,7 +156,7 @@ async def test_embedding_adapter_caching() -> None: async def test_embedding_adapter_add_embedding() -> None: """add_embedding() populates the cache.""" mock_embedder = AsyncMock(spec=Embedder) - embedder = PydanticAIEmbedder(mock_embedder, "test-model", 3) + embedder = PydanticAIEmbedder(mock_embedder, "test-model") adapter = CachingEmbeddingModel(embedder) vec: NormalizedEmbedding = np.array([1.0, 0.0, 0.0], dtype=np.float32) adapter.add_embedding("key", vec) @@ -194,13 +167,13 @@ async def test_embedding_adapter_add_embedding() -> None: @pytest.mark.asyncio -async def test_embedding_adapter_empty_batch() -> None: - """Empty batch returns empty array with known size.""" +async def test_embedding_adapter_empty_batch_returns_empty() -> None: + """Empty batch via CachingEmbeddingModel raises ValueError.""" mock_embedder = AsyncMock(spec=Embedder) - embedder = PydanticAIEmbedder(mock_embedder, "test-model", 4) + embedder = PydanticAIEmbedder(mock_embedder, "test-model") adapter = CachingEmbeddingModel(embedder) - result = await adapter.get_embeddings_nocache([]) - assert result.shape == (0, 4) + with pytest.raises(ValueError, match="Cannot embed an empty list"): + await adapter.get_embeddings([]) # --------------------------------------------------------------------------- diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 416ed3eb..81ccecc6 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -195,3 +195,28 @@ def test_fuzzy_lookup_embedding_in_subset( # Empty subset returns empty list result = vector_base.fuzzy_lookup_embedding_in_subset(query, []) assert result == [] + + +def test_add_embedding_size_mismatch(vector_base: VectorBase) -> None: + """Adding an embedding of wrong size raises ValueError.""" + emb3 = np.array([0.1, 0.2, 0.3], dtype=np.float32) + emb5 = np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32) + vector_base.add_embedding(None, emb3) + with pytest.raises(ValueError, match="Embedding size mismatch"): + vector_base.add_embedding(None, emb5) + + +def test_add_embeddings_size_mismatch(vector_base: VectorBase) -> None: + """Adding a batch of embeddings of wrong size raises ValueError.""" + batch3 = np.array([[0.1, 0.2, 0.3]], dtype=np.float32) + batch5 = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]], dtype=np.float32) + vector_base.add_embeddings(None, batch3) + with pytest.raises(ValueError, match="Embedding size mismatch"): + vector_base.add_embeddings(None, batch5) + + +def test_add_embeddings_wrong_ndim(vector_base: VectorBase) -> None: + """Adding a 1D array via add_embeddings raises ValueError.""" + emb1d = np.array([0.1, 0.2, 0.3], dtype=np.float32) + with pytest.raises(ValueError, match="Expected 2D"): + vector_base.add_embeddings(None, emb1d) From 091bd58564629d5346230af97aaf5e9e778a0c22 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 19:52:27 -0800 Subject: [PATCH 22/23] Change default embedding back to ada-002 for backwards compatibility --- src/typeagent/aitools/embeddings.py | 4 +--- src/typeagent/aitools/model_adapters.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index 5d0e172c..8b579df2 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -114,12 +114,10 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: return np.array([self._cache[k] for k in keys], dtype=np.float32) -DEFAULT_MODEL_NAME = "text-embedding-ada-002" -DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI TEST_MODEL_NAME = "test" model_to_envvar: dict[str, str] = { - DEFAULT_MODEL_NAME: DEFAULT_ENVVAR, + "text-embedding-ada-002": "AZURE_OPENAI_ENDPOINT_EMBEDDING", "text-embedding-3-small": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL", "text-embedding-3-large": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE", } diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 86dca394..164cfdca 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -235,7 +235,7 @@ def create_chat_model( return PydanticAIChatModel(model) -DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-3-small" +DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-ada-002" def create_embedding_model( From 7c1c5272e58069cf35964544eb6d40846659447c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 20:57:01 -0800 Subject: [PATCH 23/23] Add OPENAI_EMBEDDING_MODEL envvar to set the text embedding (e.g. text-embedding-3-small) --- src/typeagent/aitools/model_adapters.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 164cfdca..34d5ac84 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -229,6 +229,11 @@ def create_chat_model( if _needs_azure_fallback(provider): from pydantic_ai.models.openai import OpenAIChatModel + if os.getenv("OPENAI_MODEL"): + print( + f"OPENAI_MODEL={os.getenv('OPENAI_MODEL')!r} ignored; " + f"Azure deployment is determined by AZURE_OPENAI_ENDPOINT" + ) model = OpenAIChatModel(model_name, provider=_make_azure_provider()) else: model = infer_model(model_spec) @@ -247,19 +252,26 @@ def create_embedding_model( If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. - If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used. + If *model_spec* is ``None``, it is constructed from the + ``OPENAI_EMBEDDING_MODEL`` environment variable (falling back to + :data:`DEFAULT_EMBEDDING_SPEC`). Returns a :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` wrapping a :class:`PydanticAIEmbedder`. Examples:: + model = create_embedding_model() # uses OPENAI_EMBEDDING_MODEL or ada-002 model = create_embedding_model("openai:text-embedding-3-small") model = create_embedding_model("cohere:embed-english-v3.0") model = create_embedding_model("google:text-embedding-004") """ if model_spec is None: - model_spec = DEFAULT_EMBEDDING_SPEC + openai_embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL") + if openai_embedding_model: + model_spec = f"openai:{openai_embedding_model}" + else: + model_spec = DEFAULT_EMBEDDING_SPEC provider, _, model_name = model_spec.partition(":") if not model_name: model_name = provider # No colon in spec