diff --git a/CITATION.cff b/CITATION.cff
index cd06753..42199b3 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -1,6 +1,6 @@
cff-version: 1.2.0
message: "If you use SemHash in your research, please cite it as below."
-title: "SemHash: Fast Semantic Text Deduplication & Filtering"
+title: "SemHash: Fast Multimodal Semantic Deduplication & Filtering"
authors:
- family-names: "van Dongen"
given-names: "Thomas"
@@ -14,7 +14,7 @@ date-released: "2025-01-05"
preferred-citation:
type: software
- title: "SemHash: Fast Semantic Text Deduplication & Filtering"
+ title: "SemHash: Fast Multimodal Semantic Deduplication & Filtering"
authors:
- family-names: "van Dongen"
given-names: "Thomas"
diff --git a/Makefile b/Makefile
index 2e79ab7..5a53d3b 100644
--- a/Makefile
+++ b/Makefile
@@ -9,10 +9,18 @@ install: venv
uv run pre-commit install
install-no-pre-commit:
- uv pip install ".[dev]"
+ uv pip install ".[dev,all]"
fix:
uv run pre-commit run --all-files
test:
uv run pytest --cov=semhash --cov-report=term-missing
+
+benchmark-text:
+ uv run python -m benchmarks.run_text_benchmarks
+
+benchmark-image:
+ uv run python -m benchmarks.run_image_benchmarks
+
+benchmark: benchmark-text benchmark-image
diff --git a/README.md b/README.md
index fd2ed6a..abb230f 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@

- Fast Semantic Text Deduplication & Filtering
+ Fast Multimodal Semantic Deduplication & Filtering
@@ -38,9 +38,9 @@
-SemHash is a lightweight and flexible tool for deduplicating datasets, filtering outliers, and finding representative samples using semantic similarity. It combines fast embedding generation from [Model2Vec](https://github.com/MinishLab/model2vec) with efficient ANN-based similarity search through [Vicinity](https://github.com/MinishLab/vicinity).
+SemHash is a lightweight, multimodal library for semantic deduplication, outlier filtering, and representative sample selection. Text works out of the box with fast [Model2Vec](https://github.com/MinishLab/model2vec) embeddings, and images, audio, and other modalities are supported with custom encoders.
-SemHash supports both single-dataset deduplication & filtering (e.g., cleaning up a train set by removing duplicates and outliers) and multi-dataset deduplication & filtering (e.g., ensuring no overlap between a test set and a train set). It works with simple datasets, such as text lists, and more complex ones, like multi-column QA datasets. Additionally, it includes functions to inspect deduplication results, making it easier to understand and refine your data cleaning process.
+SemHash supports both single-dataset operations (clean a training set) and cross-dataset operations (deduplicate test against train). It works with simple lists and complex multi-column datasets, and includes inspection tools to help you understand and refine results. All operations use [Vicinity](https://github.com/MinishLab/vicinity) for efficient similarity search.
## Quickstart
@@ -49,6 +49,8 @@ Install the package with:
pip install semhash
```
+### Text Deduplication, Filtering & Representative Sampling
+
Deduplicate a single dataset, filter outliers, and find representative samples with the following code (note: the examples assume you have `datasets` installed, which you can install with `pip install datasets`):
```python
@@ -71,7 +73,35 @@ filtered_texts = semhash.self_filter_outliers().selected
representative_texts = semhash.self_find_representative().selected
```
-Or, deduplicate across two datasets, filter outliers, and find representative samples with the following code (e.g., eliminating train/test leakage):
+### Image Deduplication, Filtering & Representative Sampling
+
+Deduplicate an image dataset, filter outliers, and find representative samples using a vision model (requires `pip install sentence-transformers`):
+
+```python
+from datasets import load_dataset
+from sentence_transformers import SentenceTransformer
+from semhash import SemHash
+
+# Load an image dataset and vision model
+model = SentenceTransformer('clip-ViT-B-32')
+dataset = load_dataset("uoft-cs/cifar10", split="test")
+
+# Initialize a SemHash instance with the 'img' column
+semhash = SemHash.from_records(list(dataset), columns=["img"], model=model)
+
+# Deduplicate the images
+deduplicated_images = semhash.self_deduplicate().selected
+
+# Filter outliers
+filtered_images = semhash.self_filter_outliers().selected
+
+# Find representative images
+representative_images = semhash.self_find_representative().selected
+```
+
+### Cross-Dataset Deduplication, Filtering & Representative Sampling
+
+Deduplicate across two datasets, filter outliers, and find representative samples (e.g., eliminating train/test leakage):
```python
from datasets import load_dataset
@@ -93,13 +123,12 @@ filtered_test_texts = semhash.filter_outliers(records=test_texts, outlier_percen
# Find representative texts in the test data against the training data,
# optionally with a specific selection size
-representative_test_texts = semhash.find_representative(
- records=test_texts, selection_size=10).selected
-
-
+representative_test_texts = semhash.find_representative(records=test_texts, selection_size=10).selected
```
-Or, deduplicate multi-column dataset, filter outliers, and find representative samples with the following code (e.g., deduplicating a QA dataset):
+### Multi-Column Deduplication
+
+Deduplicate multi-column datasets (e.g., deduplicating a QA dataset):
```python
from datasets import load_dataset
@@ -116,15 +145,9 @@ semhash = SemHash.from_records(records=records, columns=["question", "context"])
# Deduplicate the records
deduplicated_records = semhash.self_deduplicate().selected
-
-# Filter outliers from the records
-filtered_texts = semhash.self_filter_outliers().selected
-
-# Find representative texts in the records
-representative_texts = semhash.self_find_representative().selected
```
-The `deduplicate` and `self_deduplicate` functions return a [DeduplicationResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#L58). This object stores the deduplicated corpus, a set of duplicate object (along with the objects that caused duplication), and several useful functions to further inspect the deduplication result.
+The `deduplicate` and `self_deduplicate` functions return a [DeduplicationResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#L58). This object stores the deduplicated corpus, a set of duplicate objects (along with the objects that caused duplication), and several useful functions to further inspect the deduplication result.
The `filter_outliers`, `self_filter_outliers`, `find_representative`, and `self_find_representative` functions return a [FilterResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#179). This object stores the found outliers/representative samples.
@@ -212,14 +235,11 @@ The following code snippet shows how to deduplicate across two datasets, filter
from datasets import load_dataset
from semhash import SemHash
-# Initialize a SemHash instance
-semhash = SemHash()
-
# Load two datasets to deduplicate
train_texts = load_dataset("ag_news", split="train")["text"]
test_texts = load_dataset("ag_news", split="test")["text"]
-# Initialize a SemHash instance
+# Initialize a SemHash instance with the training data
semhash = SemHash.from_records(records=train_texts)
# Deduplicate the test data against the training data
@@ -265,6 +285,70 @@ representative_records = semhash.self_find_representative().selected
+
+ Deduplicate, filter outliers, and find representative samples on image datasets
+
+
+You can bring your own encoder for any modality by implementing the Encoder protocol. Here's an example using a vision model from timm for image deduplication:
+
+```python
+from datasets import load_dataset
+import timm
+import torch
+from semhash import SemHash
+
+# Requires: pip install timm torch datasets
+
+# Create a custom image encoder
+class VisionEncoder:
+ """Custom encoder using timm models. Implements the Encoder protocol."""
+
+ def __init__(self, model_name: str = "mobilenetv3_small_100.lamb_in1k"):
+ self.model = timm.create_model(model_name, pretrained=True, num_classes=0).eval()
+ data_config = timm.data.resolve_model_data_config(self.model)
+ self.transform = timm.data.create_transform(**data_config, is_training=False)
+
+ def encode(self, inputs, batch_size: int = 128):
+ """Encode a batch of PIL images into embeddings."""
+ import numpy as np
+
+ # Convert grayscale to RGB if needed
+ rgb_inputs = [img.convert("RGB") if img.mode != "RGB" else img for img in inputs]
+
+ # Process in batches to avoid memory issues
+ all_embeddings = []
+ with torch.no_grad():
+ for i in range(0, len(rgb_inputs), batch_size):
+ batch_inputs = rgb_inputs[i : i + batch_size]
+ batch = torch.stack([self.transform(img) for img in batch_inputs])
+ embeddings = self.model(batch).numpy()
+ all_embeddings.append(embeddings)
+
+ return np.vstack(all_embeddings)
+
+# Load image dataset
+dataset = load_dataset("uoft-cs/cifar10", split="test")
+train_data = [{"img": img, "id": i} for i, img in enumerate(dataset["img"][:100])]
+test_data = [{"img": img, "id": i} for i, img in enumerate(dataset["img"][100:150])]
+
+# Initialize SemHash with the custom vision encoder
+semhash = SemHash.from_records(train_data, columns=["img"], model=VisionEncoder())
+
+# Single-dataset operations
+deduplicated = semhash.self_deduplicate().selected
+outliers = semhash.self_filter_outliers().selected
+representatives = semhash.self_find_representative().selected
+
+# Cross-dataset operations
+test_deduplicated = semhash.deduplicate(test_data).selected
+test_outliers = semhash.filter_outliers(test_data).selected
+test_representatives = semhash.find_representative(test_data, selection_size=10).selected
+```
+
+The Encoder protocol requires only an `encode(inputs, **kwargs)` method that returns a numpy array. This makes it easy to integrate any embedding model for any modality.
+
+
+
Using custom encoders
@@ -400,14 +484,65 @@ representative_texts = semhash.self_find_representative().selected
```
+
+ Initializing from a HuggingFace Dataset
+
+You can easily use SemHash with HuggingFace Datasets by converting them to a list:
+
+```python
+from datasets import load_dataset
+from semhash import SemHash
+
+# Load a HuggingFace dataset
+dataset = load_dataset("ag_news", split="train")
+
+# Convert to list and initialize SemHash
+semhash = SemHash.from_records(records=list(dataset), columns=["text"])
+
+# Deduplicate, filter outliers, and find representative samples
+deduplicated_texts = semhash.self_deduplicate().selected
+filtered_texts = semhash.self_filter_outliers().selected
+representative_texts = semhash.self_find_representative().selected
+```
+
+This also works with multi-column datasets:
+
+```python
+from datasets import load_dataset
+from semhash import SemHash
+
+# Load a multi-column dataset
+dataset = load_dataset("squad_v2", split="train")
+
+# Convert to list and initialize with multiple columns
+semhash = SemHash.from_records(records=list(dataset), columns=["question", "context"])
+
+# Deduplicate the records
+deduplicated_records = semhash.self_deduplicate().selected
+```
+
+
## Benchmarks
-SemHash is extremely fast and scales to large datasets with millions of records. We've benchmarked both single-dataset deduplication and train/test deduplication across a variety of datasets. For example, deduplicating 1.8M records takes only ~83 seconds on CPU.
+SemHash is extremely fast and scales to large datasets with millions of records. We've benchmarked both text and image deduplication across a variety of datasets. For example, deduplicating text 1.8M records takes only ~83 seconds on CPU.
+
+For detailed benchmark results and analysis, see the [benchmarks directory](benchmarks/README.md).
-For detailed benchmark results including performance metrics across 17 datasets, as well as code to reproduce the benchmarks, see the [benchmarks directory](benchmarks/README.md).
+### Running Benchmarks
+
+```bash
+# Run text benchmarks
+make benchmark-text
+
+# Run image benchmarks
+make benchmark-image
+
+# Run all benchmarks
+make benchmark
+```
## License
@@ -419,7 +554,7 @@ If you use SemHash in your research, please cite the following:
```bibtex
@software{minishlab2025semhash,
author = {{van Dongen}, Thomas and Stephan Tulkens},
- title = {SemHash: Fast Semantic Text Deduplication \& Filtering},
+ title = {SemHash: Fast Multimodal Semantic Deduplication \& Filtering},
year = {2025},
publisher = {Zenodo},
doi = {10.5281/zenodo.17265942},
diff --git a/benchmarks/README.md b/benchmarks/README.md
index 7078f3b..7ebf432 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -1,16 +1,33 @@
# SemHash Benchmarks
-This directory contains the benchmarking code and results for SemHash. The benchmarks measure deduplication performance and speed across a variety of datasets.
+This directory contains the benchmarking code and results for SemHash. The benchmarks measure deduplication performance and speed across a variety of text and image datasets.
-## Setup
+## Table of Contents
-All benchmarks were run with the following configuration:
+- [Text Benchmarks](#text-benchmarks)
+ - [Setup](#setup)
+ - [Results](#results)
+ - [Key Findings](#key-findings)
+ - [Running Text Benchmarks](#running-text-benchmarks)
+- [Image Benchmarks](#image-benchmarks)
+ - [Setup](#setup-1)
+ - [Results](#results-1)
+ - [Key Findings](#key-findings-1)
+ - [Running Image Benchmarks](#running-image-benchmarks)
+- [Running All Benchmarks](#running-all-benchmarks)
+
+## Text Benchmarks
+
+### Setup
+
+All text benchmarks were run with the following configuration:
- **CPU-only**: All benchmarks run on CPU (no GPU acceleration)
- **ANN backend**: Default backend (USearch)
- **Encoder**: Default encoder ([potion-base-8M](https://huggingface.co/minishlab/potion-base-8M))
- **Timing**: Includes encoding time, index building time, and deduplication time
+- **Dependencies**: Requires `datasets` package (`pip install datasets`)
-## Results
+### Results
### Train Deduplication Benchmark
@@ -60,7 +77,7 @@ This benchmark measures the performance of deduplicating a test dataset against
| squad_v2 | 130319 | 11873 | 11863 | 0.08 | 7.13 |
| wikitext | 1801350 | 4358 | 2139 | 50.92 | 40.32 |
-## Key Findings
+### Key Findings
SemHash is extremely fast and scales to large datasets with millions of records. Some notable findings include:
@@ -70,12 +87,77 @@ SemHash is extremely fast and scales to large datasets with millions of records.
- `student`: 52% of test data overlaps with training data
- `wikitext`: 51% of test data overlaps with training data
-## Running the Benchmarks
+### Running Text Benchmarks
+
+To run the text benchmarks yourself:
+
+```bash
+# Install dependencies
+pip install datasets
+
+# Run benchmarks
+python -m benchmarks.run_text_benchmarks
+# Or using make
+make benchmark-text
+```
+
+## Image Benchmarks
+
+### Setup
+
+All image benchmarks were run with the following configuration:
+- **Device**: Apple Silicon GPU (MPS)
+- **ANN backend**: Default backend (USearch)
+- **Encoder**: MobileNetV3-Small ([mobilenetv3_small_100.lamb_in1k](https://huggingface.co/timm/mobilenetv3_small_100.lamb_in1k))
+- **Batch size**: 128 images per batch
+- **Timing**: Includes encoding time, index building time, and deduplication time
+
+### Results
+
+#### Train Deduplication Benchmark
+
+This benchmark measures the performance of deduplicating within a single training dataset.
+
+| Dataset | Original Train Size | Deduplicated Train Size | % Removed | Deduplication Time (s) |
+|----------------------|----------------------|--------------------------|------------|--------------------------|
+| cifar10 | 50000 | 48274 | 3.45 | 61.20 |
+| fashion_mnist | 60000 | 16714 | 72.14 | 86.61 |
+
+#### Train/Test Deduplication Benchmark
+
+This benchmark measures the performance of deduplicating a test dataset against a training dataset.
+
+| Dataset | Train Size | Test Size | Deduplicated Test Size | % Removed | Deduplication Time (s) |
+|----------------------|--------------|--------------|--------------------------|------------|--------------------------|
+| cifar10 | 50000 | 10000 | 9397 | 6.03 | 67.43 |
+| fashion_mnist | 60000 | 10000 | 2052 | 79.48 | 72.14 |
+
+### Key Findings
-To run the benchmarks yourself:
+- **Fashion-MNIST high deduplication**: Fashion-MNIST shows very high duplication rates (72% train, 79% test) due to the simple nature of the dataset (10 clothing categories with similar items)
+- **CIFAR-10 moderate deduplication**: CIFAR-10 shows lower duplication (3.45% train, 6.03% test) as it contains more diverse natural images
+- **Speed**: Image deduplication is fast even for large datasets (60k images in ~87 seconds on MPS); note that the actual deduplication step is quick, with most time spent on encoding images
+
+### Running Image Benchmarks
+
+To run the image benchmarks yourself:
```bash
-python -m benchmarks.run_benchmarks
+# Install dependencies
+pip install timm torch datasets
+
+# Run benchmarks
+python -m benchmarks.run_image_benchmarks
+# Or using make
+make benchmark-image
```
-The datasets can be customized by editing `benchmarks/data.py`.
+The image datasets can be customized by editing `benchmarks/data.py` (see `IMAGE_DATASET_DICT`).
+
+## Running All Benchmarks
+
+To run both text and image benchmarks:
+
+```bash
+make benchmark
+```
diff --git a/benchmarks/data.py b/benchmarks/data.py
index 49b988c..f0eb6db 100644
--- a/benchmarks/data.py
+++ b/benchmarks/data.py
@@ -12,6 +12,7 @@ class DatasetRecord:
columns: list[str] | None = None
split_one: str = "train"
split_two: str = "test"
+ modality: str = "text"
DATASET_DICT: dict[str, DatasetRecord] = {
@@ -41,3 +42,8 @@ class DatasetRecord:
name="Salesforce/wikitext", text_name="text", label_name="text", sub_directory="wikitext-103-raw-v1"
),
}
+
+IMAGE_DATASET_DICT: dict[str, DatasetRecord] = {
+ "cifar10": DatasetRecord(name="uoft-cs/cifar10", columns=["img"], modality="image"),
+ "fashion_mnist": DatasetRecord(name="fashion_mnist", columns=["image"], modality="image"),
+}
diff --git a/benchmarks/results/image_train_benchmark_results.json b/benchmarks/results/image_train_benchmark_results.json
new file mode 100644
index 0000000..53fa71b
--- /dev/null
+++ b/benchmarks/results/image_train_benchmark_results.json
@@ -0,0 +1,20 @@
+[
+ {
+ "dataset": "cifar10",
+ "original_train_size": 50000,
+ "deduplicated_train_size": 48274,
+ "percent_removed": 3.4519999999999995,
+ "build_time_seconds": 56.00128899999254,
+ "deduplication_time_seconds": 5.201297917010379,
+ "time_seconds": 61.20258691700292
+ },
+ {
+ "dataset": "fashion_mnist",
+ "original_train_size": 60000,
+ "deduplicated_train_size": 16714,
+ "percent_removed": 72.14333333333333,
+ "build_time_seconds": 61.14413262500602,
+ "deduplication_time_seconds": 25.46288070900482,
+ "time_seconds": 86.60701333401084
+ }
+]
diff --git a/benchmarks/results/image_train_test_benchmark_results.json b/benchmarks/results/image_train_test_benchmark_results.json
new file mode 100644
index 0000000..290c7e3
--- /dev/null
+++ b/benchmarks/results/image_train_test_benchmark_results.json
@@ -0,0 +1,22 @@
+[
+ {
+ "dataset": "cifar10",
+ "train_size": 50000,
+ "test_size": 10000,
+ "deduplicated_test_size": 9397,
+ "percent_removed": 6.030000000000002,
+ "build_time_seconds": 56.00128899999254,
+ "deduplication_time_seconds": 11.428115875009098,
+ "time_seconds": 67.42940487500164
+ },
+ {
+ "dataset": "fashion_mnist",
+ "train_size": 60000,
+ "test_size": 10000,
+ "deduplicated_test_size": 2052,
+ "percent_removed": 79.47999999999999,
+ "build_time_seconds": 61.14413262500602,
+ "deduplication_time_seconds": 10.998616750002839,
+ "time_seconds": 72.14274937500886
+ }
+]
diff --git a/benchmarks/run_image_benchmarks.py b/benchmarks/run_image_benchmarks.py
new file mode 100644
index 0000000..245a803
--- /dev/null
+++ b/benchmarks/run_image_benchmarks.py
@@ -0,0 +1,194 @@
+import json
+import logging
+from time import perf_counter
+
+import numpy as np
+import timm
+import torch
+from datasets import load_dataset
+
+from benchmarks.data import IMAGE_DATASET_DICT
+from semhash import SemHash
+
+# Set up logging
+logger = logging.getLogger(__name__)
+
+
+class VisionEncoder:
+ """Custom encoder using timm models for image embeddings."""
+
+ def __init__(self, model_name: str = "mobilenetv3_small_100.lamb_in1k") -> None:
+ """Initialize the vision encoder with a timm model."""
+ if torch.cuda.is_available():
+ self.device = torch.device("cuda")
+ elif torch.backends.mps.is_available():
+ self.device = torch.device("mps")
+ else:
+ self.device = torch.device("cpu")
+ logger.info(f"Using device: {self.device}")
+
+ self.model = timm.create_model(model_name, pretrained=True, num_classes=0).eval()
+ self.model = self.model.to(self.device)
+
+ data_config = timm.data.resolve_model_data_config(self.model)
+ self.transform = timm.data.create_transform(**data_config, is_training=False)
+
+ def encode(self, inputs: list, batch_size: int = 128) -> np.ndarray:
+ """Encode a batch of PIL images into embeddings."""
+ # Convert grayscale to RGB if needed
+ rgb_inputs = [img.convert("RGB") if img.mode != "RGB" else img for img in inputs]
+
+ # Process in batches
+ all_embeddings = []
+ with torch.no_grad():
+ for i in range(0, len(rgb_inputs), batch_size):
+ batch_inputs = rgb_inputs[i : i + batch_size]
+ batch = torch.stack([self.transform(img) for img in batch_inputs]).to(self.device)
+ embeddings = self.model(batch).cpu().numpy()
+ all_embeddings.append(embeddings)
+
+ return np.vstack(all_embeddings)
+
+
+def main() -> None: # noqa: C901
+ """Run the image benchmarks."""
+ # Prepare lists to hold benchmark results
+ train_dedup_results = []
+ train_test_dedup_results = []
+
+ # Initialize vision encoder
+ encoder = VisionEncoder()
+
+ for dataset_name, record in IMAGE_DATASET_DICT.items():
+ logger.info(f"Loading dataset: {dataset_name} from {record.name}")
+
+ # Load train and test splits
+ if record.sub_directory:
+ train_ds = load_dataset(record.name, record.sub_directory, split=record.split_one)
+ test_ds = load_dataset(record.name, record.sub_directory, split=record.split_two)
+ else:
+ train_ds = load_dataset(record.name, split=record.split_one)
+ test_ds = load_dataset(record.name, split=record.split_two)
+
+ # Convert to list of dicts with image column
+ train_records = list(train_ds)
+ test_records = list(test_ds)
+ columns = record.columns
+
+ # Build the SemHash instance
+ build_start = perf_counter()
+ semhash = SemHash.from_records(model=encoder, records=train_records, columns=columns)
+ build_end = perf_counter()
+ build_time = build_end - build_start
+
+ # Time how long it takes to deduplicate the train set
+ train_only_start = perf_counter()
+ deduplicated_train = semhash.self_deduplicate()
+ train_only_end = perf_counter()
+
+ train_only_dedup_time = train_only_end - train_only_start
+ original_train_size = len(train_records)
+ dedup_train_size = len(deduplicated_train.selected)
+
+ percent_removed_train = deduplicated_train.duplicate_ratio * 100
+ train_dedup_results.append(
+ {
+ "dataset": dataset_name,
+ "original_train_size": original_train_size,
+ "deduplicated_train_size": dedup_train_size,
+ "percent_removed": percent_removed_train,
+ "build_time_seconds": build_time,
+ "deduplication_time_seconds": train_only_dedup_time,
+ "time_seconds": train_only_dedup_time + build_time,
+ }
+ )
+
+ logger.info(
+ f"[TRAIN DEDUPLICATION] Dataset: {dataset_name}\n"
+ f" - Original Train Size: {original_train_size}\n"
+ f" - Deduplicated Train Size: {dedup_train_size}\n"
+ f" - % Removed: {percent_removed_train:.2f}\n"
+ f" - Deduplication Time (seconds): {train_only_dedup_time:.2f}\n"
+ f" - Build Time (seconds): {build_time:.2f}\n"
+ f" - Total Time (seconds): {train_only_dedup_time + build_time:.2f}\n"
+ )
+
+ # Time how long it takes to deduplicate the test set
+ train_test_start = perf_counter()
+ deduplicated_test = semhash.deduplicate(
+ records=test_records,
+ )
+ train_test_end = perf_counter()
+ train_test_dedup_time = train_test_end - train_test_start
+ original_test_size = len(test_records)
+ deduped_test_size = len(deduplicated_test.selected)
+ percent_removed_test = deduplicated_test.duplicate_ratio * 100
+
+ train_test_dedup_results.append(
+ {
+ "dataset": dataset_name,
+ "train_size": original_train_size,
+ "test_size": original_test_size,
+ "deduplicated_test_size": deduped_test_size,
+ "percent_removed": percent_removed_test,
+ "build_time_seconds": build_time,
+ "deduplication_time_seconds": train_test_dedup_time,
+ "time_seconds": train_test_dedup_time + build_time,
+ }
+ )
+
+ logger.info(
+ f"[TRAIN/TEST DEDUPLICATION] Dataset: {dataset_name}\n"
+ f" - Train Size: {original_train_size}\n"
+ f" - Test Size: {original_test_size}\n"
+ f" - Deduplicated Test Size: {deduped_test_size}\n"
+ f" - % Removed: {percent_removed_test:.2f}\n"
+ f" - Deduplication Time (seconds): {train_test_dedup_time:.2f}\n"
+ f" - Build Time (seconds): {build_time:.2f}\n"
+ f" - Total Time (seconds): {train_test_dedup_time + build_time:.2f}\n"
+ )
+
+ # Write the results to JSON files
+ with open("benchmarks/results/image_train_benchmark_results.json", "w", encoding="utf-8") as f:
+ json.dump(train_dedup_results, f, ensure_ascii=False, indent=2)
+
+ with open("benchmarks/results/image_train_test_benchmark_results.json", "w", encoding="utf-8") as f:
+ json.dump(train_test_dedup_results, f, ensure_ascii=False, indent=2)
+
+ # Print the train table
+ print("### Image Train Deduplication Benchmark\n") # noqa T201
+ print( # noqa T201
+ f"| {'Dataset':<20} | {'Original Train Size':>20} | {'Deduplicated Train Size':>24} | {'% Removed':>10} | {'Deduplication Time (s)':>24} |"
+ ) # noqa T201
+ print("|" + "-" * 22 + "|" + "-" * 22 + "|" + "-" * 26 + "|" + "-" * 12 + "|" + "-" * 26 + "|") # noqa T201
+ for r in train_dedup_results:
+ print( # noqa T201
+ f"| {r['dataset']:<20} "
+ f"| {r['original_train_size']:>20} "
+ f"| {r['deduplicated_train_size']:>24} "
+ f"| {r['percent_removed']:>10.2f} "
+ f"| {r['time_seconds']:>24.2f} |"
+ )
+
+ print("\n") # noqa T201
+
+ # Print the train/test table
+ print("### Image Train/Test Deduplication Benchmark\n") # noqa T201
+ print( # noqa T201
+ f"| {'Dataset':<20} | {'Train Size':>12} | {'Test Size':>12} | {'Deduplicated Test Size':>24} | {'% Removed':>10} | {'Deduplication Time (s)':>24} |"
+ ) # noqa T201
+ print("|" + "-" * 22 + "|" + "-" * 14 + "|" + "-" * 14 + "|" + "-" * 26 + "|" + "-" * 12 + "|" + "-" * 26 + "|") # noqa T201
+ for r in train_test_dedup_results:
+ print( # noqa T201
+ f"| {r['dataset']:<20} "
+ f"| {r['train_size']:>12} "
+ f"| {r['test_size']:>12} "
+ f"| {r['deduplicated_test_size']:>24} "
+ f"| {r['percent_removed']:>10.2f} "
+ f"| {r['time_seconds']:>24.2f} |"
+ )
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
+ main()
diff --git a/benchmarks/run_benchmarks.py b/benchmarks/run_text_benchmarks.py
similarity index 100%
rename from benchmarks/run_benchmarks.py
rename to benchmarks/run_text_benchmarks.py
diff --git a/pyproject.toml b/pyproject.toml
index 966df8a..e8ba058 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "semhash"
-description = "Fast Semantic Text Deduplication & Filtering"
+description = "Fast Multimodal Semantic Deduplication & Filtering"
authors = [{name = "Thomas van Dongen", email = "thomas123@live.nl"}, { name = "Stéphan Tulkens", email = "stephantul@gmail.com"}]
readme = { file = "README.md", content-type = "text/markdown" }
dynamic = ["version"]
@@ -43,6 +43,7 @@ dev = [
"ruff",
]
+
[project.urls]
"Homepage" = "https://github.com/MinishLab"
"Bug Reports" = "https://github.com/MinishLab/semhash/issues"
diff --git a/semhash/records.py b/semhash/records.py
index c2bbb27..0835606 100644
--- a/semhash/records.py
+++ b/semhash/records.py
@@ -1,6 +1,134 @@
+from collections import defaultdict
from collections.abc import Sequence
+from typing import Any
+
+from frozendict import frozendict
from semhash.datamodels import DeduplicationResult, DuplicateRecord
+from semhash.utils import Record, coerce_value, to_frozendict
+
+
+def group_records_by_key(
+ records: Sequence[dict[str, Any]],
+ columns: Sequence[str],
+) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]]]:
+ """
+ Group records by exact match on columns, preserving first-occurrence order.
+
+ :param records: Records to group.
+ :param columns: Columns to use as grouping key.
+ :return: Tuple of (deduplicated_records, items) where:
+ - deduplicated_records: first record from each unique group
+ - items: list of groups, each group is a list of exact duplicates
+ """
+ # Track buckets by key and preserve first-occurrence order
+ buckets: dict[frozendict[str, Any], list[dict[str, Any]]] = {}
+ order: list[frozendict[str, Any]] = []
+
+ for record in records:
+ key = to_frozendict(record, columns)
+ bucket = buckets.get(key)
+ if bucket is None:
+ # First occurrence: create new bucket and track order
+ buckets[key] = [record]
+ order.append(key)
+ else:
+ # Duplicate: add to existing bucket
+ bucket.append(record)
+
+ # Reconstruct in first-occurrence order
+ items = [buckets[k] for k in order]
+ deduplicated_records = [bucket[0] for bucket in items]
+ return deduplicated_records, items
+
+
+def remove_exact_duplicates(
+ records: Sequence[dict[str, Any]],
+ columns: Sequence[str],
+ reference_records: list[list[dict[str, Any]]] | None = None,
+) -> tuple[list[dict[str, Any]], list[tuple[dict[str, Any], list[dict[str, Any]]]]]:
+ """
+ Remove exact duplicates based on the hashable representation of each record.
+
+ If reference_records is None, the function will only check for duplicates within the records list.
+
+ :param records: A list of records to check for exact duplicates.
+ :param columns: Columns to unpack.
+ :param reference_records: A list of records to compare against. These are already unpacked
+ :return: A list of deduplicated records and a list of duplicates.
+ """
+ deduplicated: list[dict[str, Any]] = []
+ duplicates: list[tuple[dict[str, Any], list[dict[str, Any]]]] = []
+
+ column_set = set(columns)
+
+ # Build seen set from reference_records (cross-dataset mode) or empty (single-dataset mode)
+ seen: defaultdict[frozendict[str, Any], list[dict[str, Any]]] = defaultdict(list)
+ if reference_records is not None:
+ for record_set in reference_records:
+ key = to_frozendict(record_set[0], column_set)
+ seen[key] = list(record_set)
+
+ for record in records:
+ frozen_record = to_frozendict(record, column_set)
+ if duplicated_records := seen.get(frozen_record):
+ duplicates.append((record, duplicated_records))
+ else:
+ deduplicated.append(record)
+ # Single-dataset mode: track this record for future comparisons
+ if reference_records is None:
+ seen[frozen_record].append(record)
+
+ return deduplicated, duplicates
+
+
+def prepare_records(
+ records: Sequence[Record], columns: Sequence[str] | None
+) -> tuple[list[dict[str, Any]], Sequence[str], bool]:
+ """
+ Validate and prepare records for processing.
+
+ :param records: A list of records (strings or dictionaries).
+ :param columns: Columns to use if records are dictionaries.
+ :return: Tuple of (dict_records, columns, was_string).
+ :raises ValueError: If records are empty.
+ :raises ValueError: If columns are not provided for dictionary records.
+ :raises ValueError: If dict record contains None values.
+ :raises ValueError: If records are not homogeneous (mixed strings and dicts).
+ """
+ if len(records) == 0:
+ raise ValueError("records must not be empty")
+
+ if columns is None and isinstance(records[0], dict):
+ raise ValueError("Columns must be specified when passing dictionaries.")
+
+ # String path: convert to dicts with "text" column
+ if isinstance(records[0], str):
+ if not all(isinstance(r, str) for r in records):
+ raise ValueError("All records must be strings when the first record is a string.")
+ columns = ["text"]
+ dict_records: list[dict[str, Any]] = [{"text": record} for record in records]
+ was_string = True
+ # Dict path: validate and coerce values
+ else:
+ if not all(isinstance(r, dict) for r in records):
+ raise ValueError("All records must be dicts when the first record is a dict.")
+ assert columns is not None
+
+ # Coerce values: stringify primitives, keep complex types raw (for images, etc.)
+ dict_records_typed: list[dict[str, Any]] = list(records)
+ dict_records = []
+ for record in dict_records_typed:
+ coerced: dict[str, Any] = {}
+ for column in columns:
+ val = record.get(column)
+ if val is None:
+ raise ValueError(f"Column '{column}' has None value in record {record}")
+ coerced[column] = coerce_value(val)
+ dict_records.append(coerced)
+ was_string = False
+
+ return dict_records, columns, was_string
def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str:
diff --git a/semhash/semhash.py b/semhash/semhash.py
index 2baaabe..e866bc4 100644
--- a/semhash/semhash.py
+++ b/semhash/semhash.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-from collections import defaultdict
from collections.abc import Sequence
from math import ceil
from typing import Any, Generic, Literal
@@ -11,15 +10,21 @@
from pyversity import Strategy, diversify
from vicinity import Backend
-from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, Record
+from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult
from semhash.index import Index
-from semhash.records import add_scores_to_records, map_deduplication_result_to_strings
+from semhash.records import (
+ add_scores_to_records,
+ group_records_by_key,
+ map_deduplication_result_to_strings,
+ prepare_records,
+ remove_exact_duplicates,
+)
from semhash.utils import (
Encoder,
+ Record,
+ coerce_value,
compute_candidate_limit,
featurize,
- prepare_records,
- remove_exact_duplicates,
to_frozendict,
)
@@ -52,7 +57,7 @@ def from_records(
"""
Initialize a SemHash instance from records.
- This removes exact duplicates, featurizes the records, and fits a vicinity index.
+ Removes exact duplicates, featurizes the records, and fits a vicinity index.
:param records: A list of records (strings or dictionaries).
:param columns: Columns to featurize if records are dictionaries.
@@ -65,37 +70,17 @@ def from_records(
dict_records, columns, was_string = prepare_records(records, columns)
# If no model is provided, load the default model
- if model is None:
+ if model is None: # pragma: no cover
model = StaticModel.from_pretrained("minishlab/potion-base-8M")
- # Remove exact duplicates
- deduplicated_records, duplicates = remove_exact_duplicates(dict_records, columns)
-
- col_set = set(columns)
- duplicate_map = defaultdict(list)
- for x, _ in duplicates:
- frozen_record = to_frozendict(x, col_set)
- duplicate_map[frozen_record].append(x)
-
- items: list[list[dict[str, str]]] = []
- for record in deduplicated_records:
- i = [record]
- frozen_record = to_frozendict(record, col_set)
- i.extend(duplicate_map[frozen_record])
- items.append(i)
+ # Group by exact match, preserving first-occurrence order
+ deduplicated_records, items = group_records_by_key(dict_records, columns)
# Create embeddings for deduplicated records only
embeddings = featurize(deduplicated_records, columns, model)
- # Build the Vicinity index
- index = Index.from_vectors_and_items(
- vectors=embeddings,
- items=items,
- backend_type=ann_backend,
- **kwargs,
- )
-
- return cls(index=index, columns=columns, model=model, was_string=was_string)
+ index = Index.from_vectors_and_items(vectors=embeddings, items=items, backend_type=ann_backend, **kwargs)
+ return cls(index=index, model=model, columns=columns, was_string=was_string)
@classmethod
def from_embeddings(
@@ -110,7 +95,7 @@ def from_embeddings(
"""
Initialize a SemHash instance from pre-computed embeddings.
- This removes exact duplicates and fits a vicinity index using the provided embeddings.
+ Removes exact duplicates, featurizes the records, and fits a vicinity index.
:param embeddings: Pre-computed embeddings as a numpy array of shape (n_records, embedding_dim).
:param records: A list of records (strings or dictionaries) corresponding to the embeddings.
@@ -160,10 +145,7 @@ def from_embeddings(
deduplicated_embeddings = embeddings[keep_embedding_indices]
index = Index.from_vectors_and_items(
- vectors=deduplicated_embeddings,
- items=items,
- backend_type=ann_backend,
- **kwargs,
+ vectors=deduplicated_embeddings, items=items, backend_type=ann_backend, **kwargs
)
return cls(index=index, model=model, columns=columns, was_string=was_string)
@@ -267,8 +249,8 @@ def self_deduplicate(
duplicate_records.append(DuplicateRecord(record=curr_record, duplicates=items_with_score, exact=True))
# If we don't see any similar_items, we know the record is not a duplicate.
- # in rare cases, the item itself might not be a duplicate of itself.
- if not similar_items:
+ # In rare cases, the item itself might not be returned by the index.
+ if not similar_items: # pragma: no cover
deduplicated_records.append(record)
continue
items, _ = zip(*similar_items)
@@ -299,30 +281,45 @@ def self_deduplicate(
return result
- def _validate_if_strings(self, records: Sequence[dict[str, str] | str]) -> Sequence[dict[str, str]]:
+ def _validate_if_strings(self, records: Sequence[dict[str, Any] | str]) -> list[dict[str, Any]]:
"""
Validate if the records are strings.
If the records are strings, they are converted to dictionaries with a single column.
+ If the records are dicts, primitives are stringified and complex types (images, etc.) are kept raw.
:param records: The records to validate.
:return: The records as a list of dictionaries.
:raises ValueError: If records are empty.
:raises ValueError: If the records are strings but were not originally strings.
- :raises ValueError: If the records are not all strings or dictionaries.
+ :raises ValueError: If the records are not all strings or all dictionaries.
+ :raises ValueError: If dict record contains None values.
"""
if len(records) == 0:
raise ValueError("records must not be empty")
+ # String path
if isinstance(records[0], str):
if not self._was_string:
raise ValueError("Records were not originally strings, but you passed strings.")
- dict_records = [{"text": record} for record in records if isinstance(record, str)]
- else:
- dict_records = [record for record in records if isinstance(record, dict)]
- if len(dict_records) != len(records):
- raise ValueError("Records must be either strings or dictionaries.")
- return dict_records
+ if not all(isinstance(r, str) for r in records):
+ raise ValueError("Records must be all strings.")
+ return [{"text": r} for r in records]
+
+ # Dict path
+ if not all(isinstance(r, dict) for r in records):
+ raise ValueError("Records must be all dictionaries.")
+
+ dict_records: Sequence[dict[str, Any]] = records # type: ignore[assignment]
+ result: list[dict[str, Any]] = []
+ for r in dict_records:
+ out = {}
+ for c in self.columns:
+ if (val := r.get(c)) is None:
+ raise ValueError(f"Column '{c}' has None value in record {r}")
+ out[c] = coerce_value(val)
+ result.append(out)
+ return result
def find_representative(
self,
diff --git a/semhash/utils.py b/semhash/utils.py
index f3abb00..c2f54e0 100644
--- a/semhash/utils.py
+++ b/semhash/utils.py
@@ -1,4 +1,4 @@
-from collections import defaultdict
+import hashlib
from collections.abc import Sequence
from typing import Any, Protocol, TypeAlias, TypeVar
@@ -11,26 +11,78 @@
class Encoder(Protocol):
- """An encoder protocol for SemHash."""
+ """An encoder protocol for SemHash. Supports text, images, or any encodable data."""
def encode(
self,
- sentences: list[str] | str | Sequence[str],
+ inputs: Sequence[Any] | Any,
**kwargs: Any,
) -> np.ndarray:
"""
- Encode a list of sentences into embeddings.
+ Encode a list of inputs into embeddings.
- :param sentences: A list of sentences to encode.
+ :param inputs: A list of inputs to encode (strings, images, etc.).
:param **kwargs: Additional keyword arguments.
- :return: The embeddings of the sentences.
+ :return: The embeddings of the inputs.
"""
... # pragma: no cover
-def to_frozendict(record: dict[str, str], columns: set[str]) -> frozendict[str, str]:
- """Convert a record to a frozendict."""
- return frozendict({k: record.get(k, "") for k in columns})
+def make_hashable(value: Any) -> Any:
+ """
+ Convert a value to a hashable representation for use as dict keys.
+
+ Strings and other hashable types are returned as-is.
+ Non-hashable types (like PIL images, numpy arrays) are hashed to a string.
+
+ :param value: The value to make hashable.
+ :return: A hashable representation of the value.
+ """
+ # Fast path: most values are strings or already hashable
+ if isinstance(value, (str, int, float, bool, type(None))):
+ return value
+ # Handle objects with tobytes() (PIL Image, numpy array, etc.)
+ if hasattr(value, "tobytes"):
+ return hashlib.md5(value.tobytes()).hexdigest()
+ # Fallback: try to hash, otherwise stringify
+ try:
+ hash(value)
+ return value
+ except TypeError:
+ return str(value)
+
+
+def coerce_value(value: Any) -> Any:
+ """
+ Coerce a value for encoding: stringify primitives, keep complex types raw.
+
+ This ensures primitives (int, float, bool) work with text encoders,
+ while complex types (PIL images, tensors, etc.) are passed through for multimodal encoders.
+
+ :param value: The value to coerce.
+ :return: The coerced value.
+ """
+ if isinstance(value, (str, bytes)):
+ return value
+ if isinstance(value, (int, float, bool)):
+ return str(value)
+ return value # Complex types (images, tensors, etc.)
+
+
+def to_frozendict(record: dict[str, Any], columns: Sequence[str] | set[str]) -> frozendict[str, Any]:
+ """
+ Convert a record to a frozendict with hashable values.
+
+ :param record: The record to convert.
+ :param columns: The columns to include.
+ :return: A frozendict with only the specified columns (values made hashable).
+ :raises ValueError: If a column is missing from the record.
+ """
+ try:
+ return frozendict({k: make_hashable(record[k]) for k in columns})
+ except KeyError as e:
+ missing = e.args[0]
+ raise ValueError(f"Missing column '{missing}' in record {record}") from e
def compute_candidate_limit(
@@ -62,7 +114,7 @@ def compute_candidate_limit(
def featurize(
- records: Sequence[dict[str, str]],
+ records: Sequence[dict[str, Any]],
columns: Sequence[str],
model: Encoder,
) -> np.ndarray:
@@ -73,81 +125,25 @@ def featurize(
:param columns: Columns to featurize.
:param model: An Encoder model.
:return: The embeddings of the records.
+ :raises ValueError: If a column is missing from one or more records.
+ :raises TypeError: If encoding fails due to incompatible data types.
"""
# Extract the embeddings for each column across all records
embeddings_per_col = []
for col in columns:
- col_texts = [r[col] for r in records]
- col_emb = model.encode(col_texts)
+ try:
+ col_texts = [r[col] for r in records]
+ except KeyError as e:
+ raise ValueError(f"Missing column '{col}' in one or more records") from e
+ try:
+ col_emb = model.encode(col_texts)
+ except TypeError as e:
+ sample_type = type(col_texts[0]).__name__ if col_texts else "unknown"
+ raise TypeError(
+ f"Failed to encode column '{col}' (data type: {sample_type}). "
+ f"If encoding non-text data, provide a compatible encoder via the `model` parameter. "
+ f"See the SemHash documentation for more info."
+ ) from e
embeddings_per_col.append(np.asarray(col_emb))
return np.concatenate(embeddings_per_col, axis=1)
-
-
-def remove_exact_duplicates(
- records: Sequence[dict[str, str]],
- columns: Sequence[str],
- reference_records: list[list[dict[str, str]]] | None = None,
-) -> tuple[list[dict[str, str]], list[tuple[dict[str, str], list[dict[str, str]]]]]:
- """
- Remove exact duplicates based on the unpacked string representation of each record.
-
- If reference_records is None, the function will only check for duplicates within the records list.
-
- :param records: A list of records to check for exact duplicates.
- :param columns: Columns to unpack.
- :param reference_records: A list of records to compare against. These are already unpacked
- :return: A list of deduplicated records and a list of duplicates.
- """
- deduplicated = []
- duplicates = []
-
- column_set = set(columns)
- # Build a seen set from reference_records if provided
- seen: defaultdict[frozendict[str, str], list[dict[str, str]]] = defaultdict(list)
- if reference_records is not None:
- for record_set in reference_records:
- key = to_frozendict(record_set[0], column_set)
- seen[key] = list(record_set)
- in_one_set = reference_records is None
-
- for record in records:
- frozen_record = to_frozendict(record, column_set)
- if duplicated_records := seen.get(frozen_record):
- duplicates.append((record, duplicated_records))
- else:
- deduplicated.append(record)
- # Only add current documents to seen if no reference set is used
- if in_one_set:
- seen[frozen_record].append(record)
-
- return deduplicated, duplicates
-
-
-def prepare_records(
- records: Sequence[Record], columns: Sequence[str] | None
-) -> tuple[list[dict[str, str]], Sequence[str], bool]:
- """
- Validate and prepare records for processing.
-
- :param records: A list of records (strings or dictionaries).
- :param columns: Columns to use if records are dictionaries.
- :return: Tuple of (dict_records, columns, was_string).
- :raises ValueError: If records are empty.
- :raises ValueError: If columns are not provided for dictionary records.
- """
- if len(records) == 0:
- raise ValueError("records must not be empty")
-
- if columns is None and isinstance(records[0], dict):
- raise ValueError("Columns must be specified when passing dictionaries.")
-
- if isinstance(records[0], str):
- columns = ["text"]
- dict_records: list[dict[str, str]] = [{"text": str(record)} for record in records]
- was_string = True
- else:
- dict_records = list(records)
- was_string = False
-
- return dict_records, columns, was_string
diff --git a/semhash/version.py b/semhash/version.py
index 9bfefb0..9dbcf97 100644
--- a/semhash/version.py
+++ b/semhash/version.py
@@ -1,2 +1,2 @@
-__version_triple__ = (0, 3, 3)
-__version__ = ".".join(map(str, __version_triple__))
+__version_triple__ = (0, 4, 0) # pragma: no cover
+__version__ = ".".join(map(str, __version_triple__)) # pragma: no cover
diff --git a/tests/test_datamodels.py b/tests/test_datamodels.py
index 59c2563..307eeeb 100644
--- a/tests/test_datamodels.py
+++ b/tests/test_datamodels.py
@@ -1,8 +1,6 @@
import pytest
-import semhash
-import semhash.version
-from semhash.datamodels import DeduplicationResult, DuplicateRecord, SelectedWithDuplicates
+from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, SelectedWithDuplicates
def test_deduplication_scoring() -> None:
@@ -25,34 +23,27 @@ def test_deduplication_scoring_exact() -> None:
assert d.exact_duplicate_ratio == 0.2
-def test_deduplication_scoring_exact_empty() -> None:
- """Test the deduplication scoring."""
- d = DeduplicationResult([], [], 0.8, columns=["text"])
- assert d.exact_duplicate_ratio == 0.0
-
-
def test_deduplication_scoring_empty() -> None:
- """Test the deduplication scoring."""
+ """Test the deduplication scoring with empty results."""
d = DeduplicationResult([], [], 0.8, columns=["text"])
assert d.duplicate_ratio == 0.0
+ assert d.exact_duplicate_ratio == 0.0
def test_rethreshold() -> None:
- """Test rethresholding the duplicates."""
+ """Test rethresholding the duplicates, including empty case."""
d = DuplicateRecord("a", False, [("b", 0.9), ("c", 0.8)])
d._rethreshold(0.85)
assert d.duplicates == [("b", 0.9)]
-
-def test_rethreshold_empty() -> None:
- """Test rethresholding the duplicates."""
- d = DuplicateRecord("a", False, [])
- d._rethreshold(0.85)
- assert d.duplicates == []
+ # Empty case
+ d_empty = DuplicateRecord("a", False, [])
+ d_empty._rethreshold(0.85)
+ assert d_empty.duplicates == []
def test_get_least_similar_from_duplicates() -> None:
- """Test getting the least similar duplicates."""
+ """Test getting the least similar duplicates, including empty case."""
d = DeduplicationResult(
["a", "b", "c"],
[DuplicateRecord("a", False, [("b", 0.9), ("c", 0.7)]), DuplicateRecord("b", False, [("c", 0.8)])],
@@ -61,11 +52,9 @@ def test_get_least_similar_from_duplicates() -> None:
result = d.get_least_similar_from_duplicates(1)
assert result == [("a", "c", 0.7)]
-
-def test_get_least_similar_from_duplicates_empty() -> None:
- """Test getting the least similar duplicates."""
- d = DeduplicationResult([], [], 0.8, columns=["text"])
- assert d.get_least_similar_from_duplicates(1) == []
+ # Empty case
+ d_empty = DeduplicationResult([], [], 0.8, columns=["text"])
+ assert d_empty.get_least_similar_from_duplicates(1) == []
def test_rethreshold_deduplication_result() -> None:
@@ -243,3 +232,10 @@ def test_selected_with_duplicates_cache_invalidation_on_rethreshold() -> None:
assert result2[0].duplicates[0][0] == "duplicate_1"
# Results should be different objects
assert result1 is not result2
+
+
+def test_filter_result_empty() -> None:
+ """Test FilterResult ratios with empty lists."""
+ result = FilterResult(selected=[], filtered=[])
+ assert result.filter_ratio == 0.0
+ assert result.selected_ratio == 1.0
diff --git a/tests/test_semhash.py b/tests/test_semhash.py
index 1b4105c..55119fd 100644
--- a/tests/test_semhash.py
+++ b/tests/test_semhash.py
@@ -141,27 +141,31 @@ def test_deduplicate_with_only_exact_duplicates(model: Encoder) -> None:
def test_self_find_representative(model: Encoder, train_texts: list[str]) -> None:
"""Test the self_find_representative method."""
semhash = SemHash.from_records(records=train_texts, model=model)
- result = semhash.self_find_representative(
- candidate_limit=5,
- selection_size=3,
- diversity=0.5,
- )
+
+ # Test with explicit candidate_limit
+ result = semhash.self_find_representative(candidate_limit=5, selection_size=3, diversity=0.5)
assert len(result.selected) == 3, "Expected 3 representatives"
selected = {r["text"] for r in result.selected}
- assert selected == {
- "blueberry",
- "pineapple",
- "grape",
- }, "Expected representatives to be blueberry, pineapple, and grape"
+ assert selected == {"blueberry", "pineapple", "grape"}
+
+ # Test with auto candidate_limit (default)
+ result_auto = semhash.self_find_representative(selection_size=3, diversity=0.5)
+ assert len(result_auto.selected) == 3
def test_find_representative(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
"""Test the find_representative method."""
semhash = SemHash.from_records(records=train_texts, model=model)
+
+ # Test with explicit candidate_limit
result = semhash.find_representative(records=test_texts, candidate_limit=5, selection_size=3, diversity=0.5)
assert len(result.selected) == 3, "Expected 3 representatives"
selected = {r["text"] for r in result.selected}
- assert selected == {"grapefruit", "banana", "apple"}, "Expected representatives to be grapefruit, banana, and apple"
+ assert selected == {"grapefruit", "banana", "apple"}
+
+ # Test with auto candidate_limit (default)
+ result_auto = semhash.find_representative(records=test_texts, selection_size=3, diversity=0.5)
+ assert len(result_auto.selected) == 3
def test_filter_outliers(model: Encoder, train_texts: list[str], test_texts: list[str]) -> None:
@@ -173,10 +177,23 @@ def test_filter_outliers(model: Encoder, train_texts: list[str], test_texts: lis
filtered = {r["text"] for r in result.filtered}
assert filtered == {"motorcycle", "plane"}, "Expected outliers to be motorcycle and plane"
+ # Test FilterResult ratio properties
+ assert result.filter_ratio == len(result.filtered) / len(test_texts)
+ assert result.selected_ratio == len(result.selected) / len(test_texts)
+ assert result.filter_ratio + result.selected_ratio == 1.0
+
# Test with outlier_percentage=0.0 (should return no outliers)
result_zero = semhash.filter_outliers(records=test_texts, outlier_percentage=0.0)
assert result_zero.filtered == []
assert len(result_zero.selected) == len(test_texts)
+ assert result_zero.filter_ratio == 0.0
+ assert result_zero.selected_ratio == 1.0
+
+ # Invalid outlier_percentage raises ValueError
+ with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"):
+ semhash.filter_outliers(records=test_texts, outlier_percentage=-0.1)
+ with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"):
+ semhash.filter_outliers(records=test_texts, outlier_percentage=1.5)
def test_self_filter_outliers(model: Encoder, train_texts: list[str]) -> None:
@@ -193,6 +210,12 @@ def test_self_filter_outliers(model: Encoder, train_texts: list[str]) -> None:
assert result_zero.filtered == []
assert len(result_zero.selected) == len(train_texts)
+ # Invalid outlier_percentage raises ValueError
+ with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"):
+ semhash.self_filter_outliers(outlier_percentage=-0.1)
+ with pytest.raises(ValueError, match="outlier_percentage must be between 0 and 1"):
+ semhash.self_filter_outliers(outlier_percentage=1.5)
+
def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test the _diversify method."""
@@ -226,30 +249,80 @@ def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None:
def test_from_embeddings(model: Encoder, train_texts: list[str]) -> None:
"""Test from_embeddings constructor with validation and comparison to from_records."""
- # Test validation: mismatched shapes
+ # Validation: empty records
+ with pytest.raises(ValueError, match="records must not be empty"):
+ SemHash.from_embeddings(embeddings=np.array([[]]), records=[], model=model)
+
+ # Validation: non-2D embeddings
+ with pytest.raises(ValueError, match="must be a 2D array"):
+ SemHash.from_embeddings(embeddings=np.array([1, 2, 3]), records=["a", "b", "c"], model=model)
+
+ # Validation: mismatched shapes
with pytest.raises(ValueError, match="Number of embeddings"):
wrong_embeddings = model.encode(["apple", "banana"])
SemHash.from_embeddings(embeddings=wrong_embeddings, records=train_texts, model=model)
# Test that from_embeddings behaves same as from_records
semhash_from_records = SemHash.from_records(records=train_texts, model=model)
-
embeddings = model.encode(train_texts)
semhash_from_embeddings = SemHash.from_embeddings(embeddings=embeddings, records=train_texts, model=model)
- # Both should give same deduplication results
result1 = semhash_from_records.self_deduplicate(threshold=0.95)
result2 = semhash_from_embeddings.self_deduplicate(threshold=0.95)
-
assert len(result1.selected) == len(result2.selected)
- assert len(result1.filtered) == len(result2.filtered)
# Test that from_embeddings keeps first-occurrence embeddings and drops duplicates
records = ["apple", "banana", "apple", "cherry"]
embeddings = np.array([[0.0], [1.0], [2.0], [3.0]], dtype=np.float32)
-
semhash = SemHash.from_embeddings(embeddings=embeddings, records=records, model=model)
-
assert semhash.index.vectors.shape == (3, 1)
- # Should keep embeddings at indices 0, 1, 3 (first occurrences of img1, img2, img3)
assert semhash.index.vectors.tolist() == [[0.0], [1.0], [3.0]]
+
+
+def test_from_records_edge_cases(model: Encoder) -> None:
+ """Test from_records edge cases: coercion, order preservation, None rejection."""
+ # Coerces non-string dict values to strings
+ records = [{"id": 1}, {"id": 2}, {"id": 1}] # Integers, with duplicate
+ semhash = SemHash.from_records(records, columns=["id"], model=model)
+ assert semhash.index.vectors.shape[0] == 2 # Deduplicated
+ assert 2 in [len(bucket) for bucket in semhash.index.items] # id=1 bucket has 2
+
+ # Preserves first-occurrence order (deterministic)
+ texts = ["zebra", "apple", "zebra", "banana", "apple", "cherry"]
+ semhash = SemHash.from_records(texts, model=model)
+ firsts = [bucket[0]["text"] for bucket in semhash.index.items]
+ assert firsts == ["zebra", "apple", "banana", "cherry"]
+
+ # Rejects None values in dict records
+ with pytest.raises(ValueError, match="has None value"):
+ SemHash.from_records([{"text": "apple"}, {"text": None}], columns=["text"], model=model)
+
+
+def test_deduplicate_edge_cases(model: Encoder) -> None:
+ """Test deduplicate() edge cases: coercion, None rejection, empty records, type mismatches."""
+ semhash = SemHash.from_records(["1", "2", "3"], model=model)
+
+ # Coerces non-string dict values
+ result = semhash.deduplicate([{"text": 1}, {"text": 4}], threshold=0.95)
+ assert len(result.filtered) + len(result.selected) == 2
+
+ # Rejects None values
+ with pytest.raises(ValueError, match="has None value"):
+ semhash.deduplicate([{"text": "cherry"}, {"text": None}], threshold=0.95)
+
+ # Rejects empty records
+ with pytest.raises(ValueError, match="records must not be empty"):
+ semhash.deduplicate([], threshold=0.95)
+
+ # Type mismatch: strings passed to dict-based index
+ semhash_dict = SemHash.from_records([{"col": "a"}, {"col": "b"}], columns=["col"], model=model)
+ with pytest.raises(ValueError, match="Records were not originally strings"):
+ semhash_dict.deduplicate(["x", "y"], threshold=0.95)
+
+ # Type mismatch: mixed strings
+ with pytest.raises(ValueError, match="Records must be all strings"):
+ semhash.deduplicate(["a", {"text": "b"}], threshold=0.95)
+
+ # Type mismatch: mixed dicts
+ with pytest.raises(ValueError, match="Records must be all dictionaries"):
+ semhash_dict.deduplicate([{"col": "a"}, "b"], threshold=0.95)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 371fb16..b161367 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -2,23 +2,74 @@
import pytest
from frozendict import frozendict
-from semhash.utils import (
- Encoder,
- compute_candidate_limit,
- featurize,
- prepare_records,
- remove_exact_duplicates,
- to_frozendict,
-)
+from semhash.records import prepare_records, remove_exact_duplicates
+from semhash.utils import Encoder, coerce_value, compute_candidate_limit, featurize, make_hashable, to_frozendict
+
+
+def test_make_hashable() -> None:
+ """Test make_hashable with various types."""
+ # Fast path: primitives
+ assert make_hashable("hello") == "hello"
+ assert make_hashable(42) == 42
+ assert make_hashable(3.14) == 3.14
+ assert make_hashable(True) is True
+ assert make_hashable(None) is None
+
+ # Objects with tobytes() (simulate PIL Image or numpy array)
+ class MockImage:
+ def tobytes(self) -> bytes:
+ return b"fake_image_data"
+
+ img = MockImage()
+ result = make_hashable(img)
+ assert isinstance(result, str)
+ assert len(result) == 32 # MD5 hex digest
+
+ # Hashable objects (like tuples)
+ assert make_hashable((1, 2, 3)) == (1, 2, 3)
+
+ # Non-hashable fallback to string
+ unhashable = {"key": "value"}
+ result = make_hashable(unhashable)
+ assert result == "{'key': 'value'}"
+
+
+def test_coerce_value() -> None:
+ """Test coerce_value for encoding preparation."""
+ # Strings and bytes pass through
+ assert coerce_value("hello") == "hello"
+ assert coerce_value(b"bytes") == b"bytes"
+
+ # Primitives converted to strings
+ assert coerce_value(42) == "42"
+ assert coerce_value(3.14) == "3.14"
+ assert coerce_value(True) == "True"
+
+ # Complex types pass through unchanged
+ class MockImage:
+ pass
+
+ img = MockImage()
+ assert coerce_value(img) is img
def test_to_frozendict() -> None:
- """Test converting dict to frozendict."""
+ """Test converting dict to frozendict, including error cases."""
record = {"a": "1", "b": "2", "c": "3"}
+
+ # Basic case: select subset of columns
result = to_frozendict(record, {"a", "c"})
assert result == frozendict({"a": "1", "c": "3"})
assert "b" not in result
+ # Works with Sequence (not just set)
+ result = to_frozendict(record, ["a", "b"])
+ assert result == frozendict({"a": "1", "b": "2"})
+
+ # Missing column raises ValueError
+ with pytest.raises(ValueError, match="Missing column 'missing'"):
+ to_frozendict(record, {"a", "missing"})
+
def test_compute_candidate_limit() -> None:
"""Test candidate limit computation."""
@@ -33,30 +84,61 @@ def test_compute_candidate_limit() -> None:
def test_featurize(model: Encoder) -> None:
- """Test featurizing records."""
+ """Test featurizing records, including error cases."""
records = [{"text": "hello"}, {"text": "world"}]
embeddings = featurize(records, ["text"], model)
assert embeddings.shape == (2, 128) # Model has 128 dims
assert isinstance(embeddings, np.ndarray)
+ # Missing column raises ValueError
+ with pytest.raises(ValueError, match="Missing column 'missing'"):
+ featurize(records, ["missing"], model)
+
+ # Non-text data with text encoder raises helpful TypeError
+ class FakeImage:
+ pass
+
+ records_with_images = [{"img": FakeImage()}, {"img": FakeImage()}]
+ with pytest.raises(TypeError, match="Failed to encode column 'img'"):
+ featurize(records_with_images, ["img"], model)
+ with pytest.raises(TypeError, match="data type: FakeImage"):
+ featurize(records_with_images, ["img"], model)
+
def test_remove_exact_duplicates() -> None:
- """Test exact duplicate removal."""
+ """Test exact duplicate removal, with and without reference records."""
+ # Basic case: remove duplicates within same list
records = [
{"text": "hello", "id": "1"},
{"text": "world", "id": "2"},
{"text": "hello", "id": "3"},
]
deduplicated, duplicates = remove_exact_duplicates(records, ["text"])
-
assert len(deduplicated) == 2
assert len(duplicates) == 1
assert duplicates[0][0] == {"text": "hello", "id": "3"}
+ # With reference_records: cross-dataset filtering
+ reference_records = [
+ [{"text": "apple"}],
+ [{"text": "banana"}, {"text": "banana"}],
+ ]
+ new_records = [
+ {"text": "cherry"}, # New
+ {"text": "apple"}, # Exists in reference
+ {"text": "date"}, # New
+ {"text": "banana"}, # Exists in reference
+ ]
+ deduplicated, duplicates = remove_exact_duplicates(new_records, ["text"], reference_records=reference_records)
+ assert len(deduplicated) == 2
+ assert {"text": "cherry"} in deduplicated
+ assert {"text": "date"} in deduplicated
+ assert len(duplicates) == 2
+
def test_prepare_records() -> None:
- """Test preparing records."""
- # String records
+ """Test preparing records, including validation and edge cases."""
+ # String records -> converts to dicts with "text" column
records = ["hello", "world"]
dict_records, columns, was_string = prepare_records(records, None)
assert was_string is True
@@ -71,6 +153,15 @@ def test_prepare_records() -> None:
assert dict_records == records
# Dict records without columns raises ValueError
- records = [{"text": "hello"}]
with pytest.raises(ValueError, match="Columns must be specified"):
- prepare_records(records, None)
+ prepare_records([{"text": "hello"}], None)
+
+ # Empty records raises ValueError
+ with pytest.raises(ValueError, match="records must not be empty"):
+ prepare_records([], None)
+
+ # Mixed types rejected
+ with pytest.raises(ValueError, match="All records must be"):
+ prepare_records(["a", {"text": "b"}], None)
+ with pytest.raises(ValueError, match="All records must be"):
+ prepare_records([{"text": "a"}, "b"], ["text"])
diff --git a/uv.lock b/uv.lock
index 0bd321f..9bd57ea 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,6 +1,11 @@
version = 1
revision = 3
requires-python = ">=3.10"
+resolution-markers = [
+ "python_full_version >= '3.12'",
+ "python_full_version == '3.11.*'",
+ "python_full_version < '3.11'",
+]
[[package]]
name = "asttokens"