Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20260127131016120694.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add TableProvider abstraction for table-based storage operations"
}
39 changes: 20 additions & 19 deletions docs/examples_notebooks/index_migration_to_v1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,22 @@
"source": [
"from uuid import uuid4\n",
"\n",
"from graphrag.utils.storage import load_table_from_storage, write_table_to_storage\n",
"from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider\n",
"\n",
"# Create table provider from storage\n",
"table_provider = ParquetTableProvider(storage)\n",
"\n",
"# First we'll go through any parquet files that had model changes and update them\n",
"# The new data model may have removed excess columns as well, but we will only make the minimal changes required for compatibility\n",
"\n",
"final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n",
"final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n",
"final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n",
"final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n",
"final_relationships = await load_table_from_storage(\n",
" \"create_final_relationships\", storage\n",
")\n",
"final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n",
"final_community_reports = await load_table_from_storage(\n",
" \"create_final_community_reports\", storage\n",
"final_documents = await table_provider.read_dataframe(\"create_final_documents\")\n",
"final_text_units = await table_provider.read_dataframe(\"create_final_text_units\")\n",
"final_entities = await table_provider.read_dataframe(\"create_final_entities\")\n",
"final_nodes = await table_provider.read_dataframe(\"create_final_nodes\")\n",
"final_relationships = await table_provider.read_dataframe(\"create_final_relationships\")\n",
"final_communities = await table_provider.read_dataframe(\"create_final_communities\")\n",
"final_community_reports = await table_provider.read_dataframe(\n",
" \"create_final_community_reports\"\n",
")\n",
"\n",
"\n",
Expand Down Expand Up @@ -187,14 +188,14 @@
" parent_df, on=\"community\", how=\"left\"\n",
" )\n",
"\n",
"await write_table_to_storage(final_documents, \"create_final_documents\", storage)\n",
"await write_table_to_storage(final_text_units, \"create_final_text_units\", storage)\n",
"await write_table_to_storage(final_entities, \"create_final_entities\", storage)\n",
"await write_table_to_storage(final_nodes, \"create_final_nodes\", storage)\n",
"await write_table_to_storage(final_relationships, \"create_final_relationships\", storage)\n",
"await write_table_to_storage(final_communities, \"create_final_communities\", storage)\n",
"await write_table_to_storage(\n",
" final_community_reports, \"create_final_community_reports\", storage\n",
"await table_provider.write_dataframe(\"create_final_documents\", final_documents)\n",
"await table_provider.write_dataframe(\"create_final_text_units\", final_text_units)\n",
"await table_provider.write_dataframe(\"create_final_entities\", final_entities)\n",
"await table_provider.write_dataframe(\"create_final_nodes\", final_nodes)\n",
"await table_provider.write_dataframe(\"create_final_relationships\", final_relationships)\n",
"await table_provider.write_dataframe(\"create_final_communities\", final_communities)\n",
"await table_provider.write_dataframe(\n",
" \"create_final_community_reports\", final_community_reports\n",
")"
]
},
Expand Down
61 changes: 29 additions & 32 deletions docs/examples_notebooks/index_migration_to_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,25 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from graphrag.utils.storage import (\n",
" delete_table_from_storage,\n",
" load_table_from_storage,\n",
" write_table_to_storage,\n",
")\n",
"from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider\n",
"\n",
"final_documents = await load_table_from_storage(\"create_final_documents\", storage)\n",
"final_text_units = await load_table_from_storage(\"create_final_text_units\", storage)\n",
"final_entities = await load_table_from_storage(\"create_final_entities\", storage)\n",
"final_covariates = await load_table_from_storage(\"create_final_covariates\", storage)\n",
"final_nodes = await load_table_from_storage(\"create_final_nodes\", storage)\n",
"final_relationships = await load_table_from_storage(\n",
" \"create_final_relationships\", storage\n",
")\n",
"final_communities = await load_table_from_storage(\"create_final_communities\", storage)\n",
"final_community_reports = await load_table_from_storage(\n",
" \"create_final_community_reports\", storage\n",
"# Create table provider from storage\n",
"table_provider = ParquetTableProvider(storage)\n",
"\n",
"final_documents = await table_provider.read_dataframe(\"create_final_documents\")\n",
"final_text_units = await table_provider.read_dataframe(\"create_final_text_units\")\n",
"final_entities = await table_provider.read_dataframe(\"create_final_entities\")\n",
"final_covariates = await table_provider.read_dataframe(\"create_final_covariates\")\n",
"final_nodes = await table_provider.read_dataframe(\"create_final_nodes\")\n",
"final_relationships = await table_provider.read_dataframe(\"create_final_relationships\")\n",
"final_communities = await table_provider.read_dataframe(\"create_final_communities\")\n",
"final_community_reports = await table_provider.read_dataframe(\n",
" \"create_final_community_reports\"\n",
")\n",
"\n",
"# we've renamed document attributes as metadata\n",
Expand Down Expand Up @@ -126,23 +123,23 @@
")\n",
"\n",
"# we renamed all the output files for better clarity now that we don't have workflow naming constraints from DataShaper\n",
"await write_table_to_storage(final_documents, \"documents\", storage)\n",
"await write_table_to_storage(final_text_units, \"text_units\", storage)\n",
"await write_table_to_storage(final_entities, \"entities\", storage)\n",
"await write_table_to_storage(final_relationships, \"relationships\", storage)\n",
"await write_table_to_storage(final_covariates, \"covariates\", storage)\n",
"await write_table_to_storage(final_communities, \"communities\", storage)\n",
"await write_table_to_storage(final_community_reports, \"community_reports\", storage)\n",
"await table_provider.write_dataframe(\"documents\", final_documents)\n",
"await table_provider.write_dataframe(\"text_units\", final_text_units)\n",
"await table_provider.write_dataframe(\"entities\", final_entities)\n",
"await table_provider.write_dataframe(\"relationships\", final_relationships)\n",
"await table_provider.write_dataframe(\"covariates\", final_covariates)\n",
"await table_provider.write_dataframe(\"communities\", final_communities)\n",
"await table_provider.write_dataframe(\"community_reports\", final_community_reports)\n",
"\n",
"# delete all the old versions\n",
"await delete_table_from_storage(\"create_final_documents\", storage)\n",
"await delete_table_from_storage(\"create_final_text_units\", storage)\n",
"await delete_table_from_storage(\"create_final_entities\", storage)\n",
"await delete_table_from_storage(\"create_final_nodes\", storage)\n",
"await delete_table_from_storage(\"create_final_relationships\", storage)\n",
"await delete_table_from_storage(\"create_final_covariates\", storage)\n",
"await delete_table_from_storage(\"create_final_communities\", storage)\n",
"await delete_table_from_storage(\"create_final_community_reports\", storage)"
"await storage.delete(\"create_final_documents.parquet\")\n",
"await storage.delete(\"create_final_text_units.parquet\")\n",
"await storage.delete(\"create_final_entities.parquet\")\n",
"await storage.delete(\"create_final_nodes.parquet\")\n",
"await storage.delete(\"create_final_relationships.parquet\")\n",
"await storage.delete(\"create_final_covariates.parquet\")\n",
"await storage.delete(\"create_final_communities.parquet\")\n",
"await storage.delete(\"create_final_community_reports.parquet\")"
]
}
],
Expand Down
12 changes: 6 additions & 6 deletions docs/examples_notebooks/index_migration_to_v3.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,17 @@
"metadata": {},
"outputs": [],
"source": [
"from graphrag.utils.storage import (\n",
" load_table_from_storage,\n",
" write_table_to_storage,\n",
")\n",
"from graphrag_storage.tables.parquet_table_provider import ParquetTableProvider\n",
"\n",
"text_units = await load_table_from_storage(\"text_units\", storage)\n",
"# Create table provider from storage\n",
"table_provider = ParquetTableProvider(storage)\n",
"\n",
"text_units = await table_provider.read_dataframe(\"text_units\")\n",
"\n",
"text_units[\"document_id\"] = text_units[\"document_ids\"].apply(lambda ids: ids[0])\n",
"remove_columns(text_units, [\"document_ids\"])\n",
"\n",
"await write_table_to_storage(text_units, \"text_units\", storage)"
"await table_provider.write_dataframe(\"text_units\", text_units)"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions packages/graphrag-storage/graphrag_storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
register_storage,
)
from graphrag_storage.storage_type import StorageType
from graphrag_storage.tables import TableProvider

__all__ = [
"Storage",
"StorageConfig",
"StorageType",
"TableProvider",
"create_storage",
"register_storage",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Table provider module for GraphRAG storage."""

from .table_provider import TableProvider

__all__ = ["TableProvider"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Parquet-based table provider implementation."""

import logging
import re
from io import BytesIO

import pandas as pd

from graphrag_storage.storage import Storage
from graphrag_storage.tables.table_provider import TableProvider

logger = logging.getLogger(__name__)


class ParquetTableProvider(TableProvider):
"""Table provider that stores tables as Parquet files using an underlying Storage instance.

This provider converts between pandas DataFrames and Parquet format,
storing the data through a Storage backend (file, blob, cosmos, etc.).
"""

def __init__(self, storage: Storage, **kwargs) -> None:
"""Initialize the Parquet table provider with an underlying storage instance.

Args
----
storage: Storage
The storage instance to use for reading and writing Parquet files.
**kwargs: Any
Additional keyword arguments (currently unused).
"""
self._storage = storage

async def read_dataframe(self, table_name: str) -> pd.DataFrame:
"""Read a table from storage as a pandas DataFrame.

Args
----
table_name: str
The name of the table to read. The file will be accessed as '{table_name}.parquet'.

Returns
-------
pd.DataFrame:
The table data loaded from the Parquet file.

Raises
------
ValueError:
If the table file does not exist in storage.
Exception:
If there is an error reading or parsing the Parquet file.
"""
filename = f"{table_name}.parquet"
if not await self._storage.has(filename):
msg = f"Could not find {filename} in storage!"
raise ValueError(msg)
try:
logger.info("reading table from storage: %s", filename)
return pd.read_parquet(
BytesIO(await self._storage.get(filename, as_bytes=True))
)
except Exception:
logger.exception("error loading table from storage: %s", filename)
raise

async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None:
"""Write a pandas DataFrame to storage as a Parquet file.

Args
----
table_name: str
The name of the table to write. The file will be saved as '{table_name}.parquet'.
df: pd.DataFrame
The DataFrame to write to storage.
"""
await self._storage.set(f"{table_name}.parquet", df.to_parquet())

async def has_dataframe(self, table_name: str) -> bool:
"""Check if a table exists in storage.

Args
----
table_name: str
The name of the table to check.

Returns
-------
bool:
True if the table exists, False otherwise.
"""
return await self._storage.has(f"{table_name}.parquet")

def find_tables(self) -> list[str]:
"""Find all table names in storage.

Returns
-------
list[str]:
List of table names (without .parquet extension).
"""
return [
file.replace(".parquet", "")
for file in self._storage.find(re.compile(r"\.parquet$"))
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Abstract base class for table providers."""

from abc import ABC, abstractmethod
from typing import Any

import pandas as pd


class TableProvider(ABC):
"""Provide a table-based storage interface with support for DataFrames and row dictionaries."""

@abstractmethod
def __init__(self, **kwargs: Any) -> None:
"""Create a table provider instance.

Args
----
**kwargs: Any
Keyword arguments for initialization, may include underlying Storage instance.
"""

@abstractmethod
async def read_dataframe(self, table_name: str) -> pd.DataFrame:
"""Read entire table as a pandas DataFrame.

Args
----
table_name: str
The name of the table to read.

Returns
-------
pd.DataFrame:
The table data as a DataFrame.
"""

@abstractmethod
async def write_dataframe(self, table_name: str, df: pd.DataFrame) -> None:
"""Write entire table from a pandas DataFrame.

Args
----
table_name: str
The name of the table to write.
df: pd.DataFrame
The DataFrame to write as a table.
"""

@abstractmethod
async def has_dataframe(self, table_name: str) -> bool:
"""Check if a table exists in the provider.

Args
----
table_name: str
The name of the table to check.

Returns
-------
bool:
True if the table exists, False otherwise.
"""

@abstractmethod
def find_tables(self) -> list[str]:
"""Find all table names in the provider.

Returns
-------
list[str]:
List of table names (without file extensions).
"""
Loading