Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
eb3f227
feat: 2 vbench dimensions and vbench dependencies
begumcig Oct 9, 2025
c21518d
test: vbench metric tests
begumcig Oct 10, 2025
ba2de41
docs: add more comprehensive docstring explanations for important par…
begumcig Oct 10, 2025
bda23c5
feat: add additional helper tools to utilities
begumcig Oct 13, 2025
2dc64e8
refactor: small updates to utilities and docstrings
begumcig Oct 14, 2025
e2099ce
refactor: add support for more calltypes in video eval utils
begumcig Oct 15, 2025
66b9d5f
refactor: make utilities more vbench independent and fix small things…
begumcig Oct 28, 2025
016af55
refactor: address PR comments
begumcig Nov 3, 2025
2191777
test: adding more tests for dynamic degree and background consistency
begumcig Nov 10, 2025
b814f24
feat: artifact saving and vbench related agent updates
begumcig Oct 6, 2025
48332b0
test: add tests for the artifact savers
begumcig Oct 7, 2025
00b56dc
test: add artifact related evaluation tests and task modality tests
begumcig Oct 8, 2025
46ac6e1
refactor: add some comments
begumcig Oct 8, 2025
07410b8
refactor: better initialization for artifact savers
begumcig Oct 23, 2025
2123845
test: add more dtype tests for artifact saver
begumcig Oct 24, 2025
c829596
feat: metric modalities as sets
begumcig Oct 24, 2025
341987c
refactor: comments tests task modality
begumcig Oct 28, 2025
7c62325
feat: add video inference support and seeding strategies to inference…
begumcig Oct 6, 2025
c5403af
feat: remove per evaluation seed and add tests
begumcig Oct 6, 2025
f838b8c
chore: add comments
begumcig Oct 6, 2025
be68471
fix: bfloats cannot be moved to cpu error in cmmd metric
begumcig Oct 14, 2025
e005d26
fix: pre commit file fix
begumcig Oct 14, 2025
3b1b0ea
refactor: configure seeding and tests
begumcig Oct 28, 2025
8a7fded
refactor: algorithm compatibility (#401)
johannaSommer Nov 5, 2025
1e8352c
feat:stratification to vbench datasets
begumcig Nov 27, 2025
0122c81
feat: data stratification by indexing
begumcig Nov 27, 2025
8a51bb2
Add image artifactsaver and modified utils to use it for algo sweeper…
Dec 3, 2025
16ee6c4
Changes to vbench-utils via ruff
Dec 3, 2025
07e4445
Add filename sanitizer which modifies invalid filename´aliases
Dec 10, 2025
0cbe55e
Change file format via ruff
Dec 10, 2025
b20c3f0
Add helper function for creating aliases as prompt names for outputs
Dec 10, 2025
9a5a98d
feat: 0 vbench dimensions and vbench dependencies
begumcig Oct 9, 2025
20d0844
Undo limiting prompt length as name for image logging
Dec 22, 2025
9b1696c
Integrate uncommented method from task.py
Dec 22, 2025
88ebf4d
Update function in evaluation agent for generating metadata-json file
Dec 22, 2025
e3612c0
Add evaluation-agent parameter for optional JSON metadata creation; i…
Marius-Graml Dec 24, 2025
a56dde2
Adjust doc string for _maybe_create_input_output_metadat() in evaluat…
Marius-Graml Dec 24, 2025
36677a6
Add model role in metadata json when applying pairwise metrics
Marius-Graml Dec 29, 2025
af19fe3
Change name and if logic of json file creation function in evaluation…
Marius-Graml Jan 13, 2026
5ef344b
Copy the tests from image-artifactsaver to feat/img-saver-extended to…
Marius-Graml Jan 13, 2026
bce5e79
Remove unused import
Marius-Graml Jan 13, 2026
07173d8
Kid metric added (#435)
minettekaum Dec 10, 2025
0ce3cb6
Pin `torchao==0.12.0` to avoid PyTorch ABI warnings, also pin `numpyd…
ParagEkbote Dec 15, 2025
3637967
feat: metric modalities as sets
begumcig Oct 24, 2025
4a9fb5d
Make prepare_inputs in diffuser handler more robust. Models such as F…
Marius-Graml Jan 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ repos:
pass_filenames: false
files: \.py$

- repo: local
hooks:
- id: ty
name: type checking using ty
entry: uvx ty check .
language: system
types: [python]
pass_filenames: false
files: \.py$

- repo: local
hooks:
- id: check-pruna-pro
Expand All @@ -63,13 +73,41 @@ repos:
grep -v "^D" |
cut -f2- |
while IFS= read -r file; do
<<<<<<< HEAD
<<<<<<< HEAD
<<<<<<< HEAD
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> 781a3a4 (feat: metric modalities as sets)
if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then
echo "Error: pruna_pro found in staged file $file"
exit 1
fi
<<<<<<< HEAD
=======
=======
>>>>>>> 2329abb (feat: metric modalities as sets)
if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then
=======
=======
>>>>>>> 2441e9d (fix: pre commit file fix)
if [ -f "$file" ] && ["$file" != ".pre-commit-config.yaml"] && grep -q "pruna_pro" "$file"; then
echo "Error: pruna_pro found in staged file $file"
exit 1
fi
>>>>>>> 4d5d496 (feat: 2 vbench dimensions and vbench dependencies)
=======
if [ -f "$file" ] && [ "$file" != ".pre-commit-config.yaml" ] && grep -q "pruna_pro" "$file"; then
echo "Error: pruna_pro found in staged file $file"
exit 1
fi
>>>>>>> 607edef (feat: 0 vbench dimensions and vbench dependencies)
=======
>>>>>>> 781a3a4 (feat: metric modalities as sets)
done
'
language: system
stages: [pre-commit]
types: [python]
files: \.py$
files: \.py$
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ unsupported-operator = "ignore" # mypy supports | syntax with from __future__ im
invalid-argument-type = "ignore" # mypy is more permissive with argument types
invalid-return-type = "ignore" # mypy is more permissive with return types
invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults
possibly-missing-attribute = "ignore" # mypy is more permissive with attribute access
possibly-unbound-attribute = "ignore"
possibly-missing-import = "ignore" # mypy is more permissive with imports
no-matching-overload = "ignore" # mypy is more permissive with overloads
unresolved-reference = "ignore" # mypy is more permissive with references
possibly-unbound-import = "ignore"
missing-argument = "ignore"
possibly-unbound-import = "ignore"

[tool.coverage.run]
source = ["src/pruna"]
Expand Down Expand Up @@ -75,6 +78,7 @@ gptqmodel = [
{ index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'"},
{ index = "pypi", marker = "sys_platform == 'darwin' and platform_machine == 'arm64'"},
]
clip = {git = "https://github.com/openai/CLIP.git", rev = "dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1"}

[project]
name = "pruna"
Expand Down Expand Up @@ -187,6 +191,7 @@ dev = [
]
cpu = []


[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down
15 changes: 12 additions & 3 deletions src/pruna/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
return texts


def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42) -> Dataset:
def stratify_dataset(dataset: Dataset, sample_size: int, seed: int = 42, partition_strategy: str = "random", partition_index: int = 0) -> Dataset:

Check failure on line 187 in src/pruna/data/utils.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (E501)

src/pruna/data/utils.py:187:122: E501 Line too long (146 > 121)
"""
Stratify the dataset into a specific size.

Expand All @@ -196,6 +196,10 @@
The size to stratify.
seed : int
The seed to use for sampling the dataset.
partition_strategy : str
The strategy to use for partitioning the dataset. Can be "indexed" or "random".
partition_index : int
The index to use for partitioning the dataset.

Returns
-------
Expand All @@ -211,8 +215,13 @@
return dataset

indices = list(range(dataset_length))
random.Random(seed).shuffle(indices)
selected_indices = indices[:sample_size]
if partition_strategy == "indexed":
selected_indices = indices[sample_size*partition_index:sample_size*(partition_index+1)]

Check failure on line 219 in src/pruna/data/utils.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (E226)

src/pruna/data/utils.py:219:92: E226 Missing whitespace around arithmetic operator

Check failure on line 219 in src/pruna/data/utils.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (E226)

src/pruna/data/utils.py:219:75: E226 Missing whitespace around arithmetic operator

Check failure on line 219 in src/pruna/data/utils.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (E226)

src/pruna/data/utils.py:219:47: E226 Missing whitespace around arithmetic operator
elif partition_strategy == "random":
random.Random(seed).shuffle(indices)
selected_indices = indices[:sample_size]
else:
raise ValueError(f"Invalid partition strategy: {partition_strategy}")
dataset = dataset.select(selected_indices)
return dataset

Expand Down
62 changes: 46 additions & 16 deletions src/pruna/engine/handler/handler_diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from __future__ import annotations

import inspect
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

import torch
from torchvision import transforms

from pruna.engine.handler.handler_inference import InferenceHandler
from pruna.logging.logger import pruna_logger
Expand All @@ -28,10 +27,6 @@ class DiffuserHandler(InferenceHandler):
"""
Handle inference arguments, inputs and outputs for diffusers models.

A generator with a fixed seed (42) is passed as an argument to the model for reproducibility.
The first element of the batch is passed as input to the model.
The generated outputs are expected to have .images attribute.

Parameters
----------
call_signature : inspect.Signature
Expand All @@ -40,12 +35,18 @@ class DiffuserHandler(InferenceHandler):
The arguments to pass to the model.
"""

def __init__(self, call_signature: inspect.Signature, model_args: Optional[Dict[str, Any]] = None) -> None:
default_args = {"generator": torch.Generator("cpu").manual_seed(42)}
def __init__(
self,
call_signature: inspect.Signature,
model_args: Optional[Dict[str, Any]] = None,
seed_strategy: Literal["per_sample", "no_seed"] = "no_seed",
global_seed: int | None = None,
) -> None:
self.call_signature = call_signature
if model_args:
default_args.update(model_args)
self.model_args = default_args
self.model_args = model_args if model_args else {}
# We want the default output type to be pytorch tensors.
self.model_args["output_type"] = "pt"
self.configure_seed(seed_strategy, global_seed)

def prepare_inputs(
self, batch: List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor | dict[str, Any], ...] | dict[str, Any]
Expand All @@ -63,8 +64,14 @@ def prepare_inputs(
Any
The prepared inputs.
"""
# Many diffusers pipelines accept `prompt`, but it is not always the first positional argument
# (e.g. some pipelines have `image` first and `prompt` second). To be robust, pass prompts
# as a keyword argument whenever possible.
if "prompt" in self.call_signature.parameters or "args" in self.call_signature.parameters:
x, _ = batch
if "prompt" in self.call_signature.parameters:
return {"prompt": x}
# Fallback: pipelines that use *args without a named `prompt`
return x
else: # Unconditional generation models
return None
Expand All @@ -83,13 +90,36 @@ def process_output(self, output: Any) -> torch.Tensor:
torch.Tensor
The processed images.
"""
generated = output.images
return torch.stack([transforms.PILToTensor()(g) for g in generated])
if hasattr(output, "images"):
generated = output.images
# For video models.
elif hasattr(output, "frames"):
generated = output.frames
else:
# Maybe the user is calling the pipeline with return_dict = False,
# which then returns the generated image / video in a tuple
generated = output[0]
return generated.float()

def log_model_info(self) -> None:
"""Log information about the inference handler."""
pruna_logger.info(
"Detected diffusers model. Using DiffuserHandler with fixed seed.\n"
"- The first element of the batch is passed as input.\n"
"- The generated outputs are expected to have .images attribute."
"Detected diffusers model. Using DiffuserHandler.\n- The first element of the batch is passed as input.\n"
"Inference outputs are expected to have either have an `images` attribute or a `frames` attribute."
"Or be a tuple with the generated image / video as the first element."
)

def set_seed(self, seed: int) -> None:
"""
Set the random seed for the current process.

Parameters
----------
seed : int
The seed to set.
"""
self.model_args["generator"] = torch.Generator("cpu").manual_seed(seed)

def remove_seed(self) -> None:
"""Remove the seed from the current process."""
self.model_args["generator"] = None
76 changes: 75 additions & 1 deletion src/pruna/engine/handler/handler_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

from __future__ import annotations

import random
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Literal, Tuple

import numpy as np
import torch

from pruna.data.utils import move_batch_to_device
Expand Down Expand Up @@ -98,3 +100,75 @@ def move_inputs_to_device(
return move_batch_to_device(inputs, device)
except torch.cuda.OutOfMemoryError as e:
raise e

def configure_seed(self, seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None:
"""
Set the random seed according to the chosen strategy.

- If `seed_strategy="per_sample"`,the `global_seed` is used as a base to derive a different seed for each
sample. This ensures reproducibility while still producing variation across samples,
making it the preferred option for benchmarking.
- If `seed_strategy="no_seed"`, no seed is set internally.
The user is responsible for managing seeds if reproducibility is required.

Parameters
----------
seed_strategy : Literal["per_sample", "no_seed"]
The seeding strategy to apply.
global_seed : int | None
The base seed value to use (if applicable).
"""
self.seed_strategy = seed_strategy
validate_seed_strategy(seed_strategy, global_seed)
if global_seed is not None:
self.global_seed = global_seed
self.set_seed(global_seed)
else:
self.remove_seed()

def set_seed(self, seed: int) -> None:
"""
Set the random seed for the current process.

Parameters
----------
seed : int
The seed to set.
"""
# With the default handler, we can't assume anything about the model,
# so we are setting the seed for all RNGs available.
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

def remove_seed(self) -> None:
"""Remove the seed from the current process."""
random.seed(None)
np.random.seed(None)
# We can't really remove the seed from the PyTorch RNG, so we are reseeding with torch.seed().
# torch.seed() creates a non-deterministic random number.
torch.manual_seed(torch.seed())
if torch.cuda.is_available():
torch.cuda.manual_seed_all(torch.seed())


def validate_seed_strategy(seed_strategy: Literal["per_sample", "no_seed"], global_seed: int | None) -> None:
"""
Check the consistency of the seed strategy and the global seed.

If the seed strategy is "no_seed", the global seed must be None.
If the seed strategy is or "per_sample", the user must provide a global seed.

Parameters
----------
seed_strategy : Literal["per_sample", "no_seed"]
The seeding strategy to apply.
global_seed : int | None
The base seed value to use (if applicable).
"""
if seed_strategy != "no_seed" and global_seed is None:
raise ValueError("Global seed must be provided if seed strategy is not 'no_seed'.")
elif global_seed is not None and seed_strategy == "no_seed":
raise ValueError("Seed strategy cannot be 'no_seed' if global seed is provided.")
4 changes: 3 additions & 1 deletion src/pruna/engine/pruna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pruna.config.smash_config import SmashConfig
from pruna.engine.handler.handler_utils import register_inference_handler
from pruna.engine.load import load_pruna_model, load_pruna_model_from_pretrained
from pruna.engine.load import filter_load_kwargs, load_pruna_model, load_pruna_model_from_pretrained
from pruna.engine.save import save_pruna_model, save_pruna_model_to_hub
from pruna.engine.utils import get_device, get_nn_modules, set_to_eval
from pruna.logging.filter import apply_warning_filter
Expand Down Expand Up @@ -108,6 +108,8 @@ def run_inference(self, batch: Any) -> Any:
)
inference_function = getattr(self, inference_function_name)

self.inference_handler.model_args = filter_load_kwargs(self.model.__call__, self.inference_handler.model_args)

if prepared_inputs is None:
outputs = inference_function(**self.inference_handler.model_args)
elif isinstance(prepared_inputs, dict):
Expand Down
Loading
Loading