diff --git a/.semversioner/next-release/patch-20260127131016120694.json b/.semversioner/next-release/patch-20260127131016120694.json new file mode 100644 index 0000000000..516466de8f --- /dev/null +++ b/.semversioner/next-release/patch-20260127131016120694.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add TableProvider abstraction for table-based storage operations" +} diff --git a/docs/examples_notebooks/index_migration_to_v1.ipynb b/docs/examples_notebooks/index_migration_to_v1.ipynb index c5b582d38d..3fd8f264bc 100644 --- a/docs/examples_notebooks/index_migration_to_v1.ipynb +++ b/docs/examples_notebooks/index_migration_to_v1.ipynb @@ -103,21 +103,22 @@ "source": [ "from uuid import uuid4\n", "\n", - "from graphrag.utils.storage import load_table_from_storage, write_table_to_storage\n", + "from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider\n", + "\n", + "# Create table provider from storage\n", + "table_provider = ParquetTableProvider(storage)\n", "\n", "# First we'll go through any parquet files that had model changes and update them\n", "# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n", "\n", - "final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n", - "final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n", - "final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n", - "final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n", - "final_relationships = await load_table_from_storage(\n", - " \"create_final_relationships\", storage\n", - ")\n", - "final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n", - "final_community_reports = await load_table_from_storage(\n", - " \"create_final_community_reports\", storage\n", + "final_documents = await table_provider.read_dataframe(\"create_final_documents\")\n", + "final_text_units = await table_provider.read_dataframe(\"create_final_text_units\")\n", + "final_entities = await table_provider.read_dataframe(\"create_final_entities\")\n", + "final_nodes = await table_provider.read_dataframe(\"create_final_nodes\")\n", + "final_relationships = await table_provider.read_dataframe(\"create_final_relationships\")\n", + "final_communities = await table_provider.read_dataframe(\"create_final_communities\")\n", + "final_community_reports = await table_provider.read_dataframe(\n", + " \"create_final_community_reports\"\n", ")\n", "\n", "\n", @@ -187,14 +188,14 @@ " parent_df, on=\"community\", how=\"left\"\n", " )\n", "\n", - "await write_table_to_storage(final_documents, \"create_final_documents\", storage)\n", - "await write_table_to_storage(final_text_units, \"create_final_text_units\", storage)\n", - "await write_table_to_storage(final_entities, \"create_final_entities\", storage)\n", - "await write_table_to_storage(final_nodes, \"create_final_nodes\", storage)\n", - "await write_table_to_storage(final_relationships, \"create_final_relationships\", storage)\n", - "await write_table_to_storage(final_communities, \"create_final_communities\", storage)\n", - "await write_table_to_storage(\n", - " final_community_reports, \"create_final_community_reports\", storage\n", + "await table_provider.write_dataframe(\"create_final_documents\", final_documents)\n", + "await table_provider.write_dataframe(\"create_final_text_units\", final_text_units)\n", + "await table_provider.write_dataframe(\"create_final_entities\", final_entities)\n", + "await table_provider.write_dataframe(\"create_final_nodes\", final_nodes)\n", + "await table_provider.write_dataframe(\"create_final_relationships\", final_relationships)\n", + "await table_provider.write_dataframe(\"create_final_communities\", final_communities)\n", + "await table_provider.write_dataframe(\n", + " \"create_final_community_reports\", final_community_reports\n", ")" ] }, diff --git a/docs/examples_notebooks/index_migration_to_v2.ipynb b/docs/examples_notebooks/index_migration_to_v2.ipynb index 0681d1a0b2..c71e27d945 100644 --- a/docs/examples_notebooks/index_migration_to_v2.ipynb +++ b/docs/examples_notebooks/index_migration_to_v2.ipynb @@ -65,28 +65,25 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", - "from graphrag.utils.storage import (\n", - " delete_table_from_storage,\n", - " load_table_from_storage,\n", - " write_table_to_storage,\n", - ")\n", + "from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider\n", "\n", - "final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n", - "final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n", - "final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n", - "final_covariates = await load_table_from_storage(\"create_final_covariates\", storage)\n", - "final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n", - "final_relationships = await load_table_from_storage(\n", - " \"create_final_relationships\", storage\n", - ")\n", - "final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n", - "final_community_reports = await load_table_from_storage(\n", - " \"create_final_community_reports\", storage\n", + "# Create table provider from storage\n", + "table_provider = ParquetTableProvider(storage)\n", + "\n", + "final_documents = await table_provider.read_dataframe(\"create_final_documents\")\n", + "final_text_units = await table_provider.read_dataframe(\"create_final_text_units\")\n", + "final_entities = await table_provider.read_dataframe(\"create_final_entities\")\n", + "final_covariates = await table_provider.read_dataframe(\"create_final_covariates\")\n", + "final_nodes = await table_provider.read_dataframe(\"create_final_nodes\")\n", + "final_relationships = await table_provider.read_dataframe(\"create_final_relationships\")\n", + "final_communities = await table_provider.read_dataframe(\"create_final_communities\")\n", + "final_community_reports = await table_provider.read_dataframe(\n", + " \"create_final_community_reports\"\n", ")\n", "\n", "# we've renamed document attributes as metadata\n", @@ -126,23 +123,23 @@ ")\n", "\n", "# we renamed all the output files for better clarity now that we don't have workflow naming constraints from DataShaper\n", - "await write_table_to_storage(final_documents, \"documents\", storage)\n", - "await write_table_to_storage(final_text_units, \"text_units\", storage)\n", - "await write_table_to_storage(final_entities, \"entities\", storage)\n", - "await write_table_to_storage(final_relationships, \"relationships\", storage)\n", - "await write_table_to_storage(final_covariates, \"covariates\", storage)\n", - "await write_table_to_storage(final_communities, \"communities\", storage)\n", - "await write_table_to_storage(final_community_reports, \"community_reports\", storage)\n", + "await table_provider.write_dataframe(\"documents\", final_documents)\n", + "await table_provider.write_dataframe(\"text_units\", final_text_units)\n", + "await table_provider.write_dataframe(\"entities\", final_entities)\n", + "await table_provider.write_dataframe(\"relationships\", final_relationships)\n", + "await table_provider.write_dataframe(\"covariates\", final_covariates)\n", + "await table_provider.write_dataframe(\"communities\", final_communities)\n", + "await table_provider.write_dataframe(\"community_reports\", final_community_reports)\n", "\n", "# delete all the old versions\n", - "await delete_table_from_storage(\"create_final_documents\", storage)\n", - "await delete_table_from_storage(\"create_final_text_units\", storage)\n", - "await delete_table_from_storage(\"create_final_entities\", storage)\n", - "await delete_table_from_storage(\"create_final_nodes\", storage)\n", - "await delete_table_from_storage(\"create_final_relationships\", storage)\n", - "await delete_table_from_storage(\"create_final_covariates\", storage)\n", - "await delete_table_from_storage(\"create_final_communities\", storage)\n", - "await delete_table_from_storage(\"create_final_community_reports\", storage)" + "await storage.delete(\"create_final_documents.parquet\")\n", + "await storage.delete(\"create_final_text_units.parquet\")\n", + "await storage.delete(\"create_final_entities.parquet\")\n", + "await storage.delete(\"create_final_nodes.parquet\")\n", + "await storage.delete(\"create_final_relationships.parquet\")\n", + "await storage.delete(\"create_final_covariates.parquet\")\n", + "await storage.delete(\"create_final_communities.parquet\")\n", + "await storage.delete(\"create_final_community_reports.parquet\")" ] } ], diff --git a/docs/examples_notebooks/index_migration_to_v3.ipynb b/docs/examples_notebooks/index_migration_to_v3.ipynb index a0e50be432..7f94dedee6 100644 --- a/docs/examples_notebooks/index_migration_to_v3.ipynb +++ b/docs/examples_notebooks/index_migration_to_v3.ipynb @@ -66,17 +66,17 @@ "metadata": {}, "outputs": [], "source": [ - "from graphrag.utils.storage import (\n", - " load_table_from_storage,\n", - " write_table_to_storage,\n", - ")\n", + "from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider\n", "\n", - "text_units = await load_table_from_storage(\"text_units\", storage)\n", + "# Create table provider from storage\n", + "table_provider = ParquetTableProvider(storage)\n", + "\n", + "text_units = await table_provider.read_dataframe(\"text_units\")\n", "\n", "text_units[\"document_id\"] = text_units[\"document_ids\"].apply(lambda ids: ids[0])\n", "remove_columns(text_units, [\"document_ids\"])\n", "\n", - "await write_table_to_storage(text_units, \"text_units\", storage)" + "await table_provider.write_dataframe(\"text_units\", text_units)" ] }, { diff --git a/packages/graphrag-storage/graphrag_storage/__init__.py b/packages/graphrag-storage/graphrag_storage/__init__.py index 2ae67be741..454842eecc 100644 --- a/packages/graphrag-storage/graphrag_storage/__init__.py +++ b/packages/graphrag-storage/graphrag_storage/__init__.py @@ -10,11 +10,13 @@ register_storage, ) from graphrag_storage.storage_type import StorageType +from graphrag_storage.tables import TableProvider __all__ = [ "Storage", "StorageConfig", "StorageType", + "TableProvider", "create_storage", "register_storage", ] diff --git a/packages/graphrag-storage/graphrag_storage/tables/__init__.py b/packages/graphrag-storage/graphrag_storage/tables/__init__.py new file mode 100644 index 0000000000..0210d935f3 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Table provider module for GraphRAG storage.""" + +from .table_provider import TableProvider + +__all__ = ["TableProvider"] diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py new file mode 100644 index 0000000000..75805be23a --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parquet-based table provider implementation.""" + +import logging +import re +from io import BytesIO + +import pandas as pd + +from graphrag_storage.storage import Storage +from graphrag_storage.tables.table_provider import TableProvider + +logger = logging.getLogger(__name__) + + +class ParquetTableProvider(TableProvider): + """Table provider that stores tables as Parquet files using an underlying Storage instance. + + This provider converts between pandas DataFrames and Parquet format, + storing the data through a Storage backend (file, blob, cosmos, etc.). + """ + + def __init__(self, storage: Storage, **kwargs) -> None: + """Initialize the Parquet table provider with an underlying storage instance. + + Args + ---- + storage: Storage + The storage instance to use for reading and writing Parquet files. + **kwargs: Any + Additional keyword arguments (currently unused). + """ + self._storage = storage + + async def read_dataframe(self, table_name: str) -> pd.DataFrame: + """Read a table from storage as a pandas DataFrame. + + Args + ---- + table_name: str + The name of the table to read. The file will be accessed as '{table_name}.parquet'. + + Returns + ------- + pd.DataFrame: + The table data loaded from the Parquet file. + + Raises + ------ + ValueError: + If the table file does not exist in storage. + Exception: + If there is an error reading or parsing the Parquet file. + """ + filename = f"{table_name}.parquet" + if not await self._storage.has(filename): + msg = f"Could not find {filename} in storage!" + raise ValueError(msg) + try: + logger.info("reading table from storage: %s", filename) + return pd.read_parquet( + BytesIO(await self._storage.get(filename, as_bytes=True)) + ) + except Exception: + logger.exception("error loading table from storage: %s", filename) + raise + + async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: + """Write a pandas DataFrame to storage as a Parquet file. + + Args + ---- + table_name: str + The name of the table to write. The file will be saved as '{table_name}.parquet'. + df: pd.DataFrame + The DataFrame to write to storage. + """ + await self._storage.set(f"{table_name}.parquet", df.to_parquet()) + + async def has_dataframe(self, table_name: str) -> bool: + """Check if a table exists in storage. + + Args + ---- + table_name: str + The name of the table to check. + + Returns + ------- + bool: + True if the table exists, False otherwise. + """ + return await self._storage.has(f"{table_name}.parquet") + + def find_tables(self) -> list[str]: + """Find all table names in storage. + + Returns + ------- + list[str]: + List of table names (without .parquet extension). + """ + return [ + file.replace(".parquet", "") + for file in self._storage.find(re.compile(r"\.parquet$")) + ] diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py new file mode 100644 index 0000000000..0d48480892 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Abstract base class for table providers.""" + +from abc import ABC, abstractmethod +from typing import Any + +import pandas as pd + + +class TableProvider(ABC): + """Provide a table-based storage interface with support for DataFrames and row dictionaries.""" + + @abstractmethod + def __init__(self, **kwargs: Any) -> None: + """Create a table provider instance. + + Args + ---- + **kwargs: Any + Keyword arguments for initialization, may include underlying Storage instance. + """ + + @abstractmethod + async def read_dataframe(self, table_name: str) -> pd.DataFrame: + """Read entire table as a pandas DataFrame. + + Args + ---- + table_name: str + The name of the table to read. + + Returns + ------- + pd.DataFrame: + The table data as a DataFrame. + """ + + @abstractmethod + async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None: + """Write entire table from a pandas DataFrame. + + Args + ---- + table_name: str + The name of the table to write. + df: pd.DataFrame + The DataFrame to write as a table. + """ + + @abstractmethod + async def has_dataframe(self, table_name: str) -> bool: + """Check if a table exists in the provider. + + Args + ---- + table_name: str + The name of the table to check. + + Returns + ------- + bool: + True if the table exists, False otherwise. + """ + + @abstractmethod + def find_tables(self) -> list[str]: + """Find all table names in the provider. + + Returns + ------- + list[str]: + List of table names (without file extensions). + """ diff --git a/packages/graphrag/graphrag/cli/query.py b/packages/graphrag/graphrag/cli/query.py index ae06a88c95..1f808420d4 100644 --- a/packages/graphrag/graphrag/cli/query.py +++ b/packages/graphrag/graphrag/cli/query.py @@ -9,12 +9,12 @@ from typing import TYPE_CHECKING, Any from graphrag_storage import create_storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider import graphrag.api as api from graphrag.callbacks.noop_query_callbacks import NoopQueryCallbacks from graphrag.config.load_config import load_config from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.utils.storage import load_table_from_storage, storage_has_table if TYPE_CHECKING: import pandas as pd @@ -378,18 +378,17 @@ def _resolve_output_files( """Read indexing output files to a dataframe dict.""" dataframe_dict = {} storage_obj = create_storage(config.output_storage) + table_provider = ParquetTableProvider(storage_obj) for name in output_list: - df_value = asyncio.run(load_table_from_storage(name=name, storage=storage_obj)) + df_value = asyncio.run(table_provider.read_dataframe(name)) dataframe_dict[name] = df_value # for optional output files, set the dict entry to None instead of erroring out if it does not exist if optional_list: for optional_file in optional_list: - file_exists = asyncio.run(storage_has_table(optional_file, storage_obj)) + file_exists = asyncio.run(table_provider.has_dataframe(optional_file)) if file_exists: - df_value = asyncio.run( - load_table_from_storage(name=optional_file, storage=storage_obj) - ) + df_value = asyncio.run(table_provider.read_dataframe(optional_file)) dataframe_dict[optional_file] = df_value else: dataframe_dict[optional_file] = None diff --git a/packages/graphrag/graphrag/index/run/run_pipeline.py b/packages/graphrag/graphrag/index/run/run_pipeline.py index a4ce17582c..24ff39cc07 100644 --- a/packages/graphrag/graphrag/index/run/run_pipeline.py +++ b/packages/graphrag/graphrag/index/run/run_pipeline.py @@ -5,7 +5,6 @@ import json import logging -import re import time from collections.abc import AsyncIterable from dataclasses import asdict @@ -13,7 +12,9 @@ import pandas as pd from graphrag_cache import create_cache -from graphrag_storage import Storage, create_storage +from graphrag_storage import create_storage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider +from graphrag_storage.tables.table_provider import TableProvider from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -21,7 +22,6 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -36,6 +36,8 @@ async def run_pipeline( ) -> AsyncIterable[PipelineRunResult]: """Run all workflows using a simplified pipeline.""" input_storage = create_storage(config.input_storage) + input_table_provider = ParquetTableProvider(input_storage) + output_storage = create_storage(config.output_storage) cache = create_cache(config.cache) @@ -54,22 +56,28 @@ async def run_pipeline( update_timestamp = time.strftime("%Y%m%d-%H%M%S") timestamped_storage = update_storage.child(update_timestamp) delta_storage = timestamped_storage.child("delta") + delta_table_provider = ParquetTableProvider(delta_storage) # copy the previous output to a backup folder, so we can replace it with the update # we'll read from this later when we merge the old and new indexes previous_storage = timestamped_storage.child("previous") - await _copy_previous_output(output_storage, previous_storage) + previous_table_provider = ParquetTableProvider(previous_storage) + + output_table_provider = ParquetTableProvider(output_storage) + await _copy_previous_output(output_table_provider, previous_table_provider) state["update_timestamp"] = update_timestamp # if the user passes in a df directly, write directly to storage so we can skip finding/parsing later if input_documents is not None: - await write_table_to_storage(input_documents, "documents", delta_storage) + await delta_table_provider.write_dataframe("documents", input_documents) pipeline.remove("load_update_documents") context = create_run_context( input_storage=input_storage, + input_table_provider=input_table_provider, output_storage=delta_storage, - previous_storage=previous_storage, + output_table_provider=delta_table_provider, + previous_table_provider=previous_table_provider, cache=cache, callbacks=callbacks, state=state, @@ -80,12 +88,15 @@ async def run_pipeline( # if the user passes in a df directly, write directly to storage so we can skip finding/parsing later if input_documents is not None: - await write_table_to_storage(input_documents, "documents", output_storage) + output_table_provider = ParquetTableProvider(output_storage) + await output_table_provider.write_dataframe("documents", input_documents) pipeline.remove("load_input_documents") context = create_run_context( input_storage=input_storage, + input_table_provider=input_table_provider, output_storage=output_storage, + output_table_provider=ParquetTableProvider(storage=output_storage), cache=cache, callbacks=callbacks, state=state, @@ -156,10 +167,10 @@ async def _dump_json(context: PipelineRunContext) -> None: async def _copy_previous_output( - storage: Storage, - copy_storage: Storage, -): - for file in storage.find(re.compile(r"\.parquet$")): - base_name = file.replace(".parquet", "") - table = await load_table_from_storage(base_name, storage) - await write_table_to_storage(table, base_name, copy_storage) + output_table_provider: TableProvider, + previous_table_provider: TableProvider, +) -> None: + """Copy all parquet tables from output to previous storage for backup.""" + for table_name in output_table_provider.find_tables(): + table = await output_table_provider.read_dataframe(table_name) + await previous_table_provider.write_dataframe(table_name, table) diff --git a/packages/graphrag/graphrag/index/run/utils.py b/packages/graphrag/graphrag/index/run/utils.py index be6914a6d6..207e9561a0 100644 --- a/packages/graphrag/graphrag/index/run/utils.py +++ b/packages/graphrag/graphrag/index/run/utils.py @@ -7,6 +7,7 @@ from graphrag_cache.memory_cache import MemoryCache from graphrag_storage import Storage, create_storage from graphrag_storage.memory_storage import MemoryStorage +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks @@ -19,18 +20,26 @@ def create_run_context( input_storage: Storage | None = None, + input_table_provider: ParquetTableProvider | None = None, output_storage: Storage | None = None, - previous_storage: Storage | None = None, + output_table_provider: ParquetTableProvider | None = None, + previous_table_provider: ParquetTableProvider | None = None, cache: Cache | None = None, callbacks: WorkflowCallbacks | None = None, stats: PipelineRunStats | None = None, state: PipelineState | None = None, ) -> PipelineRunContext: """Create the run context for the pipeline.""" + input_storage = input_storage or MemoryStorage() + output_storage = output_storage or MemoryStorage() return PipelineRunContext( - input_storage=input_storage or MemoryStorage(), - output_storage=output_storage or MemoryStorage(), - previous_storage=previous_storage or MemoryStorage(), + input_storage=input_storage, + input_table_provider=input_table_provider + or ParquetTableProvider(storage=input_storage), + output_storage=output_storage, + output_table_provider=output_table_provider + or ParquetTableProvider(storage=output_storage), + previous_table_provider=previous_table_provider, cache=cache or MemoryCache(), callbacks=callbacks or NoopWorkflowCallbacks(), stats=stats or PipelineRunStats(), @@ -48,14 +57,18 @@ def create_callback_chain( return manager -def get_update_storages( +def get_update_table_providers( config: GraphRagConfig, timestamp: str -) -> tuple[Storage, Storage, Storage]: - """Get storage objects for the update index run.""" +) -> tuple[ParquetTableProvider, ParquetTableProvider, ParquetTableProvider]: + """Get table providers for the update index run.""" output_storage = create_storage(config.output_storage) update_storage = create_storage(config.update_output_storage) timestamped_storage = update_storage.child(timestamp) delta_storage = timestamped_storage.child("delta") previous_storage = timestamped_storage.child("previous") - return output_storage, previous_storage, delta_storage + output_table_provider = ParquetTableProvider(output_storage) + previous_table_provider = ParquetTableProvider(previous_storage) + delta_table_provider = ParquetTableProvider(delta_storage) + + return output_table_provider, previous_table_provider, delta_table_provider diff --git a/packages/graphrag/graphrag/index/typing/context.py b/packages/graphrag/graphrag/index/typing/context.py index 95e7f898f9..f606218dd2 100644 --- a/packages/graphrag/graphrag/index/typing/context.py +++ b/packages/graphrag/graphrag/index/typing/context.py @@ -10,7 +10,7 @@ from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats -from graphrag_storage import Storage +from graphrag_storage import Storage, TableProvider @dataclass @@ -19,11 +19,15 @@ class PipelineRunContext: stats: PipelineRunStats input_storage: Storage - "Storage for input documents." + "Storage for reading input documents." + input_table_provider: TableProvider + "Table provider for reading input tables." output_storage: Storage "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." - previous_storage: Storage - "Storage for previous pipeline run when running in update mode." + output_table_provider: TableProvider + "Table provider for reading and writing output tables." + previous_table_provider: TableProvider | None + "Table provider for reading previous pipeline run when running in update mode." cache: Cache "Cache instance for reading previous LLM responses." callbacks: WorkflowCallbacks diff --git a/packages/graphrag/graphrag/index/update/incremental_index.py b/packages/graphrag/graphrag/index/update/incremental_index.py index 81f917e187..0e7eb34684 100644 --- a/packages/graphrag/graphrag/index/update/incremental_index.py +++ b/packages/graphrag/graphrag/index/update/incremental_index.py @@ -7,12 +7,7 @@ import numpy as np import pandas as pd -from graphrag_storage import Storage - -from graphrag.utils.storage import ( - load_table_from_storage, - write_table_to_storage, -) +from graphrag_storage.tables.table_provider import TableProvider @dataclass @@ -31,22 +26,24 @@ class InputDelta: deleted_inputs: pd.DataFrame -async def get_delta_docs(input_dataset: pd.DataFrame, storage: Storage) -> InputDelta: +async def get_delta_docs( + input_dataset: pd.DataFrame, table_provider: TableProvider +) -> InputDelta: """Get the delta between the input dataset and the final documents. Parameters ---------- input_dataset : pd.DataFrame The input dataset. - storage : Storage - The Pipeline storage. + table_provider : TableProvider + The table provider for reading previous documents. Returns ------- InputDelta The input delta. With new inputs and deleted inputs. """ - final_docs = await load_table_from_storage("documents", storage) + final_docs = await table_provider.read_dataframe("documents") # Select distinct title from final docs and from dataset previous_docs: list[str] = final_docs["title"].unique().tolist() @@ -63,19 +60,19 @@ async def get_delta_docs(input_dataset: pd.DataFrame, storage: Storage) -> Input async def concat_dataframes( name: str, - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, ) -> pd.DataFrame: """Concatenate dataframes.""" - old_df = await load_table_from_storage(name, previous_storage) - delta_df = await load_table_from_storage(name, delta_storage) + old_df = await previous_table_provider.read_dataframe(name) + delta_df = await delta_table_provider.read_dataframe(name) # Merge the final documents initial_id = old_df["human_readable_id"].max() + 1 delta_df["human_readable_id"] = np.arange(initial_id, initial_id + len(delta_df)) final_df = pd.concat([old_df, delta_df], ignore_index=True, copy=False) - await write_table_to_storage(final_df, name, output_storage) + await output_table_provider.write_dataframe(name, final_df) return final_df diff --git a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py index ec6abc2578..2d53fd8e6f 100644 --- a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py @@ -20,7 +20,6 @@ from graphrag.index.utils.hashing import gen_sha512_hash from graphrag.logger.progress import progress_ticker from graphrag.tokenizer.get_tokenizer import get_tokenizer -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform base text_units.""" logger.info("Workflow started: create_base_text_units") - documents = await load_table_from_storage("documents", context.output_storage) + documents = await context.output_table_provider.read_dataframe("documents") tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model) chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode) @@ -43,7 +42,7 @@ async def run_workflow( prepend_metadata=config.chunking.prepend_metadata, ) - await write_table_to_storage(output, "text_units", context.output_storage) + await context.output_table_provider.write_dataframe("text_units", output) logger.info("Workflow completed: create_base_text_units") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/create_communities.py b/packages/graphrag/graphrag/index/workflows/create_communities.py index 4394593e99..7c3d7a6b33 100644 --- a/packages/graphrag/graphrag/index/workflows/create_communities.py +++ b/packages/graphrag/graphrag/index/workflows/create_communities.py @@ -17,7 +17,6 @@ from graphrag.index.operations.create_graph import create_graph from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -28,10 +27,8 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform final communities.""" logger.info("Workflow started: create_communities") - entities = await load_table_from_storage("entities", context.output_storage) - relationships = await load_table_from_storage( - "relationships", context.output_storage - ) + entities = await context.output_table_provider.read_dataframe("entities") + relationships = await context.output_table_provider.read_dataframe("relationships") max_cluster_size = config.cluster_graph.max_cluster_size use_lcc = config.cluster_graph.use_lcc @@ -45,7 +42,7 @@ async def run_workflow( seed=seed, ) - await write_table_to_storage(output, "communities", context.output_storage) + await context.output_table_provider.write_dataframe("communities", output) logger.info("Workflow completed: create_communities") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/create_community_reports.py b/packages/graphrag/graphrag/index/workflows/create_community_reports.py index abfdeca45a..6f8b061a30 100644 --- a/packages/graphrag/graphrag/index/workflows/create_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/create_community_reports.py @@ -30,11 +30,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import ( - load_table_from_storage, - storage_has_table, - write_table_to_storage, -) if TYPE_CHECKING: from graphrag_llm.completion import LLMCompletion @@ -48,14 +43,16 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" logger.info("Workflow started: create_community_reports") - edges = await load_table_from_storage("relationships", context.output_storage) - entities = await load_table_from_storage("entities", context.output_storage) - communities = await load_table_from_storage("communities", context.output_storage) + edges = await context.output_table_provider.read_dataframe("relationships") + entities = await context.output_table_provider.read_dataframe("entities") + communities = await context.output_table_provider.read_dataframe("communities") + claims = None - if config.extract_claims.enabled and await storage_has_table( - "covariates", context.output_storage + if ( + config.extract_claims.enabled + and await context.output_table_provider.has_dataframe("covariates") ): - claims = await load_table_from_storage("covariates", context.output_storage) + claims = await context.output_table_provider.read_dataframe("covariates") model_config = config.get_completion_model_config( config.community_reports.completion_model_id @@ -85,7 +82,7 @@ async def run_workflow( async_type=config.async_mode, ) - await write_table_to_storage(output, "community_reports", context.output_storage) + await context.output_table_provider.write_dataframe("community_reports", output) logger.info("Workflow completed: create_community_reports") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py b/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py index 8a6be96e68..52cb4b0f8e 100644 --- a/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py +++ b/packages/graphrag/graphrag/index/workflows/create_community_reports_text.py @@ -29,7 +29,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage if TYPE_CHECKING: from graphrag_llm.completion import LLMCompletion @@ -43,10 +42,9 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" logger.info("Workflow started: create_community_reports_text") - entities = await load_table_from_storage("entities", context.output_storage) - communities = await load_table_from_storage("communities", context.output_storage) - - text_units = await load_table_from_storage("text_units", context.output_storage) + entities = await context.output_table_provider.read_dataframe("entities") + communities = await context.output_table_provider.read_dataframe("communities") + text_units = await context.output_table_provider.read_dataframe("text_units") model_config = config.get_completion_model_config( config.community_reports.completion_model_id @@ -75,7 +73,7 @@ async def run_workflow( async_type=config.async_mode, ) - await write_table_to_storage(output, "community_reports", context.output_storage) + await context.output_table_provider.write_dataframe("community_reports", output) logger.info("Workflow completed: create_community_reports_text") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/create_final_documents.py b/packages/graphrag/graphrag/index/workflows/create_final_documents.py index 554fbc4254..c799d1bb44 100644 --- a/packages/graphrag/graphrag/index/workflows/create_final_documents.py +++ b/packages/graphrag/graphrag/index/workflows/create_final_documents.py @@ -11,7 +11,6 @@ from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -22,12 +21,12 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform final documents.""" logger.info("Workflow started: create_final_documents") - documents = await load_table_from_storage("documents", context.output_storage) - text_units = await load_table_from_storage("text_units", context.output_storage) + documents = await context.output_table_provider.read_dataframe("documents") + text_units = await context.output_table_provider.read_dataframe("text_units") output = create_final_documents(documents, text_units) - await write_table_to_storage(output, "documents", context.output_storage) + await context.output_table_provider.write_dataframe("documents", output) logger.info("Workflow completed: create_final_documents") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/create_final_text_units.py b/packages/graphrag/graphrag/index/workflows/create_final_text_units.py index c16e08bb7c..9c897b28f3 100644 --- a/packages/graphrag/graphrag/index/workflows/create_final_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_final_text_units.py @@ -11,11 +11,6 @@ from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import ( - load_table_from_storage, - storage_has_table, - write_table_to_storage, -) logger = logging.getLogger(__name__) @@ -26,17 +21,19 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform the text units.""" logger.info("Workflow started: create_final_text_units") - text_units = await load_table_from_storage("text_units", context.output_storage) - final_entities = await load_table_from_storage("entities", context.output_storage) - final_relationships = await load_table_from_storage( - "relationships", context.output_storage + text_units = await context.output_table_provider.read_dataframe("text_units") + final_entities = await context.output_table_provider.read_dataframe("entities") + final_relationships = await context.output_table_provider.read_dataframe( + "relationships" ) + final_covariates = None - if config.extract_claims.enabled and await storage_has_table( - "covariates", context.output_storage + if ( + config.extract_claims.enabled + and await context.output_table_provider.has_dataframe("covariates") ): - final_covariates = await load_table_from_storage( - "covariates", context.output_storage + final_covariates = await context.output_table_provider.read_dataframe( + "covariates" ) output = create_final_text_units( @@ -46,7 +43,7 @@ async def run_workflow( final_covariates, ) - await write_table_to_storage(output, "text_units", context.output_storage) + await context.output_table_provider.write_dataframe("text_units", output) logger.info("Workflow completed: create_final_text_units") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/extract_covariates.py b/packages/graphrag/graphrag/index/workflows/extract_covariates.py index 18b470a8b1..f27d8590d1 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/extract_covariates.py @@ -21,7 +21,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage if TYPE_CHECKING: from graphrag_llm.completion import LLMCompletion @@ -37,7 +36,7 @@ async def run_workflow( logger.info("Workflow started: extract_covariates") output = None if config.extract_claims.enabled: - text_units = await load_table_from_storage("text_units", context.output_storage) + text_units = await context.output_table_provider.read_dataframe("text_units") model_config = config.get_completion_model_config( config.extract_claims.completion_model_id @@ -64,7 +63,7 @@ async def run_workflow( async_type=config.async_mode, ) - await write_table_to_storage(output, "covariates", context.output_storage) + await context.output_table_provider.write_dataframe("covariates", output) logger.info("Workflow completed: extract_covariates") return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph.py b/packages/graphrag/graphrag/index/workflows/extract_graph.py index 6d6520e401..237bbe16cc 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph.py @@ -21,7 +21,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage if TYPE_CHECKING: from graphrag_llm.completion import LLMCompletion @@ -35,7 +34,7 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: extract_graph") - text_units = await load_table_from_storage("text_units", context.output_storage) + text_units = await context.output_table_provider.read_dataframe("text_units") extraction_model_config = config.get_completion_model_config( config.extract_graph.completion_model_id @@ -73,15 +72,15 @@ async def run_workflow( summarization_num_threads=config.concurrent_requests, ) - await write_table_to_storage(entities, "entities", context.output_storage) - await write_table_to_storage(relationships, "relationships", context.output_storage) + await context.output_table_provider.write_dataframe("entities", entities) + await context.output_table_provider.write_dataframe("relationships", relationships) if config.snapshots.raw_graph: - await write_table_to_storage( - raw_entities, "raw_entities", context.output_storage + await context.output_table_provider.write_dataframe( + "raw_entities", raw_entities ) - await write_table_to_storage( - raw_relationships, "raw_relationships", context.output_storage + await context.output_table_provider.write_dataframe( + "raw_relationships", raw_relationships ) logger.info("Workflow completed: extract_graph") diff --git a/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py b/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py index 38810e5de4..c0cd069ac6 100644 --- a/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py +++ b/packages/graphrag/graphrag/index/workflows/extract_graph_nlp.py @@ -19,7 +19,6 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -30,7 +29,7 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: extract_graph_nlp") - text_units = await load_table_from_storage("text_units", context.output_storage) + text_units = await context.output_table_provider.read_dataframe("text_units") text_analyzer_config = config.extract_graph_nlp.text_analyzer text_analyzer = create_noun_phrase_extractor(text_analyzer_config) @@ -44,8 +43,8 @@ async def run_workflow( async_type=config.extract_graph_nlp.async_mode, ) - await write_table_to_storage(entities, "entities", context.output_storage) - await write_table_to_storage(relationships, "relationships", context.output_storage) + await context.output_table_provider.write_dataframe("entities", entities) + await context.output_table_provider.write_dataframe("relationships", relationships) logger.info("Workflow completed: extract_graph_nlp") diff --git a/packages/graphrag/graphrag/index/workflows/finalize_graph.py b/packages/graphrag/graphrag/index/workflows/finalize_graph.py index 49529aea3a..64029a8cb6 100644 --- a/packages/graphrag/graphrag/index/workflows/finalize_graph.py +++ b/packages/graphrag/graphrag/index/workflows/finalize_graph.py @@ -14,7 +14,6 @@ from graphrag.index.operations.snapshot_graphml import snapshot_graphml from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -25,19 +24,17 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: finalize_graph") - entities = await load_table_from_storage("entities", context.output_storage) - relationships = await load_table_from_storage( - "relationships", context.output_storage - ) + entities = await context.output_table_provider.read_dataframe("entities") + relationships = await context.output_table_provider.read_dataframe("relationships") final_entities, final_relationships = finalize_graph( entities, relationships, ) - await write_table_to_storage(final_entities, "entities", context.output_storage) - await write_table_to_storage( - final_relationships, "relationships", context.output_storage + await context.output_table_provider.write_dataframe("entities", final_entities) + await context.output_table_provider.write_dataframe( + "relationships", final_relationships ) if config.snapshots.graphml: diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index 16b726028e..c1e42969ee 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -25,10 +25,6 @@ from graphrag.index.operations.embed_text.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import ( - load_table_from_storage, - write_table_to_storage, -) if TYPE_CHECKING: from graphrag_llm.embedding import LLMEmbedding @@ -48,12 +44,12 @@ async def run_workflow( entities = None community_reports = None if text_unit_text_embedding in embedded_fields: - text_units = await load_table_from_storage("text_units", context.output_storage) + text_units = await context.output_table_provider.read_dataframe("text_units") if entity_description_embedding in embedded_fields: - entities = await load_table_from_storage("entities", context.output_storage) + entities = await context.output_table_provider.read_dataframe("entities") if community_full_content_embedding in embedded_fields: - community_reports = await load_table_from_storage( - "community_reports", context.output_storage + community_reports = await context.output_table_provider.read_dataframe( + "community_reports" ) model_config = config.get_embedding_model_config( @@ -84,10 +80,9 @@ async def run_workflow( if config.snapshots.embeddings: for name, table in output.items(): - await write_table_to_storage( - table, + await context.output_table_provider.write_dataframe( f"embeddings.{name}", - context.output_storage, + table, ) logger.info("Workflow completed: generate_text_embeddings") diff --git a/packages/graphrag/graphrag/index/workflows/load_input_documents.py b/packages/graphrag/graphrag/index/workflows/load_input_documents.py index 0a5aa65454..ed7f83c8e2 100644 --- a/packages/graphrag/graphrag/index/workflows/load_input_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_input_documents.py @@ -11,7 +11,6 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -33,7 +32,7 @@ async def run_workflow( logger.info("Final # of rows loaded: %s", len(output)) context.stats.num_documents = len(output) - await write_table_to_storage(output, "documents", context.output_storage) + await context.output_table_provider.write_dataframe("documents", output) return WorkflowFunctionOutput(result=output) diff --git a/packages/graphrag/graphrag/index/workflows/load_update_documents.py b/packages/graphrag/graphrag/index/workflows/load_update_documents.py index 1cab6cabfe..3f4417d3e1 100644 --- a/packages/graphrag/graphrag/index/workflows/load_update_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_update_documents.py @@ -8,13 +8,12 @@ import pandas as pd from graphrag_input.input_reader import InputReader from graphrag_input.input_reader_factory import create_input_reader -from graphrag_storage import Storage +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.incremental_index import get_delta_docs -from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -24,10 +23,14 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """Load and parse update-only input documents into a standard format.""" + if context.previous_table_provider is None: + msg = "previous_table_provider is required for update workflows" + raise ValueError(msg) + input_reader = create_input_reader(config.input, context.input_storage) output = await load_update_documents( input_reader, - context.previous_storage, + context.previous_table_provider, ) logger.info("Final # of update rows loaded: %s", len(output)) @@ -37,18 +40,18 @@ async def run_workflow( logger.warning("No new update documents found.") return WorkflowFunctionOutput(result=None, stop=True) - await write_table_to_storage(output, "documents", context.output_storage) + await context.output_table_provider.write_dataframe("documents", output) return WorkflowFunctionOutput(result=output) async def load_update_documents( input_reader: InputReader, - previous_storage: Storage, + previous_table_provider: TableProvider, ) -> pd.DataFrame: """Load and parse update-only input documents into a standard format.""" input_documents = pd.DataFrame(await input_reader.read_files()) - # previous storage is the output of the previous run + # previous table provider has the output of the previous run # we'll use this to diff the input from the prior - delta_documents = await get_delta_docs(input_documents, previous_storage) + delta_documents = await get_delta_docs(input_documents, previous_table_provider) return delta_documents.new_inputs diff --git a/packages/graphrag/graphrag/index/workflows/prune_graph.py b/packages/graphrag/graphrag/index/workflows/prune_graph.py index 5653eef49b..483c9b18b3 100644 --- a/packages/graphrag/graphrag/index/workflows/prune_graph.py +++ b/packages/graphrag/graphrag/index/workflows/prune_graph.py @@ -14,7 +14,6 @@ from graphrag.index.operations.prune_graph import prune_graph as prune_graph_operation from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -25,10 +24,8 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" logger.info("Workflow started: prune_graph") - entities = await load_table_from_storage("entities", context.output_storage) - relationships = await load_table_from_storage( - "relationships", context.output_storage - ) + entities = await context.output_table_provider.read_dataframe("entities") + relationships = await context.output_table_provider.read_dataframe("relationships") pruned_entities, pruned_relationships = prune_graph( entities, @@ -36,9 +33,9 @@ async def run_workflow( pruning_config=config.prune_graph, ) - await write_table_to_storage(pruned_entities, "entities", context.output_storage) - await write_table_to_storage( - pruned_relationships, "relationships", context.output_storage + await context.output_table_provider.write_dataframe("entities", pruned_entities) + await context.output_table_provider.write_dataframe( + "relationships", pruned_relationships ) logger.info("Workflow completed: prune_graph") diff --git a/packages/graphrag/graphrag/index/workflows/update_communities.py b/packages/graphrag/graphrag/index/workflows/update_communities.py index da4fdef147..7887706a86 100644 --- a/packages/graphrag/graphrag/index/workflows/update_communities.py +++ b/packages/graphrag/graphrag/index/workflows/update_communities.py @@ -5,14 +5,13 @@ import logging -from graphrag_storage import Storage +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_communities -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -23,12 +22,12 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the communities from a incremental index run.""" logger.info("Workflow started: update_communities") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) community_id_mapping = await _update_communities( - previous_storage, delta_storage, output_storage + previous_table_provider, delta_table_provider, output_table_provider ) context.state["incremental_update_community_id_mapping"] = community_id_mapping @@ -38,17 +37,17 @@ async def run_workflow( async def _update_communities( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, ) -> dict: """Update the communities output.""" - old_communities = await load_table_from_storage("communities", previous_storage) - delta_communities = await load_table_from_storage("communities", delta_storage) + old_communities = await previous_table_provider.read_dataframe("communities") + delta_communities = await delta_table_provider.read_dataframe("communities") merged_communities, community_id_mapping = _update_and_merge_communities( old_communities, delta_communities ) - await write_table_to_storage(merged_communities, "communities", output_storage) + await output_table_provider.write_dataframe("communities", merged_communities) return community_id_mapping diff --git a/packages/graphrag/graphrag/index/workflows/update_community_reports.py b/packages/graphrag/graphrag/index/workflows/update_community_reports.py index 790f9fc296..9c9b0f2fec 100644 --- a/packages/graphrag/graphrag/index/workflows/update_community_reports.py +++ b/packages/graphrag/graphrag/index/workflows/update_community_reports.py @@ -6,14 +6,13 @@ import logging import pandas as pd -from graphrag_storage import Storage +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.communities import _update_and_merge_community_reports -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -24,14 +23,17 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the community reports from a incremental index run.""" logger.info("Workflow started: update_community_reports") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) community_id_mapping = context.state["incremental_update_community_id_mapping"] merged_community_reports = await _update_community_reports( - previous_storage, delta_storage, output_storage, community_id_mapping + previous_table_provider, + delta_table_provider, + output_table_provider, + community_id_mapping, ) context.state["incremental_update_merged_community_reports"] = ( @@ -43,24 +45,24 @@ async def run_workflow( async def _update_community_reports( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, community_id_mapping: dict, ) -> pd.DataFrame: """Update the community reports output.""" - old_community_reports = await load_table_from_storage( - "community_reports", previous_storage + old_community_reports = await previous_table_provider.read_dataframe( + "community_reports" ) - delta_community_reports = await load_table_from_storage( - "community_reports", delta_storage + delta_community_reports = await delta_table_provider.read_dataframe( + "community_reports" ) merged_community_reports = _update_and_merge_community_reports( old_community_reports, delta_community_reports, community_id_mapping ) - await write_table_to_storage( - merged_community_reports, "community_reports", output_storage + await output_table_provider.write_dataframe( + "community_reports", merged_community_reports ) return merged_community_reports diff --git a/packages/graphrag/graphrag/index/workflows/update_covariates.py b/packages/graphrag/graphrag/index/workflows/update_covariates.py index 09f8b4053d..a2c1a834fb 100644 --- a/packages/graphrag/graphrag/index/workflows/update_covariates.py +++ b/packages/graphrag/graphrag/index/workflows/update_covariates.py @@ -7,17 +7,12 @@ import numpy as np import pandas as pd -from graphrag_storage import Storage +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import ( - load_table_from_storage, - storage_has_table, - write_table_to_storage, -) logger = logging.getLogger(__name__) @@ -28,31 +23,33 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the covariates from a incremental index run.""" logger.info("Workflow started: update_covariates") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) - if await storage_has_table( - "covariates", previous_storage - ) and await storage_has_table("covariates", delta_storage): + if await previous_table_provider.has_dataframe( + "covariates" + ) and await delta_table_provider.has_dataframe("covariates"): logger.info("Updating Covariates") - await _update_covariates(previous_storage, delta_storage, output_storage) + await _update_covariates( + previous_table_provider, delta_table_provider, output_table_provider + ) logger.info("Workflow completed: update_covariates") return WorkflowFunctionOutput(result=None) async def _update_covariates( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, ) -> None: """Update the covariates output.""" - old_covariates = await load_table_from_storage("covariates", previous_storage) - delta_covariates = await load_table_from_storage("covariates", delta_storage) + old_covariates = await previous_table_provider.read_dataframe("covariates") + delta_covariates = await delta_table_provider.read_dataframe("covariates") merged_covariates = _merge_covariates(old_covariates, delta_covariates) - await write_table_to_storage(merged_covariates, "covariates", output_storage) + await output_table_provider.write_dataframe("covariates", merged_covariates) def _merge_covariates( diff --git a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py index 225c12d9b9..c7d1bcc416 100644 --- a/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py +++ b/packages/graphrag/graphrag/index/workflows/update_entities_relationships.py @@ -8,18 +8,17 @@ import pandas as pd from graphrag_cache import Cache from graphrag_llm.completion import create_completion -from graphrag_storage import Storage +from graphrag_storage.tables.table_provider import TableProvider from graphrag.cache.cache_key_creator import cache_key_creator from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.entities import _group_and_resolve_entities from graphrag.index.update.relationships import _update_and_merge_relationships from graphrag.index.workflows.extract_graph import get_summarized_entities_relationships -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -30,8 +29,8 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the entities and relationships from a incremental index run.""" logger.info("Workflow started: update_entities_relationships") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) ( @@ -39,9 +38,9 @@ async def run_workflow( merged_relationships_df, entity_id_mapping, ) = await _update_entities_and_relationships( - previous_storage, - delta_storage, - output_storage, + previous_table_provider, + delta_table_provider, + output_table_provider, config, context.cache, context.callbacks, @@ -56,24 +55,24 @@ async def run_workflow( async def _update_entities_and_relationships( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, config: GraphRagConfig, cache: Cache, callbacks: WorkflowCallbacks, ) -> tuple[pd.DataFrame, pd.DataFrame, dict]: """Update Final Entities and Relationships output.""" - old_entities = await load_table_from_storage("entities", previous_storage) - delta_entities = await load_table_from_storage("entities", delta_storage) + old_entities = await previous_table_provider.read_dataframe("entities") + delta_entities = await delta_table_provider.read_dataframe("entities") merged_entities_df, entity_id_mapping = _group_and_resolve_entities( old_entities, delta_entities ) # Update Relationships - old_relationships = await load_table_from_storage("relationships", previous_storage) - delta_relationships = await load_table_from_storage("relationships", delta_storage) + old_relationships = await previous_table_provider.read_dataframe("relationships") + delta_relationships = await delta_table_provider.read_dataframe("relationships") merged_relationships_df = _update_and_merge_relationships( old_relationships, delta_relationships, @@ -104,10 +103,9 @@ async def _update_entities_and_relationships( ) # Save the updated entities back to storage - await write_table_to_storage(merged_entities_df, "entities", output_storage) - - await write_table_to_storage( - merged_relationships_df, "relationships", output_storage + await output_table_provider.write_dataframe("entities", merged_entities_df) + await output_table_provider.write_dataframe( + "relationships", merged_relationships_df ) return merged_entities_df, merged_relationships_df, entity_id_mapping diff --git a/packages/graphrag/graphrag/index/workflows/update_final_documents.py b/packages/graphrag/graphrag/index/workflows/update_final_documents.py index b684beba94..7f473096d3 100644 --- a/packages/graphrag/graphrag/index/workflows/update_final_documents.py +++ b/packages/graphrag/graphrag/index/workflows/update_final_documents.py @@ -6,7 +6,7 @@ import logging from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.incremental_index import concat_dataframes @@ -20,12 +20,15 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the documents from a incremental index run.""" logger.info("Workflow started: update_final_documents") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) final_documents = await concat_dataframes( - "documents", previous_storage, delta_storage, output_storage + "documents", + previous_table_provider, + delta_table_provider, + output_table_provider, ) context.state["incremental_update_final_documents"] = final_documents diff --git a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py index 375bb69df4..4a3cf1a673 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_embeddings.py @@ -9,11 +9,10 @@ from graphrag.cache.cache_key_creator import cache_key_creator from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings -from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -24,9 +23,10 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the text embeddings from a incremental index run.""" logger.info("Workflow started: update_text_embeddings") - output_storage, _, _ = get_update_storages( + output_table_provider, _, _ = get_update_table_providers( config, context.state["update_timestamp"] ) + merged_text_units = context.state["incremental_update_merged_text_units"] merged_entities_df = context.state["incremental_update_merged_entities"] merged_community_reports = context.state[ @@ -62,11 +62,7 @@ async def run_workflow( ) if config.snapshots.embeddings: for name, table in result.items(): - await write_table_to_storage( - table, - f"embeddings.{name}", - output_storage, - ) + await output_table_provider.write_dataframe(f"embeddings.{name}", table) logger.info("Workflow completed: update_text_embeddings") return WorkflowFunctionOutput(result=None) diff --git a/packages/graphrag/graphrag/index/workflows/update_text_units.py b/packages/graphrag/graphrag/index/workflows/update_text_units.py index c97f89ce7a..02592b8aa4 100644 --- a/packages/graphrag/graphrag/index/workflows/update_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/update_text_units.py @@ -7,13 +7,12 @@ import numpy as np import pandas as pd -from graphrag_storage import Storage +from graphrag_storage.tables.table_provider import TableProvider from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.run.utils import get_update_storages +from graphrag.index.run.utils import get_update_table_providers from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -24,13 +23,16 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """Update the text units from a incremental index run.""" logger.info("Workflow started: update_text_units") - output_storage, previous_storage, delta_storage = get_update_storages( - config, context.state["update_timestamp"] + output_table_provider, previous_table_provider, delta_table_provider = ( + get_update_table_providers(config, context.state["update_timestamp"]) ) entity_id_mapping = context.state["incremental_update_entity_id_mapping"] merged_text_units = await _update_text_units( - previous_storage, delta_storage, output_storage, entity_id_mapping + previous_table_provider, + delta_table_provider, + output_table_provider, + entity_id_mapping, ) context.state["incremental_update_merged_text_units"] = merged_text_units @@ -40,19 +42,19 @@ async def run_workflow( async def _update_text_units( - previous_storage: Storage, - delta_storage: Storage, - output_storage: Storage, + previous_table_provider: TableProvider, + delta_table_provider: TableProvider, + output_table_provider: TableProvider, entity_id_mapping: dict, ) -> pd.DataFrame: """Update the text units output.""" - old_text_units = await load_table_from_storage("text_units", previous_storage) - delta_text_units = await load_table_from_storage("text_units", delta_storage) + old_text_units = await previous_table_provider.read_dataframe("text_units") + delta_text_units = await delta_table_provider.read_dataframe("text_units") merged_text_units = _update_and_merge_text_units( old_text_units, delta_text_units, entity_id_mapping ) - await write_table_to_storage(merged_text_units, "text_units", output_storage) + await output_table_provider.write_dataframe("text_units", merged_text_units) return merged_text_units diff --git a/packages/graphrag/graphrag/utils/storage.py b/packages/graphrag/graphrag/utils/storage.py deleted file mode 100644 index 852d066091..0000000000 --- a/packages/graphrag/graphrag/utils/storage.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Storage functions for the GraphRAG run module.""" - -import logging -from io import BytesIO - -import pandas as pd -from graphrag_storage import Storage - -logger = logging.getLogger(__name__) - - -async def load_table_from_storage(name: str, storage: Storage) -> pd.DataFrame: - """Load a parquet from the storage instance.""" - filename = f"{name}.parquet" - if not await storage.has(filename): - msg = f"Could not find {filename} in storage!" - raise ValueError(msg) - try: - logger.info("reading table from storage: %s", filename) - return pd.read_parquet(BytesIO(await storage.get(filename, as_bytes=True))) - except Exception: - logger.exception("error loading table from storage: %s", filename) - raise - - -async def write_table_to_storage( - table: pd.DataFrame, name: str, storage: Storage -) -> None: - """Write a table to storage.""" - await storage.set(f"{name}.parquet", table.to_parquet()) - - -async def delete_table_from_storage(name: str, storage: Storage) -> None: - """Delete a table to storage.""" - await storage.delete(f"{name}.parquet") - - -async def storage_has_table(name: str, storage: Storage) -> bool: - """Check if a table exists in storage.""" - return await storage.has(f"{name}.parquet") diff --git a/tests/unit/storage/__init__.py b/tests/unit/storage/__init__.py new file mode 100644 index 0000000000..0a3e38adfb --- /dev/null +++ b/tests/unit/storage/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/storage/test_parquet_table_provider.py b/tests/unit/storage/test_parquet_table_provider.py new file mode 100644 index 0000000000..781735224b --- /dev/null +++ b/tests/unit/storage/test_parquet_table_provider.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +import unittest +from io import BytesIO + +import pandas as pd +import pytest +from graphrag_storage import ( + StorageConfig, + StorageType, + create_storage, +) +from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider + + +class TestParquetTableProvider(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.storage = create_storage( + StorageConfig( + type=StorageType.Memory, + ) + ) + self.table_provider = ParquetTableProvider(storage=self.storage) + + async def asyncTearDown(self): + await self.storage.clear() + + async def test_write_and_read(self): + df = pd.DataFrame({ + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [30, 25, 35], + }) + + await self.table_provider.write_dataframe("users", df) + result = await self.table_provider.read_dataframe("users") + + pd.testing.assert_frame_equal(result, df) + + async def test_read_nonexistent_table_raises_error(self): + with pytest.raises( + ValueError, match=r"Could not find nonexistent\.parquet in storage!" + ): + await self.table_provider.read_dataframe("nonexistent") + + async def test_empty_dataframe(self): + df = pd.DataFrame() + + await self.table_provider.write_dataframe("empty", df) + result = await self.table_provider.read_dataframe("empty") + + pd.testing.assert_frame_equal(result, df) + + async def test_dataframe_with_multiple_types(self): + df = pd.DataFrame({ + "int_col": [1, 2, 3], + "float_col": [1.1, 2.2, 3.3], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True], + }) + + await self.table_provider.write_dataframe("mixed", df) + result = await self.table_provider.read_dataframe("mixed") + + pd.testing.assert_frame_equal(result, df) + + async def test_storage_persistence(self): + df = pd.DataFrame({"x": [1, 2, 3]}) + + await self.table_provider.write_dataframe("test", df) + + assert await self.storage.has("test.parquet") + + parquet_bytes = await self.storage.get("test.parquet", as_bytes=True) + loaded_df = pd.read_parquet(BytesIO(parquet_bytes)) + + pd.testing.assert_frame_equal(loaded_df, df) + + async def test_has_dataframe(self): + df = pd.DataFrame({"a": [1, 2, 3]}) + + # Table doesn't exist yet + assert not await self.table_provider.has_dataframe("test_table") + + # Write the table + await self.table_provider.write_dataframe("test_table", df) + + # Now it exists + assert await self.table_provider.has_dataframe("test_table") diff --git a/tests/verbs/test_create_base_text_units.py b/tests/verbs/test_create_base_text_units.py index 34bad99dc7..b7ad0543ed 100644 --- a/tests/verbs/test_create_base_text_units.py +++ b/tests/verbs/test_create_base_text_units.py @@ -2,7 +2,6 @@ # Licensed under the MIT License from graphrag.index.workflows.create_base_text_units import run_workflow -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -23,7 +22,7 @@ async def test_create_base_text_units(): await run_workflow(config, context) - actual = await load_table_from_storage("text_units", context.output_storage) + actual = await context.output_table_provider.read_dataframe("text_units") print("EXPECTED") print(expected.columns) diff --git a/tests/verbs/test_create_communities.py b/tests/verbs/test_create_communities.py index d5505d7a31..072e878e2c 100644 --- a/tests/verbs/test_create_communities.py +++ b/tests/verbs/test_create_communities.py @@ -5,7 +5,6 @@ from graphrag.index.workflows.create_communities import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -33,7 +32,7 @@ async def test_create_communities(): context, ) - actual = await load_table_from_storage("communities", context.output_storage) + actual = await context.output_table_provider.read_dataframe("communities") columns = list(expected.columns.values) # don't compare period since it is created with the current date each time diff --git a/tests/verbs/test_create_community_reports.py b/tests/verbs/test_create_community_reports.py index a36b6c7a66..68d8d1be9c 100644 --- a/tests/verbs/test_create_community_reports.py +++ b/tests/verbs/test_create_community_reports.py @@ -10,7 +10,6 @@ from graphrag.index.workflows.create_community_reports import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -56,7 +55,7 @@ async def test_create_community_reports(): await run_workflow(config, context) - actual = await load_table_from_storage("community_reports", context.output_storage) + actual = await context.output_table_provider.read_dataframe("community_reports") assert len(actual.columns) == len(expected.columns) diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index 6031ccd09c..586ad5b31c 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -5,7 +5,6 @@ from graphrag.index.workflows.create_final_documents import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -27,7 +26,7 @@ async def test_create_final_documents(): await run_workflow(config, context) - actual = await load_table_from_storage("documents", context.output_storage) + actual = await context.output_table_provider.read_dataframe("documents") compare_outputs(actual, expected) diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index 56c0e72a81..c97cba2bcd 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -5,7 +5,6 @@ from graphrag.index.workflows.create_final_text_units import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -33,7 +32,7 @@ async def test_create_final_text_units(): await run_workflow(config, context) - actual = await load_table_from_storage("text_units", context.output_storage) + actual = await context.output_table_provider.read_dataframe("text_units") for column in TEXT_UNITS_FINAL_COLUMNS: assert column in actual.columns diff --git a/tests/verbs/test_extract_covariates.py b/tests/verbs/test_extract_covariates.py index 5a87c121b3..4cf3a79d77 100644 --- a/tests/verbs/test_extract_covariates.py +++ b/tests/verbs/test_extract_covariates.py @@ -5,7 +5,6 @@ from graphrag.index.workflows.extract_covariates import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from graphrag_llm.config import LLMProviderType from pandas.testing import assert_series_equal @@ -41,7 +40,7 @@ async def test_extract_covariates(): await run_workflow(config, context) - actual = await load_table_from_storage("covariates", context.output_storage) + actual = await context.output_table_provider.read_dataframe("covariates") for column in COVARIATES_FINAL_COLUMNS: assert column in actual.columns diff --git a/tests/verbs/test_extract_graph.py b/tests/verbs/test_extract_graph.py index b62bd9c77f..504baaac31 100644 --- a/tests/verbs/test_extract_graph.py +++ b/tests/verbs/test_extract_graph.py @@ -1,10 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.workflows.extract_graph import ( - run_workflow, -) -from graphrag.utils.storage import load_table_from_storage +from graphrag.index.workflows.extract_graph import run_workflow from tests.unit.config.utils import get_default_graphrag_config @@ -54,10 +51,8 @@ async def test_extract_graph(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.output_storage) - edges_actual = await load_table_from_storage( - "relationships", context.output_storage - ) + nodes_actual = await context.output_table_provider.read_dataframe("entities") + edges_actual = await context.output_table_provider.read_dataframe("relationships") assert len(nodes_actual.columns) == 5 assert len(edges_actual.columns) == 5 diff --git a/tests/verbs/test_extract_graph_nlp.py b/tests/verbs/test_extract_graph_nlp.py index 9c758dda61..55ab376689 100644 --- a/tests/verbs/test_extract_graph_nlp.py +++ b/tests/verbs/test_extract_graph_nlp.py @@ -1,10 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.workflows.extract_graph_nlp import ( - run_workflow, -) -from graphrag.utils.storage import load_table_from_storage +from graphrag.index.workflows.extract_graph_nlp import run_workflow from tests.unit.config.utils import get_default_graphrag_config @@ -22,10 +19,8 @@ async def test_extract_graph_nlp(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.output_storage) - edges_actual = await load_table_from_storage( - "relationships", context.output_storage - ) + nodes_actual = await context.output_table_provider.read_dataframe("entities") + edges_actual = await context.output_table_provider.read_dataframe("relationships") # this will be the raw count of entities and edges with no pruning # with NLP it is deterministic, so we can assert exact row counts diff --git a/tests/verbs/test_finalize_graph.py b/tests/verbs/test_finalize_graph.py index 055ec76768..72513a8293 100644 --- a/tests/verbs/test_finalize_graph.py +++ b/tests/verbs/test_finalize_graph.py @@ -5,10 +5,7 @@ ENTITIES_FINAL_COLUMNS, RELATIONSHIPS_FINAL_COLUMNS, ) -from graphrag.index.workflows.finalize_graph import ( - run_workflow, -) -from graphrag.utils.storage import load_table_from_storage, write_table_to_storage +from graphrag.index.workflows.finalize_graph import run_workflow from tests.unit.config.utils import get_default_graphrag_config @@ -25,10 +22,8 @@ async def test_finalize_graph(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.output_storage) - edges_actual = await load_table_from_storage( - "relationships", context.output_storage - ) + nodes_actual = await context.output_table_provider.read_dataframe("entities") + edges_actual = await context.output_table_provider.read_dataframe("relationships") for column in ENTITIES_FINAL_COLUMNS: assert column in nodes_actual.columns @@ -44,8 +39,8 @@ async def _prep_tables(): # edit the tables to eliminate final fields that wouldn't be on the inputs entities = load_test_table("entities") entities.drop(columns=["degree"], inplace=True) - await write_table_to_storage(entities, "entities", context.output_storage) + await context.output_table_provider.write_dataframe("entities", entities) relationships = load_test_table("relationships") relationships.drop(columns=["combined_degree"], inplace=True) - await write_table_to_storage(relationships, "relationships", context.output_storage) + await context.output_table_provider.write_dataframe("relationships", relationships) return context diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index 14fd163d87..f25ed52d34 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -7,7 +7,6 @@ from graphrag.index.workflows.generate_text_embeddings import ( run_workflow, ) -from graphrag.utils.storage import load_table_from_storage from tests.unit.config.utils import get_default_graphrag_config @@ -45,8 +44,8 @@ async def test_generate_text_embeddings(): assert f"embeddings.{field}.parquet" in parquet_files # entity description should always be here, let's assert its format - entity_description_embeddings = await load_table_from_storage( - "embeddings.entity_description", context.output_storage + entity_description_embeddings = await context.output_table_provider.read_dataframe( + "embeddings.entity_description" ) assert len(entity_description_embeddings.columns) == 2 diff --git a/tests/verbs/test_prune_graph.py b/tests/verbs/test_prune_graph.py index fa66f98bde..1df5cadebb 100644 --- a/tests/verbs/test_prune_graph.py +++ b/tests/verbs/test_prune_graph.py @@ -2,10 +2,7 @@ # Licensed under the MIT License from graphrag.config.models.prune_graph_config import PruneGraphConfig -from graphrag.index.workflows.prune_graph import ( - run_workflow, -) -from graphrag.utils.storage import load_table_from_storage +from graphrag.index.workflows.prune_graph import run_workflow from tests.unit.config.utils import get_default_graphrag_config @@ -26,6 +23,6 @@ async def test_prune_graph(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.output_storage) + nodes_actual = await context.output_table_provider.read_dataframe("entities") assert len(nodes_actual) == 29 diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 65b6b906b7..741e8e3b1a 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -4,7 +4,6 @@ import pandas as pd from graphrag.index.run.utils import create_run_context from graphrag.index.typing.context import PipelineRunContext -from graphrag.utils.storage import write_table_to_storage from pandas.testing import assert_series_equal pd.set_option("display.max_columns", None) @@ -17,12 +16,12 @@ async def create_test_context(storage: list[str] | None = None) -> PipelineRunCo # always set the input docs, but since our stored table is final, drop what wouldn't be in the original source input input = load_test_table("documents") input.drop(columns=["text_unit_ids"], inplace=True) - await write_table_to_storage(input, "documents", context.output_storage) + await context.output_table_provider.write_dataframe("documents", input) if storage: for name in storage: table = load_test_table(name) - await write_table_to_storage(table, name, context.output_storage) + await context.output_table_provider.write_dataframe(name, table) return context