diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index cc94a221f..60bf91178 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -303,3 +303,76 @@ def is_data_dependent_shape_op(op_type: str): "NonZero", "RoiAlign", ] + + +def get_bool_operations(): + """Returns set of boolean/comparison operations that are not quantizable.""" + return { + "Not", + "And", + "Or", + "Xor", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + "BitShift", + "IsNaN", + "IsInf", + "Sign", + "Abs", + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + "Where", + "Max", + "Min", + "Mean", + "Median", + "ArgMax", + "ArgMin", + "ReduceMax", + "ReduceMin", + "ReduceSum", + "ReduceMean", + "All", + "Any", + "Unique", + "NonZero", + "TopK", + } + + +def get_autotuner_skip_ops(): + """Returns set of shape/structural operations that are not quantizable. + + These operations manipulate tensor structure or perform indexing rather than + computing on tensor values, making them unsuitable for quantization. + """ + indexing_ops = {"Compress", "Gather", "GatherElements", "GatherND", "Slice"} + scatter_ops = {"Scatter", "ScatterND"} + reshape_ops = {"ExpandDims", "Flatten", "Squeeze", "Unsqueeze", "View"} + rearrangement_ops = {"Concat", "Pad", "Split", "Tile", "Transpose"} + utility_ops = {"Cast", "Ceil", "Clip", "Identity", "Range", "Shape"} + + return indexing_ops | scatter_ops | reshape_ops | rearrangement_ops | utility_ops + + +def get_autotuner_quantizable_operations(): + """Returns set of key operations that benefit from quantization.""" + return { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + "Add", + "Sum", + "Mul", + "Relu", + } diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py new file mode 100644 index 000000000..f32206c86 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/common.py @@ -0,0 +1,717 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common data structures and types for the QDQ Autotuner. + +This module provides the foundational classes used throughout the autotuner: + +**Exceptions:** +- Region-related: RegionError +- Autotuner-related: AutotunerError, AutotunerNotInitializedError, InvalidSchemeError + +**Region Hierarchy:** +- Region: Hierarchical subgraph representation with parent/child relationships +- RegionType: Enumeration for LEAF, COMPOSITE, and ROOT regions + +**Q/DQ Insertion Specifications:** +- InsertionScheme: Collection of insertion points with performance metrics + +**Scheme Management:** +- PatternSchemes: Multiple insertion schemes for a pattern (applies to all matching regions) +- PatternCache: Collection of top schemes for multiple patterns, used as autotuning seeds + +**Configuration:** +- Config: Autotuning parameters and Q/DQ default values +""" + +import hashlib +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional + +import onnx_graphsurgeon as gs + +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + NodeInputInsertionPoint, + RegionOutputInsertionPoint, +) + +# Module logger +logger = logging.getLogger(__name__) + + +# Region-related Exceptions +class RegionError(Exception): + """Base exception for region-related errors.""" + + +# Autotuner-related Exceptions +class AutotunerError(Exception): + """Base exception for autotuner-related errors.""" + + +class AutotunerNotInitializedError(AutotunerError): + """Exception raised when autotuner is used without initialization.""" + + +class InvalidSchemeError(AutotunerError): + """Exception raised when an invalid scheme is referenced.""" + + +class RegionType(Enum): + """Region type enumeration for hierarchical graph structure. + + - LEAF: Atomic region containing direct nodes with no child regions + - COMPOSITE: Hierarchical region containing child regions (and optionally direct nodes) + - ROOT: Top-level region encompassing the entire computation graph + """ + + LEAF = "LEAF" + COMPOSITE = "COMPOSITE" + ROOT = "ROOT" + + +class Region: + """Hierarchical subgraph region in an ONNX computation graph. + + A Region represents a cohesive subgraph with well-defined boundaries, supporting: + + **Hierarchical Structure:** + - Parent/child relationships forming a multi-level hierarchy + - LEAF regions contain only direct nodes + - COMPOSITE regions contain child regions (and optionally direct nodes) + - ROOT regions encompass the entire graph + + **Node Management:** + - Direct nodes: Operations directly in this region (not in children) + - Recursive nodes: All operations including those in descendant regions + + **Boundary Tracking:** + - Input tensors: Data entering the region from outside + - Output tensors: Data leaving the region to outside consumers + + **Pattern Matching:** + - Regions with identical structure share the same pattern signature + - Pattern-based optimization applies schemes to all matching regions + + Regions are the fundamental unit for Q/DQ insertion and optimization. + """ + + def __init__(self, region_id: int, level: int, region_type: RegionType): + """Initialize a new region. + + Args: + region_id: Unique identifier within the region hierarchy + level: Hierarchical level (0 = leaf, higher = more composite) + region_type: Type classification (LEAF, COMPOSITE, or ROOT) + """ + self.id = region_id + self.level = level + self.type = region_type + self.parent: Region | None = None + self.children: list[Region] = [] + self.nodes: set[int] = set() + self.inputs: list[str] = [] + self.outputs: list[str] = [] + self.metadata: dict[str, str] = {} + + # ========================================================================= + # Basic Accessors + # ========================================================================= + + def get_id(self) -> int: + """Get region ID.""" + return self.id + + def set_id(self, region_id: int) -> None: + """Set region ID (for RegionBuilder use).""" + self.id = region_id + + def get_level(self) -> int: + """Get region level in hierarchy.""" + return self.level + + def set_level(self, level: int) -> None: + """Set region level in hierarchy (for RegionBuilder use).""" + self.level = level + + def get_type(self) -> RegionType: + """Get region type.""" + return self.type + + def set_type(self, region_type: RegionType) -> None: + """Set region type (for RegionBuilder use).""" + self.type = region_type + + # ========================================================================= + # Hierarchy Management + # ========================================================================= + + def get_parent(self) -> Optional["Region"]: + """Get parent region.""" + return self.parent + + def set_parent(self, parent: Optional["Region"]) -> None: + """Set parent region.""" + self.parent = parent + + def get_children(self, *, sort: bool = False) -> list["Region"]: + """Get all child regions. + + Args: + sort: If True, return children sorted by (-level, size_of_region_and_descendants). + This ordering ensures deterministic pattern matching. Defaults to False. + + Returns: + List of child regions, optionally sorted. + """ + if not sort: + return self.children + return sorted( + self.children, key=lambda r: (-r.get_level(), r.get_size_of_region_and_descendants()) + ) + + def remove_child(self, child: "Region") -> bool: + """Remove a child region from this region's children list. + + Args: + child: The child region to remove + + Returns: + True if child was found and removed, False otherwise + """ + try: + self.children.remove(child) + if child.parent and child.parent.get_id() == self.id: + child.set_parent(None) + return True + except ValueError: + return False + + def add_child(self, child: "Region") -> None: + """Add a child sub-region.""" + # Prevent adding self as child + if child.get_id() == self.id: + logger.warning(f"Cannot add region {self.id} as its own child") + return + + # Prevent creating cycles: check if self is already a descendant of child + if self.is_descendant_of(child): + logger.warning( + f"Cycle detected: region {self.id} is already a descendant of region {child.get_id()}" + ) + return + + # Check if child already has a different parent + if child.parent is not None and child.parent.get_id() != self.id: + old_parent_id = child.parent.get_id() + logger.debug( + f"Re-parenting region {child.get_id()}: moving from parent {old_parent_id} to {self.id}" + ) + # Remove from old parent to maintain tree structure + child.parent.remove_child(child) + + # Check if child is already in children list + if any(c.get_id() == child.get_id() for c in self.children): + logger.debug(f"Region {child.get_id()} already child of {self.id}") + return + + self.children.append(child) + child.set_parent(self) + + def is_descendant_of(self, potential_ancestor: "Region") -> bool: + """Check if this region is a descendant of potential_ancestor.""" + visited = set() + current = self.parent + while current: + if current.get_id() in visited: + # Already visited, there's a cycle in parents + return False + visited.add(current.get_id()) + if current.get_id() == potential_ancestor.get_id(): + return True + current = current.parent + return False + + # ========================================================================= + # Node Management + # ========================================================================= + + def add_node(self, node_index: int) -> None: + """Add a node index to this region.""" + self.nodes.add(node_index) + + def add_nodes(self, node_indices: list[int]) -> None: + """Add multiple node indices to this region.""" + self.nodes.update(node_indices) + + def get_nodes(self, *, sort: bool = False) -> list[int]: + """Get direct node indices in this region only. + + Returns only nodes directly owned by this region, excluding nodes + in child regions. Use get_region_nodes_and_descendants() for complete coverage. + + Args: + sort: If True, return nodes in sorted order. Defaults to False. + + Returns: + List of node indices (absolute positions in the graph). + """ + if sort: + return sorted(self.nodes) + return list(self.nodes) + + def get_region_nodes_and_descendants(self, _visited: set[int] | None = None) -> set[int]: + """Get all node indices recursively, including descendants. + + Traverses the entire subtree rooted at this region, collecting nodes + from this region and all child regions recursively. + + Args: + _visited: Internal parameter for cycle detection (do not use) + + Returns: + Set of all node indices in this region and its descendants + """ + if _visited is None: + _visited = set() + + # Detect cycles + assert self.id not in _visited, f"Cycle detected in region {self.id} during node traversal" + + _visited.add(self.id) + all_nodes = set(self.nodes) + for child in self.children: + all_nodes.update(child.get_region_nodes_and_descendants(_visited)) + return all_nodes + + def contains_node(self, node_index: int) -> bool: + """Check if region contains a specific node (direct only).""" + return node_index in self.nodes + + def contains_node_within_region_and_descendants( + self, node_index: int, _visited: set[int] | None = None + ) -> bool: + """Check if region contains a node recursively.""" + if _visited is None: + _visited = set() + + # Detect cycles + assert self.id not in _visited, f"Cycle detected in region {self.id} during node traversal" + + _visited.add(self.id) + + if self.contains_node(node_index): + return True + return any( + child.contains_node_within_region_and_descendants(node_index, _visited) + for child in self.children + ) + + # ========================================================================= + # Input/Output Management + # ========================================================================= + + def add_input(self, tensor_name: str) -> None: + """Add an input tensor name.""" + if tensor_name not in self.inputs: + self.inputs.append(tensor_name) + + def add_output(self, tensor_name: str) -> None: + """Add an output tensor name.""" + if tensor_name not in self.outputs: + self.outputs.append(tensor_name) + + def get_inputs(self) -> list[str]: + """Get region input tensors.""" + return self.inputs + + def get_outputs(self) -> list[str]: + """Get region output tensors.""" + return self.outputs + + # ========================================================================= + # Size and Query Methods + # ========================================================================= + + def get_size(self) -> int: + """Get the number of direct nodes in this region. + + Returns: + Count of nodes directly in this region (excludes child regions) + """ + return len(self.nodes) + + def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int: + """Get total node count recursively including all descendants. + + Computes the sum of nodes in this region and all child regions, + providing the total footprint of the region subtree. + + Args: + _visited: Internal parameter for cycle detection (do not use) + + Returns: + Total number of nodes in this region and all descendants + """ + if _visited is None: + _visited = set() + + # Detect cycles + assert self.id not in _visited, ( + f"Cycle detected in region {self.id} during size calculation" + ) + + _visited.add(self.id) + total = len(self.nodes) + for child in self.children: + total += child.get_size_of_region_and_descendants(_visited) + return total + + # ========================================================================= + # Region Operations + # ========================================================================= + + def merge(self, other: "Region") -> None: + """Merge another region into this one. + + Combines the nodes and children from the other region into this region. + The other region's children become children of this region, updating + their parent references accordingly. + + Args: + other: Region to merge into this one + """ + if not other: + return + # Merge direct nodes + self.nodes.update(other.nodes) + # Merge children (updates their parent references) + for child in other.children: + self.add_child(child) + + # ========================================================================= + # Metadata Management + # ========================================================================= + + def set_metadata(self, key: str, value: str) -> None: + """Set region metadata.""" + self.metadata[key] = value + + def get_metadata(self, key: str) -> str: + """Get region metadata.""" + return self.metadata.get(key, "") + + # ========================================================================= + # String Representation + # ========================================================================= + + def to_string(self) -> str: + """Print region information for debugging.""" + type_str = self.type.value + return ( + f"Region[id={self.id}, level={self.level}, type={type_str}, " + f"nodes={len(self.nodes)}, children={len(self.children)}, " + f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" + ) + + def __str__(self) -> str: + return self.to_string() + + def __repr__(self) -> str: + return self.to_string() + + def compute_structural_signature(self, graph: gs.Graph) -> str: + """Compute deterministic structural signature for pattern matching. + + Creates a signature that uniquely identifies the region's topology, + node operations, and hierarchical structure. Regions with identical + signatures can share Q/DQ insertion schemes. + + The signature captures: + - Node operation types and key parameters + - Hierarchical structure (child regions) + - Deterministic ordering (sorted for consistency) + + Args: + graph: The ONNX graph containing the region's nodes + + Returns: + Signature string (e.g., "Conv->BatchNorm->Relu" or "COMPOSITE(...)") + """ + raise NotImplementedError("Not implemented") + + +# ============================================================================= +# Autotuner Q/DQ Insertion Specifications +# ============================================================================= + + +@dataclass +class InsertionScheme: + """Complete Q/DQ insertion specification for a region pattern. + + An InsertionScheme defines a complete Q/DQ configuration for a pattern, + combining both node-level and region-level insertion points. The scheme + is applied to all regions matching the pattern. + + **Scheme Identity:** + - Uniquely identified by the combination of insertion points (computed hash) + - latency_ms is a measured performance metric, not part of identity + - Two schemes with same insertion points but different latencies are considered identical + + **Application:** + - Node insertion points: Q/DQ at node inputs within the pattern + - Region insertion points: Q/DQ at child region boundaries (COMPOSITE only) + - All are resolved to actual configurations for each matching region + + **Performance Tracking:** + - latency_ms: Measured performance (inf = not yet measured) + - error: Whether this scheme encountered an error during measurement + - Used to select the best scheme for each pattern + + **Attributes:** + node_inputs: Q/DQ insertions at node inputs (list of NodeInputInsertionPoint) + child_region_inputs: Q/DQ insertions at child boundaries (list of ChildRegionInputInsertionPoint) + region_outputs: Q/DQ insertions at region outputs (list of RegionOutputInsertionPoint) + latency_ms: Measured latency in milliseconds (inf if not measured) + error: True if scheme measurement failed, False otherwise + profile_timestamp: ISO format timestamp when this scheme was profiled (None if not yet profiled) + """ + + node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) + child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) + region_outputs: list[RegionOutputInsertionPoint] = field(default_factory=list) + latency_ms: float = float("inf") + error: bool = False + profile_timestamp: str | None = None + + @property + def hash(self) -> str: + """Compute deterministic hash for scheme identity. + + The hash uniquely identifies this scheme configuration based on its + insertion points. Two schemes with identical insertion points produce + the same hash, regardless of their measured latencies. + + **Hash Input:** + - Sorted node_inputs (for deterministic ordering) + - Sorted child_region_inputs (for deterministic ordering) + - Sorted region_outputs (for deterministic ordering) + + **Use Cases:** + - Detect duplicate schemes before measurement + - Group schemes by configuration + - Efficient scheme comparison + + Returns: + 32-character hexadecimal string (SHA-256 truncated to 128 bits) + """ + # Sort points for deterministic hashing + sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) + sorted_regions = sorted( + [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] + ) + sorted_region_outputs = sorted( + [(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs] + ) + + # Create hash input string + hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}" + + # Compute SHA-256 hash (128 bits) + return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32] + + @property + def is_empty(self) -> bool: + """Check if this is a baseline scheme with no Q/DQ insertions. + + Returns: + True if scheme has no node/region insertion points + """ + return ( + len(self.node_inputs) == 0 + and len(self.child_region_inputs) == 0 + and len(self.region_outputs) == 0 + ) + + @property + def has_error(self) -> bool: + """Check if this scheme encountered an error during measurement. + + Returns: + True if scheme has error=True, False otherwise + """ + return self.error + + @property + def is_profiled(self) -> bool: + """Check if this scheme has been profiled (measured). + + A scheme is considered profiled if it has been measured (has non-infinite latency) + or has encountered an error during measurement. + + Returns: + True if scheme has been measured (latency_ms != inf) or has error, + False if scheme is waiting to be profiled (error=False and latency_ms=inf) + """ + return self.error or self.latency_ms != float("inf") + + @property + def num_node_insertions(self) -> int: + """Get count of node-level Q/DQ insertion points. + + Returns: + Number of NodeInputInsertionPoint entries + """ + return len(self.node_inputs) + + @property + def num_region_insertions(self) -> int: + """Get count of region-level Q/DQ insertion points. + + These specify Q/DQ insertions at child region boundaries within + COMPOSITE regions. + + Returns: + Number of ChildRegionInputInsertionPoint entries + """ + return len(self.child_region_inputs) + + @property + def num_region_output_insertions(self) -> int: + """Get count of region output insertion points. + + These specify Q/DQ insertions at outputs from child regions or nodes. + + Returns: + Number of RegionOutputInsertionPoint entries + """ + return len(self.region_outputs) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "latency_ms": self.latency_ms, + "error": self.error, + "profile_timestamp": self.profile_timestamp, + "nodes_insertion_points": [pt.to_dict() for pt in self.node_inputs], + "child_region_inputs": [pt.to_dict() for pt in self.child_region_inputs], + "region_outputs": [pt.to_dict() for pt in self.region_outputs], + "hash": self.hash, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": + """Create InsertionScheme from serialized dictionary. + + Reconstructs the insertion scheme from saved data, including node and + region insertion points. The hash is automatically recomputed from all + components to ensure consistency. + + Args: + data: Dictionary containing 'latency_ms', 'nodes_insertion_points', + 'child_region_inputs', and 'region_outputs' keys + + Returns: + Reconstructed InsertionScheme instance + """ + scheme = cls() + scheme.latency_ms = data.get("latency_ms", float("inf")) + scheme.error = data.get("error", False) + scheme.profile_timestamp = data.get("profile_timestamp") + + scheme.node_inputs = [ + NodeInputInsertionPoint.from_dict(pt) for pt in data.get("nodes_insertion_points", []) + ] + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint.from_dict(pt) + for pt in data.get("child_region_inputs", []) + ] + scheme.region_outputs = [ + RegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) + ] + + # Note: hash is computed from points, so we don't load it from dict + # This ensures consistency even if stored hash differs + + return scheme + + def distance(self, other: "InsertionScheme") -> int: + """Compute edit distance between this scheme and another scheme. + + The edit distance is the minimum number of add/remove operations needed + to transform this scheme into the other scheme. This is computed as the + symmetric difference between the insertion point sets. + + **Distance Calculation:** + - Counts insertion points in self but not in other (need to be removed) + - Counts insertion points in other but not in self (need to be added) + - Considers all three types of insertion points: + * node_inputs + * child_region_inputs + * region_outputs + + Args: + other: InsertionScheme to compare against + + Returns: + Total edit distance (number of add + remove operations) + + Example: + >>> scheme1 = InsertionScheme( + ... node_inputs=[ + ... NodeInputInsertionPoint(0, 0), + ... NodeInputInsertionPoint(1, 0), + ... ] + ... ) + >>> scheme2 = InsertionScheme( + ... node_inputs=[ + ... NodeInputInsertionPoint(0, 0), + ... NodeInputInsertionPoint(2, 0), + ... ] + ... ) + >>> scheme1.distance(scheme2) # 2 (remove (1,0), add (2,0)) + 2 + """ + # Convert insertion points to sets for efficient set operations + self_nodes = set(self.node_inputs) + other_nodes = set(other.node_inputs) + + self_regions = set(self.child_region_inputs) + other_regions = set(other.child_region_inputs) + + self_region_outputs = set(self.region_outputs) + other_region_outputs = set(other.region_outputs) + + # Compute symmetric difference (elements in either set but not both) + # This gives us the total number of add + remove operations + node_distance = len(self_nodes.symmetric_difference(other_nodes)) + region_distance = len(self_regions.symmetric_difference(other_regions)) + region_output_distance = len(self_region_outputs.symmetric_difference(other_region_outputs)) + + return node_distance + region_distance + region_output_distance + + def __str__(self) -> str: + """String representation for debugging.""" + error_str = ", error=True" if self.error else "" + return ( + f"InsertionScheme(node_insertions={self.num_node_insertions}, " + f"region_insertions={self.num_region_insertions}, " + f"region_output_insertions={self.num_region_output_insertions}, " + f"latency={self.latency_ms:.3f}ms{error_str})" + ) diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py new file mode 100644 index 000000000..e60a3af99 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -0,0 +1,822 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Q/DQ Insertion Point Management for ONNX Quantization. + +This module provides data structures and utilities for managing Quantization/Dequantization (Q/DQ) +insertion points in ONNX computational graphs during autotune optimization. It enables pattern-based +Q/DQ insertion that can be reused across multiple matching regions in a model. + +Core Concepts: +-------------- +1. **Pattern-Relative Insertion Points**: Insertion points are defined relative to region patterns + rather than absolute node IDs, enabling scheme reuse across all matching regions. + +2. **Resolution Process**: Pattern-relative indices are resolved to actual tensor names for each + specific region instance, then Q/DQ pairs are inserted at the resolved locations. + +3. **Hierarchical Support**: Supports Q/DQ insertion at multiple levels: + - Node inputs within regions + - Child region boundaries (inputs/outputs) + - Region outputs + +Classes: +-------- +- InsertionPoint: Abstract base class for pattern-relative insertion points +- ResolvedInsertionPoint: Resolved Q/DQ insertion point with actual tensor name +- NodeInputInsertionPoint: Pattern-relative insertion point at node inputs +- ChildRegionInputInsertionPoint: Pattern-relative insertion point at child region inputs +- RegionOutputInsertionPoint: Pattern-relative insertion point at region/node outputs + +Utilities: +---------- +- skip_invalid_insertion_points(): Filter out non-quantizable tensors +- has_quantizable_operations(): Check if region contains major quantizable ops +- resolve_region_io_insertion_points(): Resolve region I/O to actual insertion points +- merge_resolved_insertion_points(): Merge insertion points when all users are quantized +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import onnx_graphsurgeon as gs + +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.common import Region + +from modelopt.onnx.op_types import ( + get_autotuner_quantizable_operations, + get_autotuner_skip_ops, + get_bool_operations, +) +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + + +class InsertionPoint(ABC): + """Abstract base class for pattern-relative Q/DQ insertion points. + + Defines the common interface for all insertion point types that specify + where to insert Q/DQ pairs within region patterns. Concrete implementations + include: + - NodeInputInsertionPoint: Insertion at node inputs + - ChildRegionInputInsertionPoint: Insertion at child region input boundaries + - RegionOutputInsertionPoint: Insertion at region/node outputs + + All insertion points support: + - Serialization to/from dictionaries + - Resolution to actual tensor names for specific region instances + - Collection of valid insertion points from regions + """ + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + ... + + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionPoint": + """Create from dictionary.""" + ... + + @abstractmethod + def resolve(self, region: "Region", graph: gs.Graph) -> set["ResolvedInsertionPoint"]: + """Resolve pattern-relative insertion point to actual tensor names. + + Args: + region: The region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names + """ + ... + + @staticmethod + @abstractmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["InsertionPoint"]: + """Collect all valid insertion points of this type from a region. + + Args: + region: The region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of insertion point objects representing valid insertion locations + """ + ... + + +@dataclass(frozen=True) +class ResolvedInsertionPoint: + """Resolved Q/DQ insertion point with actual tensor name and optional node context. + + After resolving pattern-relative insertion points, this class represents the + actual location where Q/DQ pairs should be inserted in the graph. + + **Insertion Modes:** + 1. Node-specific insertion (node_index and input_index are set): + - Inserts Q/DQ at a specific input of a specific node + - More precise control over where quantization happens + 2. Tensor-level insertion (node_index and input_index are None): + - Inserts Q/DQ for all users of the tensor + - Used when all consumers of a tensor should be quantized together + + **Attributes:** + - tensor_name: Name of the tensor where Q/DQ should be inserted + - node_index: Absolute graph node index (not pattern-relative), or None for tensor-level insertion + - input_index: Input tensor index of that node, or None for tensor-level insertion + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + tensor_name: str + node_index: int | None = None # Absolute graph node index (or None for tensor-level insertion) + input_index: int | None = None # Input tensor index of that node (or None) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "tensor_name": self.tensor_name, + "node_index": self.node_index, + "input_index": self.input_index, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ResolvedInsertionPoint": + """Create from dictionary.""" + return cls( + tensor_name=data["tensor_name"], + node_index=data["node_index"], + input_index=data.get("input_index"), + ) + + +@dataclass(frozen=True) +class NodeInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a node's input. + + Specifies where to insert a Q/DQ pair within a region pattern using + pattern-relative indices rather than absolute node IDs. This enables + insertion scheme reuse across all regions matching the same pattern. + + **Resolution Process:** + 1. Pattern-relative indices (node_index, input_index) are defined once + 2. For each matching region, indices are resolved to actual tensor names + 3. Q/DQ pairs are inserted at the resolved tensor locations + + **Example:** + - NodeInputInsertionPoint(node_index=0, input_index=1) + - Resolves to: the second input (index 1) of the first node (index 0) in the pattern + - Actual tensor name depends on the specific region instance + + **Attributes:** + - node_index: Index of the node within the pattern's sorted node list (0-based) + - input_index: Index of the input tensor for that node (0-based) + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + node_index: int # Pattern-relative node index + input_index: int # Input tensor index of that node + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return {"node_index": self.node_index, "input_index": self.input_index} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeInputInsertionPoint": + """Create from dictionary.""" + return cls(node_index=data["node_index"], input_index=data["input_index"]) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a node input insertion point to actual tensor names for a matching region. + + Converts pattern-relative node/input indices to absolute node indices and actual + tensor names in the graph. Special handling for Conv/ConvTranspose operations + automatically includes weight quantization when input is quantized. + + Args: + region: The region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names + """ + nodes_list = list(graph.nodes) + node_indices = region.get_nodes(sort=True) + resolved_ips = set() + + # Map from pattern-relative node index to absolute graph node index + assert self.node_index < len(node_indices), "Node index out of range" + actual_node_idx = node_indices[self.node_index] + assert actual_node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[actual_node_idx] + assert self.input_index < len(node.inputs), "Input index out of range" + + # Resolve the input tensor name using input_index + inp = node.inputs[self.input_index] + if hasattr(inp, "name") and inp.name: + ip = ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=self.input_index + ) + resolved_ips.add(ip) + + if node.op in ["Conv", "ConvTranspose"]: + assert self.input_index == 0, ( + "Conv and ConvTranspose inputs and weights should be quantized at same time" + ) + assert len(node.inputs) >= 2, "Conv and ConvTranspose should have at least 2 inputs" + inp = node.inputs[1] + if hasattr(inp, "name") and inp.name: + ip = ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=1 + ) + resolved_ips.add(ip) + + return resolved_ips + + @staticmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputInsertionPoint"]: + """Collect all valid node input insertion points from a region. + + Analyzes each node in the region and identifies all valid input tensors + where Q/DQ pairs could be inserted. Filters out invalid insertion points + using skip_invalid_insertion_points(). + + Args: + region: The region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of NodeInputInsertionPoint objects representing valid insertion locations + """ + nodes_list = list(graph.nodes) + node_indices = region.get_nodes(sort=True) + + node_input_insertion_points = [] + for local_idx, node_idx in enumerate(node_indices): + assert node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[node_idx] + # Analyze each input of the node + for input_idx, inp in enumerate(node.inputs): + # Skip if tensor doesn't have a valid name + if not (hasattr(inp, "name") and inp.name): + continue + # Skip if insertion point is invalid (wrong dtype, small size, special input, etc.) + if skip_invalid_insertion_points(graph, inp.name, node): + continue + # Create insertion point for valid tensor + ip = NodeInputInsertionPoint( + # Pattern-relative node index + node_index=local_idx, + input_index=input_idx, + ) + node_input_insertion_points.append(ip) + + return node_input_insertion_points + + +@dataclass(frozen=True) +class ChildRegionInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region's input boundary. + + Specifies where to insert Q/DQ pairs at the input boundaries of child regions + within COMPOSITE regions. This allows parent regions to control quantization + at child boundaries, potentially overriding or complementing child region + optimizations. + + **Use Case:** + Parent regions can insert Q/DQ pairs at child region inputs to: + - Add quantization at child boundaries even if the child has no internal Q/DQ + - Override or supplement the child's own boundary Q/DQ decisions + - Apply different quantization schemes based on the parent context + + **Resolution Process:** + 1. Pattern-relative indices (region_index, input_index) are defined once + 2. For each matching parent region, indices resolve to actual child boundaries: + - region_index identifies which child region (in parent's sorted child list) + - input_index identifies which input tensor of that child region + 3. Q/DQ pairs are inserted at the resolved child input tensor locations + + **Example:** + - ChildRegionInputInsertionPoint(region_index=0, input_index=1) + - Resolves to: the second input tensor (index 1) of the first child region (index 0) + - Actual tensor name depends on the specific parent/child region instances + + **Note:** Only applies to COMPOSITE regions. LEAF regions have no children, + so child region insertion points have no effect there. + + **Attributes:** + - region_index: Index of the child region within the parent pattern's sorted child list (0-based) + - input_index: Index of the input tensor for that child region (0-based) + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Pattern-relative child region index + region_index: int + # Input tensor index of that child region + input_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return {"region_index": self.region_index, "input_index": self.input_index} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionInputInsertionPoint": + """Create from dictionary. + + Backward compatible: Ignores obsolete fields like 'child_region_id' + from older serialization formats. + + Args: + data: Dictionary with 'region_index' and 'input_index' keys + + Returns: + ChildRegionInputInsertionPoint instance + """ + # Ignore child_region_id if present in old data + return cls(region_index=data["region_index"], input_index=data["input_index"]) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a child region input insertion point to actual tensor names for a matching region. + + Converts pattern-relative child region index and input index to the actual tensor + name at that child region's input boundary, then resolves to all node inputs that + consume that tensor. + + Args: + region: The parent region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names. + Returns empty set for LEAF regions (no children). + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + if graph is None: + raise ValueError("graph parameter is required") + + # LEAF regions have no child boundaries + if region.get_type() == RegionType.LEAF: + return set() + + # Get sorted child regions (must match order in RegionPattern._compute_signature_recursive) + children_regions = region.get_children(sort=True) + # Map from pattern-relative child index to actual child region + resolved_ips = set() + assert self.region_index < len(children_regions), "Child region index out of range" + child_region = children_regions[self.region_index] + assert self.input_index < len(child_region.get_inputs()), "Input index out of range" + # Resolve the input tensor name using input_index + tensor_name = child_region.get_inputs()[self.input_index] + assert tensor_name is not None, "Tensor name is required" + resolved_ips.update(resolve_region_io_insertion_points(child_region, graph, tensor_name)) + + return resolved_ips + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionInputInsertionPoint"]: + """Collect all valid child region input insertion points from a region. + + For COMPOSITE regions, analyzes each child region and identifies all valid + input tensors where Q/DQ pairs could be inserted at child boundaries. + Returns empty list for LEAF regions (no children). + + Args: + region: The parent region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of ChildRegionInputInsertionPoint objects representing valid insertion locations + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + child_region_input_insertion_points = [] + + # Only COMPOSITE regions have child boundaries for Q/DQ insertion + if region.get_type() != RegionType.LEAF: + # Get all child regions, sorted for deterministic ordering + # Must match sorting in _compute_signature_recursive to ensure + # insertion point indices align with pattern structure + children_regions = region.get_children(sort=True) + + for local_idx, child_region in enumerate(children_regions): + # Create insertion point for each input tensor of the child region + for input_idx, inp in enumerate(child_region.get_inputs()): + if skip_invalid_insertion_points(graph, inp, child_region): + continue + point = ChildRegionInputInsertionPoint( + # Child region index within parent pattern + region_index=local_idx, + # Input index within child region + input_index=input_idx, + ) + child_region_input_insertion_points.append(point) + + return child_region_input_insertion_points + + +@dataclass(frozen=True) +class RegionOutputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at an output location. + + Specifies where to insert Q/DQ pairs at output boundaries. This can be either: + 1. Output from a child region (in COMPOSITE regions) + 2. Output from a node within the region + + **Use Case:** + Parent regions can: + - Add Q/DQ at child region output boundaries + - Add Q/DQ at node outputs within the region + - Control quantization precision as data flows through the region hierarchy + + **Resolution Process:** + 1. Pattern-relative indices are defined once + 2. If output is from a child region: use region_index (node_index is None) + - region_index identifies which child region (in sorted order) + - output_index identifies which output tensor of that child region + 3. If output is from a node: use node_index (region_index is None) + - node_index identifies which node (in sorted order) + - output_index identifies which output tensor of that node + 4. Resolves to the actual tensor name at that output location + + **Examples:** + - RegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) + → First output of the first child region + - RegionOutputInsertionPoint(region_index=None, node_index=2, output_index=1) + → Second output of the third node + + **Note:** Exactly one of region_index or node_index must be set (the other must be None). + + **Attributes:** + - region_index: Index of child region within parent pattern (0-based), or None + - node_index: Index of node within the region (0-based), or None + - output_index: Index of the output tensor (0-based) + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + region_index: int | None # Pattern-relative child region index (or None) + node_index: int | None # Pattern-relative node index (or None) + output_index: int # Output tensor index + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "region_index": self.region_index, + "node_index": self.node_index, + "output_index": self.output_index, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "RegionOutputInsertionPoint": + """Create from dictionary. + + Args: + data: Dictionary with 'region_index', 'node_index', and 'output_index' keys + + Returns: + RegionOutputInsertionPoint instance + """ + return cls( + region_index=data.get("region_index"), + node_index=data.get("node_index"), + output_index=data["output_index"], + ) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a region output insertion point to actual tensor names for a matching region. + + Converts pattern-relative indices to the actual tensor name at an output location: + - If region_index is set: Resolves to a child region's output tensor + - If node_index is set: Resolves to a node's output tensor + + Then identifies all node inputs that consume that output tensor. + + Args: + region: The region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names + """ + if graph is None: + raise ValueError("graph parameter is required") + + # Get sorted nodes for node output resolution + nodes_list = list(graph.nodes) + node_indices = region.get_nodes(sort=True) + children_regions = region.get_children(sort=True) + + # Resolve each region output insertion point from the scheme to actual tensor names + resolved_ips = set() + # Handle child region outputs (region_index is set) + if self.region_index is not None: + assert self.region_index < len(children_regions), "Region index out of range" + child_region = children_regions[self.region_index] + assert self.output_index < len(child_region.get_outputs()), "Output index out of range" + tensor_name = child_region.get_outputs()[self.output_index] + assert tensor_name is not None, "Invalid tensor name" + resolved_ips.update( + resolve_region_io_insertion_points(child_region, graph, tensor_name) + ) + # Handle node outputs (node_index is set) + elif self.node_index is not None: + assert self.node_index < len(node_indices), "Node index out of range" + node_idx = node_indices[self.node_index] + assert node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[node_idx] + assert self.output_index < len(node.outputs), "Output index out of range" + tensor = node.outputs[self.output_index] + assert tensor is not None, "Invalid tensor name" + assert hasattr(tensor, "name") and tensor.name, "Tensor name is required" + resolved_ips.update(resolve_region_io_insertion_points(None, graph, tensor.name)) + return resolved_ips + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["RegionOutputInsertionPoint"]: + """Collect all valid region output insertion points from a region. + + Identifies all valid output tensors (from child regions or nodes) that leave + the region boundary and could have Q/DQ pairs inserted. Only includes outputs + that are actual region outputs (not consumed internally). + + For COMPOSITE regions: + - Collects child region outputs that are also region outputs + - Collects node outputs that are region outputs + + For LEAF regions: + - Only collects node outputs that are region outputs + + Args: + region: The region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of RegionOutputInsertionPoint objects representing valid insertion locations + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + nodes_list = list(graph.nodes) + node_indices = region.get_nodes(sort=True) + region_outputs_set = set(region.get_outputs()) + + # Only include outputs that are actual region outputs (leave the region) + region_output_insertion_points = [] + if region.get_type() != RegionType.LEAF: + # For COMPOSITE regions: check if child region output is a region output + children_regions = region.get_children(sort=True) + for local_idx, child_region in enumerate(children_regions): + for output_idx, out in enumerate(child_region.get_outputs()): + if out not in region_outputs_set: + continue + if skip_invalid_insertion_points(graph, out, child_region): + continue + point = RegionOutputInsertionPoint( + region_index=local_idx, + node_index=None, + output_index=output_idx, + ) + region_output_insertion_points.append(point) + # For all regions: check if node output is a region output + for local_idx, node_idx in enumerate(node_indices): + assert node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[node_idx] + for output_idx, out in enumerate(node.outputs): + # Skip if tensor doesn't have a valid name + if not (hasattr(out, "name") and out.name): + continue + # Skip if this output is not a region output (i.e., it's consumed internally) + if out.name not in region_outputs_set: + continue + # Skip if insertion point is invalid (wrong dtype, small size, etc.) + if skip_invalid_insertion_points(graph, out.name, node): + continue + # Create insertion point for valid output tensor + point = RegionOutputInsertionPoint( + region_index=None, + node_index=local_idx, + output_index=output_idx, + ) + region_output_insertion_points.append(point) + + return region_output_insertion_points + + +def skip_invalid_insertion_points( + graph: gs.Graph, tensor_name: str, region_or_node: "Region | gs.Node" +) -> bool: + """Determine if a tensor should be skipped for Q/DQ insertion. + + Filters out tensors that are not suitable for quantization based on various criteria: + - Boolean and shape operations (not quantizable) + - Fused operation patterns (Conv->BatchNorm->ReLU) + - Operation-specific non-quantizable inputs (weights, biases, BN parameters) + - Non-floating-point tensors (indices, masks) + - Small tensors (scalars, small vectors with < 8 elements) + + Args: + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor to evaluate + region_or_node: Either a Region or a Node to check for usage of this tensor + + Returns: + True if the insertion point should be skipped, False if it's valid for quantization + """ + from modelopt.onnx.quantization.autotune.common import Region + + if isinstance(region_or_node, Region): + node_indices = region_or_node.get_region_nodes_and_descendants() + nodes: list[gs.Node] = [graph.nodes[node_idx] for node_idx in node_indices] + else: + assert isinstance(region_or_node, gs.Node) + nodes = [region_or_node] + + for node in nodes: + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + # Skip weights of Conv and ConvTranspose, they should be quantized with inputs at same time + if node.op in ["Conv", "ConvTranspose"] and input_idx >= 1: + return True + if node.op in ["Relu", "Softmax"]: + # Conv -> ReLU + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Conv -> BatchNormalization -> ReLU + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op == "BatchNormalization": + assert len(producer.inputs) >= 1, ( + "BN node should have more than one inputs" + ) + if len(producer.inputs[0].inputs) == 1: + producer = producer.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Conv -> BatchNormalization + if node.op == "BatchNormalization": + assert len(node.inputs) >= 1, "BN node should have more than one inputs" + if len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Filter 1: out boolean operations + if node.op in get_bool_operations(): + return True + # Filter 2: out shape operations + if node.op in get_autotuner_skip_ops(): + return True + # Filter 3: Skip operation-specific non-quantizable inputs + if node.op in ["BatchNormalization", "Resize"] and input_idx >= 1: + return True + if node.op in ["Conv", "Gemm"] and input_idx >= 2: + return True + # Filter 4: Skip non-floating-point tensors (int/bool indices, masks, etc.) + if hasattr(inp, "dtype") and inp.dtype not in [ + None, + np.float32, + np.float16, + np.float64, + ]: + return True + # Filter 5: Skip small tensors (scalars, small vectors) + if hasattr(inp, "shape") and inp.shape is not None: + if all(isinstance(s, int) for s in inp.shape): + if np.prod(inp.shape) < 8: + return True + return False + + +def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: + """Check if a region contains major quantizable operations. + + Args: + region: The region to check + graph: The ONNX graph containing the nodes + + Returns: + True if the region contains major quantizable operations, False otherwise + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + # only check leaf regions for quantizable operations + if region.get_type() == RegionType.LEAF: + region_ops = {graph.nodes[idx].op for idx in region.get_nodes()} + return bool(region_ops.intersection(get_autotuner_quantizable_operations())) + return True + + +def resolve_region_io_insertion_points( + region: "Region | None", graph: gs.Graph, tensor_name: str +) -> set[ResolvedInsertionPoint]: + """Resolve region input/output boundaries to actual Q/DQ insertion points. + + For a given tensor at a region boundary (input or output), this function + identifies all the actual node inputs where Q/DQ pairs should be inserted. + It considers both nodes within the region (if provided) and all users of + the tensor in the graph. + + **Use Cases:** + - Child region inputs: Find all nodes inside the child that consume the input tensor + - Child region outputs: Find all nodes outside the child that consume the output tensor + - Node outputs: Find all nodes that consume the tensor (region can be None) + + Args: + region: The region to search within (or None to search entire graph) + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor at the region boundary + + Returns: + Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ pairs + """ + resolved_insertion_points = set() + tensor_users_map: dict[str, list[int]] = {} + if hasattr(graph, "tensor_users_map"): + tensor_users_map = graph.tensor_users_map + if not tensor_users_map: + tensor_users_map = get_tensor_consumer_node_indices(graph) + + node_indices: set[int] = set() + if region is not None: + node_indices.update(region.get_region_nodes_and_descendants()) + if tensor_name in tensor_users_map: + node_indices.update(tensor_users_map[tensor_name]) + + for node_idx in node_indices: + assert node_idx < len(graph.nodes), "Node index out of range" + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if inp.name == tensor_name: + ip = ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) + resolved_insertion_points.add(ip) + + return resolved_insertion_points + + +def merge_resolved_insertion_points( + graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] +) -> set[ResolvedInsertionPoint]: + """Optimize insertion points by merging node-specific insertions into tensor-level insertions. + + When all consumers (users) of a tensor have Q/DQ insertion points, it's more efficient + to insert Q/DQ once at the tensor level rather than at each individual node input. + This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. + + **Optimization Logic:** + - For each tensor with multiple node-specific insertion points: + - If ALL users of the tensor have insertion points → merge to tensor-level insertion + - If SOME users have insertion points → keep node-specific insertions + + Args: + graph: The ONNX graph containing the nodes + resolved_insertion_points: Set of resolved insertion points to optimize + + Returns: + Optimized set of insertion points with merged tensor-level insertions where possible + """ + tensor_users_map = get_tensor_consumer_node_indices(graph) + node_input_insertion_points = { + ip for ip in resolved_insertion_points if ip.node_index is not None + } + tensor_names = {ip.tensor_name for ip in node_input_insertion_points} + + results = resolved_insertion_points.difference(node_input_insertion_points) + for tensor_name in tensor_names: + all_users = set(tensor_users_map[tensor_name]) + qdq_users = { + user for user in node_input_insertion_points if user.tensor_name == tensor_name + } + qdq_user_ids = set({user.node_index for user in qdq_users}) + if all_users == qdq_user_ids: + results.add( + ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None) + ) + else: + results.update(qdq_users) + + return results diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 67596d5df..63633279e 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -302,6 +302,24 @@ def get_tensor_consumer_nodes( return tensor_consumers +def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[str, list[int]]: + """Build a mapping from tensor names to the indices of nodes that use them. + + Args: + graph: ONNX GraphSurgeon graph to analyze + + Returns: + Dictionary mapping tensor names to lists of node indices that consume them + """ + tensor_consumer_map: dict[str, list[int]] = defaultdict(list) + nodes = graph.nodes if isinstance(graph, gs.Graph) else graph.node + for node_idx, node in enumerate(nodes): + inputs = node.inputs if isinstance(node, gs.Node) else node.input + for tensor in inputs: + tensor_consumer_map[tensor.name].append(node_idx) + return tensor_consumer_map + + def filter_quantizable_kgen_heads( cask_fusible_partitions: list[list[Node]], kgen_partitions: list[list[Node]], diff --git a/tests/unit/onnx/quantization/autotune/test_insertion_points.py b/tests/unit/onnx/quantization/autotune/test_insertion_points.py new file mode 100644 index 000000000..8d6bf5973 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_insertion_points.py @@ -0,0 +1,1314 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive tests for common data structures in the autotuner. + +Tests: +1. InsertionPoint classes (NodeInputInsertionPoint, RegionOutputInsertionPoint, ChildRegionInputInsertionPoint) +2. InsertionScheme serialization/deserialization +3. InsertionScheme hashing and equality +4. InsertionScheme properties and methods +5. PatternSchemes management +6. Utility functions (skip_invalid_insertion_points, has_quantizable_operations, etc.) +7. Resolve and collect_from methods for all InsertionPoint types +""" + +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +import onnx_graphsurgeon as gs +import pytest + +from modelopt.onnx.quantization.autotune.common import ( + ChildRegionInputInsertionPoint, + InsertionScheme, + NodeInputInsertionPoint, + Region, + RegionOutputInsertionPoint, + RegionType, +) +from modelopt.onnx.quantization.autotune.insertion_points import ( + ResolvedInsertionPoint, + has_quantizable_operations, + merge_resolved_insertion_points, + resolve_region_io_insertion_points, + skip_invalid_insertion_points, +) + + +class TestNodeInputInsertionPoint(unittest.TestCase): + """Test NodeInputInsertionPoint functionality.""" + + def test_creation(self): + """Test creating NodeInputInsertionPoint.""" + point = NodeInputInsertionPoint(node_index=5, input_index=2) + assert point.node_index == 5 + assert point.input_index == 2 + + def test_immutability(self): + """Test that NodeInputInsertionPoint is immutable (frozen).""" + point = NodeInputInsertionPoint(node_index=1, input_index=0) + with pytest.raises(AttributeError): + point.node_index = 2 + + def test_equality(self): + """Test equality comparison.""" + point1 = NodeInputInsertionPoint(node_index=3, input_index=1) + point2 = NodeInputInsertionPoint(node_index=3, input_index=1) + point3 = NodeInputInsertionPoint(node_index=3, input_index=2) + + assert point1 == point2 + assert point1 != point3 + + def test_hashable(self): + """Test that points can be used in sets and dicts.""" + point1 = NodeInputInsertionPoint(node_index=1, input_index=0) + point2 = NodeInputInsertionPoint(node_index=1, input_index=0) + point3 = NodeInputInsertionPoint(node_index=2, input_index=0) + + point_set = {point1, point2, point3} + assert len(point_set) == 2 # point1 and point2 are the same + + def test_serialization(self): + """Test to_dict and from_dict.""" + point = NodeInputInsertionPoint(node_index=7, input_index=3) + + data = point.to_dict() + assert data["node_index"] == 7 + assert data["input_index"] == 3 + + restored = NodeInputInsertionPoint.from_dict(data) + assert point == restored + + def test_string_representation(self): + """Test __str__ method.""" + point = NodeInputInsertionPoint(node_index=2, input_index=1) + s = str(point) + assert "2" in s + assert "1" in s + + +class TestRegionOutputInsertionPoint(unittest.TestCase): + """Test RegionOutputInsertionPoint functionality.""" + + def test_creation_with_region_index(self): + """Test creating with region_index (child region output).""" + point = RegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) + assert point.region_index == 2 + assert point.node_index is None + assert point.output_index == 1 + + def test_creation_with_node_index(self): + """Test creating with node_index (node output).""" + point = RegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) + assert point.region_index is None + assert point.node_index == 5 + assert point.output_index == 0 + + def test_immutability(self): + """Test that RegionOutputInsertionPoint is immutable (frozen).""" + point = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + with pytest.raises(AttributeError): + point.region_index = 2 + + def test_equality(self): + """Test equality comparison.""" + point1 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point2 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point3 = RegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) + + assert point1 == point2 + assert point1 != point3 + + def test_hashable(self): + """Test that points can be used in sets and dicts.""" + point1 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point2 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point3 = RegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) + + point_set = {point1, point2, point3} + assert len(point_set) == 2 # point1 and point2 are the same + + def test_serialization_region_index(self): + """Test serialization with region_index.""" + point = RegionOutputInsertionPoint(region_index=3, node_index=None, output_index=2) + + data = point.to_dict() + assert data["region_index"] == 3 + assert data["node_index"] is None + assert data["output_index"] == 2 + + restored = RegionOutputInsertionPoint.from_dict(data) + assert point == restored + + def test_serialization_node_index(self): + """Test serialization with node_index.""" + point = RegionOutputInsertionPoint(region_index=None, node_index=7, output_index=1) + + data = point.to_dict() + assert data["region_index"] is None + assert data["node_index"] == 7 + assert data["output_index"] == 1 + + restored = RegionOutputInsertionPoint.from_dict(data) + assert point == restored + + def test_string_representation(self): + """Test __str__ method.""" + point1 = RegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) + s1 = str(point1) + assert "region" in s1.lower() + assert "2" in s1 + + point2 = RegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) + s2 = str(point2) + assert "node" in s2.lower() + assert "5" in s2 + + +class TestChildRegionInputInsertionPoint(unittest.TestCase): + """Test ChildRegionInputInsertionPoint functionality.""" + + def test_creation(self): + """Test creating ChildRegionInputInsertionPoint.""" + point = ChildRegionInputInsertionPoint(region_index=3, input_index=1) + assert point.region_index == 3 + assert point.input_index == 1 + + def test_immutability(self): + """Test that ChildRegionInputInsertionPoint is immutable (frozen).""" + point = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + with pytest.raises(AttributeError): + point.region_index = 2 + + def test_equality(self): + """Test equality comparison.""" + point1 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) + point2 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) + point3 = ChildRegionInputInsertionPoint(region_index=2, input_index=1) + + assert point1 == point2 + assert point1 != point3 + + def test_hashable(self): + """Test that points can be used in sets and dicts.""" + point1 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + point2 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + point3 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) + + point_set = {point1, point2, point3} + assert len(point_set) == 2 # point1 and point2 are the same + + def test_serialization(self): + """Test to_dict and from_dict.""" + point = ChildRegionInputInsertionPoint(region_index=5, input_index=2) + + data = point.to_dict() + assert data["region_index"] == 5 + assert data["input_index"] == 2 + + restored = ChildRegionInputInsertionPoint.from_dict(data) + assert point == restored + + def test_string_representation(self): + """Test __str__ method.""" + point = ChildRegionInputInsertionPoint(region_index=2, input_index=1) + s = str(point) + assert "2" in s + assert "1" in s + + +class TestInsertionScheme(unittest.TestCase): + """Test InsertionScheme functionality.""" + + def test_empty_scheme(self): + """Test empty InsertionScheme.""" + scheme = InsertionScheme() + + assert scheme.is_empty + assert scheme.num_node_insertions == 0 + assert scheme.num_region_insertions == 0 + assert scheme.num_region_output_insertions == 0 + assert not scheme.error + + def test_scheme_with_node_inputs(self): + """Test scheme with node input insertion points.""" + scheme = InsertionScheme() + scheme.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + assert not scheme.is_empty + assert scheme.num_node_insertions == 2 + + def test_scheme_with_region_outputs(self): + """Test scheme with region output insertion points.""" + scheme = InsertionScheme() + scheme.region_outputs = [ + RegionOutputInsertionPoint(None, 0, 0), + RegionOutputInsertionPoint(1, None, 0), + ] + + assert not scheme.is_empty + assert scheme.num_region_output_insertions == 2 + + def test_scheme_with_composite_regions(self): + """Test scheme with composite region insertion points.""" + scheme = InsertionScheme() + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint(0, 0), + ChildRegionInputInsertionPoint(1, 0), + ] + + assert not scheme.is_empty + assert scheme.num_region_insertions == 2 + + def test_scheme_hash_empty(self): + """Test hash of empty scheme.""" + scheme1 = InsertionScheme() + scheme2 = InsertionScheme() + + assert scheme1.hash == scheme2.hash + + def test_scheme_hash_with_points(self): + """Test hash with insertion points.""" + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + scheme2 = InsertionScheme() + scheme2.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + scheme3 = InsertionScheme() + scheme3.node_inputs = [ + NodeInputInsertionPoint(0, 0), + NodeInputInsertionPoint(2, 0), # Different + ] + + assert scheme1.hash == scheme2.hash + assert scheme1.hash != scheme3.hash + + def test_scheme_hash_order_independent(self): + """Test that hash is independent of insertion point order.""" + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + scheme2 = InsertionScheme() + scheme2.node_inputs = [ + NodeInputInsertionPoint(1, 0), + NodeInputInsertionPoint(0, 0), # Reversed order + ] + + # Hash should be the same regardless of order + assert scheme1.hash == scheme2.hash + + def test_serialization_empty(self): + """Test serialization of empty scheme.""" + scheme = InsertionScheme() + + data = scheme.to_dict() + restored = InsertionScheme.from_dict(data) + + assert restored.is_empty + assert restored.latency_ms == float("inf") + assert not restored.error + + def test_serialization_full(self): + """Test serialization with all types of insertion points.""" + scheme = InsertionScheme() + scheme.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme.child_region_inputs = [ChildRegionInputInsertionPoint(0, 0)] + scheme.region_outputs = [RegionOutputInsertionPoint(None, 0, 0)] + scheme.latency_ms = 12.5 + scheme.error = False + + data = scheme.to_dict() + restored = InsertionScheme.from_dict(data) + + assert len(restored.node_inputs) == 1 + assert len(restored.child_region_inputs) == 1 + assert len(restored.region_outputs) == 1 + assert restored.latency_ms == 12.5 + assert not restored.error + + def test_serialization_with_error(self): + """Test serialization with error flag.""" + scheme = InsertionScheme() + scheme.error = True + scheme.latency_ms = float("inf") + + data = scheme.to_dict() + restored = InsertionScheme.from_dict(data) + + assert restored.error + assert restored.latency_ms == float("inf") + + +# ============================================================================= +# Helper functions for creating mock graphs +# ============================================================================= + + +def _create_mock_tensor(name: str, dtype=np.float32, shape=None): + """Create a mock tensor with the specified properties.""" + tensor = MagicMock() + tensor.name = name + tensor.dtype = dtype + tensor.shape = shape if shape is not None else [1, 3, 224, 224] + tensor.inputs = [] + return tensor + + +def _create_mock_node(op: str, inputs: list, outputs: list, name: str = ""): + """Create a mock node with the specified properties.""" + node = MagicMock(spec=gs.Node) + node.op = op + node.name = name + node.inputs = inputs + node.outputs = outputs + return node + + +def _create_simple_graph(): + """Create a mock graph with Conv -> BatchNorm -> Relu -> MaxPool pattern. + + Graph structure: + input -> Conv -> conv_out -> BatchNorm -> bn_out -> Relu -> relu_out -> MaxPool -> pool_out + + Node indices: + 0: Conv + 1: BatchNormalization + 2: Relu + 3: MaxPool + """ + # Create tensors with realistic shapes + input_tensor = _create_mock_tensor("input", np.float32, [1, 3, 224, 224]) + weight_tensor = _create_mock_tensor("conv_weight", np.float32, [64, 3, 3, 3]) + bias_tensor = _create_mock_tensor("conv_bias", np.float32, [64]) + conv_output = _create_mock_tensor("conv_out", np.float32, [1, 64, 222, 222]) + + # BatchNorm parameters + bn_scale = _create_mock_tensor("bn_scale", np.float32, [64]) + bn_bias = _create_mock_tensor("bn_bias", np.float32, [64]) + bn_mean = _create_mock_tensor("bn_mean", np.float32, [64]) + bn_var = _create_mock_tensor("bn_var", np.float32, [64]) + bn_output = _create_mock_tensor("bn_out", np.float32, [1, 64, 222, 222]) + + relu_output = _create_mock_tensor("relu_out", np.float32, [1, 64, 222, 222]) + pool_output = _create_mock_tensor("pool_out", np.float32, [1, 64, 111, 111]) + + # Create nodes + conv_node = _create_mock_node( + "Conv", [input_tensor, weight_tensor, bias_tensor], [conv_output], "conv1" + ) + bn_node = _create_mock_node( + "BatchNormalization", + [conv_output, bn_scale, bn_bias, bn_mean, bn_var], + [bn_output], + "bn1", + ) + relu_node = _create_mock_node("Relu", [bn_output], [relu_output], "relu1") + pool_node = _create_mock_node("MaxPool", [relu_output], [pool_output], "pool1") + + # Link tensors to their producer nodes + conv_output.inputs = [conv_node] + bn_output.inputs = [bn_node] + relu_output.inputs = [relu_node] + pool_output.inputs = [pool_node] + input_tensor.inputs = [] + weight_tensor.inputs = [] + bias_tensor.inputs = [] + + # Create graph + graph = MagicMock(spec=gs.Graph) + graph.nodes = [conv_node, bn_node, relu_node, pool_node] + + tensors = { + "input": input_tensor, + "conv_weight": weight_tensor, + "conv_bias": bias_tensor, + "conv_out": conv_output, + "bn_out": bn_output, + "relu_out": relu_output, + "pool_out": pool_output, + } + + return graph, tensors + + +def _create_residual_graph(): + """Create a mock graph with a residual block pattern (skip connection). + + Graph structure: + input ─────────────────────────────┐ + │ │ + ▼ │ + Conv1 -> conv1_out │ + │ │ + ▼ │ + Relu1 -> relu1_out │ + │ │ + ▼ │ + Conv2 -> conv2_out │ + │ │ + ▼ ▼ + Add (conv2_out + input) -> add_out + │ + ▼ + Relu2 -> output + + Node indices: + 0: Conv1 + 1: Relu1 + 2: Conv2 + 3: Add + 4: Relu2 + """ + # Create tensors + input_tensor = _create_mock_tensor("input", np.float32, [1, 64, 56, 56]) + + # First conv branch + weight1 = _create_mock_tensor("conv1_weight", np.float32, [64, 64, 3, 3]) + conv1_out = _create_mock_tensor("conv1_out", np.float32, [1, 64, 56, 56]) + relu1_out = _create_mock_tensor("relu1_out", np.float32, [1, 64, 56, 56]) + + # Second conv + weight2 = _create_mock_tensor("conv2_weight", np.float32, [64, 64, 3, 3]) + conv2_out = _create_mock_tensor("conv2_out", np.float32, [1, 64, 56, 56]) + + # Add and final relu + add_out = _create_mock_tensor("add_out", np.float32, [1, 64, 56, 56]) + output = _create_mock_tensor("output", np.float32, [1, 64, 56, 56]) + + # Create nodes + conv1_node = _create_mock_node("Conv", [input_tensor, weight1], [conv1_out], "conv1") + relu1_node = _create_mock_node("Relu", [conv1_out], [relu1_out], "relu1") + conv2_node = _create_mock_node("Conv", [relu1_out, weight2], [conv2_out], "conv2") + add_node = _create_mock_node("Add", [conv2_out, input_tensor], [add_out], "add1") + relu2_node = _create_mock_node("Relu", [add_out], [output], "relu2") + + # Link tensors to their producer nodes + conv1_out.inputs = [conv1_node] + relu1_out.inputs = [relu1_node] + conv2_out.inputs = [conv2_node] + add_out.inputs = [add_node] + output.inputs = [relu2_node] + input_tensor.inputs = [] + weight1.inputs = [] + weight2.inputs = [] + + # Create graph + graph = MagicMock(spec=gs.Graph) + graph.nodes = [conv1_node, relu1_node, conv2_node, add_node, relu2_node] + + tensors = { + "input": input_tensor, + "conv1_weight": weight1, + "conv1_out": conv1_out, + "relu1_out": relu1_out, + "conv2_weight": weight2, + "conv2_out": conv2_out, + "add_out": add_out, + "output": output, + } + + return graph, tensors + + +# ============================================================================= +# Utility Function Tests +# ============================================================================= + + +class TestSkipInvalidInsertionPoints(unittest.TestCase): + """Test skip_invalid_insertion_points function.""" + + def test_skip_bool_operations(self): + """Test that boolean operations are skipped.""" + graph, _ = _create_simple_graph() + + # Create a node with boolean operation + bool_tensor = _create_mock_tensor("bool_input", np.float32) + bool_node = _create_mock_node("Equal", [bool_tensor], []) + + result = skip_invalid_insertion_points(graph, "bool_input", bool_node) + assert result is True + + def test_skip_shape_operations(self): + """Test that shape operations are skipped.""" + graph, _ = _create_simple_graph() + + shape_tensor = _create_mock_tensor("shape_input", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], []) + + result = skip_invalid_insertion_points(graph, "shape_input", shape_node) + assert result is True + + def test_skip_conv_weight_input(self): + """Test that Conv weight inputs (index >= 1) are skipped.""" + graph, tensors = _create_simple_graph() + conv_node = graph.nodes[0] + + # Weight is at input index 1 + result = skip_invalid_insertion_points(graph, "conv_weight", conv_node) + assert result is True + + def test_allow_conv_data_input(self): + """Test that Conv data input (index 0) is allowed.""" + graph, tensors = _create_simple_graph() + + # Create a MatMul node that consumes the input tensor (not Conv-related skip) + input_tensor = _create_mock_tensor("matmul_input", np.float32, [1, 3, 224, 224]) + matmul_node = _create_mock_node("MatMul", [input_tensor], []) + + result = skip_invalid_insertion_points(graph, "matmul_input", matmul_node) + assert result is False + + def test_skip_non_float_tensors(self): + """Test that non-floating-point tensors are skipped.""" + graph, _ = _create_simple_graph() + + # Create int tensor + int_tensor = _create_mock_tensor("int_input", np.int32) + node = _create_mock_node("Add", [int_tensor], []) + + result = skip_invalid_insertion_points(graph, "int_input", node) + assert result is True + + def test_skip_small_tensors(self): + """Test that small tensors (< 8 elements) are skipped.""" + graph, _ = _create_simple_graph() + + # Create small tensor (scalar) + small_tensor = _create_mock_tensor("small", np.float32, [1]) + node = _create_mock_node("Add", [small_tensor], []) + + result = skip_invalid_insertion_points(graph, "small", node) + assert result is True + + def test_allow_large_float_tensors(self): + """Test that large floating-point tensors are allowed.""" + graph, _ = _create_simple_graph() + + # Create large float tensor + large_tensor = _create_mock_tensor("large", np.float32, [1, 64, 32, 32]) + node = _create_mock_node("Add", [large_tensor], []) + + result = skip_invalid_insertion_points(graph, "large", node) + assert result is False + + def test_skip_bn_non_data_inputs(self): + """Test that BatchNormalization non-data inputs are skipped.""" + graph, tensors = _create_simple_graph() + bn_node = graph.nodes[1] # BatchNormalization node + + # Scale is at input index 1, should be skipped + result = skip_invalid_insertion_points(graph, "bn_scale", bn_node) + assert result is True + + def test_with_region(self): + """Test skip_invalid_insertion_points with a Region containing multiple nodes.""" + graph, tensors = _create_simple_graph() + + # Create a region containing Conv and BatchNorm nodes + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv node + region.add_node(1) # BatchNorm node + + # Create a shape operation node and add to graph + shape_tensor = _create_mock_tensor("shape_input", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], []) + graph.nodes.append(shape_node) + region.add_node(4) # Add the shape node to region + + result = skip_invalid_insertion_points(graph, "shape_input", region) + assert result is True + + def test_skip_conv_bn_relu_fusion(self): + """Test that Conv->BN->Relu fusion patterns are skipped at intermediate points.""" + graph, tensors = _create_simple_graph() + relu_node = graph.nodes[2] # Relu node + + # Relu input (bn_out) should be skipped when preceded by Conv->BN + result = skip_invalid_insertion_points(graph, "bn_out", relu_node) + assert result is True + + def test_residual_block_add_inputs(self): + """Test insertion points in a residual block with skip connection.""" + graph, tensors = _create_residual_graph() + add_node = graph.nodes[3] # Add node + + # Add's first input (conv2_out) should be allowed + result = skip_invalid_insertion_points(graph, "conv2_out", add_node) + assert result is False + + # Add's second input (skip connection input) should also be allowed + result = skip_invalid_insertion_points(graph, "input", add_node) + assert result is False + + +class TestHasQuantizableOperations(unittest.TestCase): + """Test has_quantizable_operations function.""" + + def test_leaf_with_conv(self): + """Test LEAF region with Conv operation.""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv node + + result = has_quantizable_operations(region, graph) + assert result is True + + def test_leaf_with_maxpool(self): + """Test LEAF region with MaxPool (a major quantizable op).""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(3) # MaxPool node + + result = has_quantizable_operations(region, graph) + assert result is True + + def test_leaf_with_relu_only(self): + """Test LEAF region with only Relu.""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(2) # Relu node only (index 2 in new graph) + + result = has_quantizable_operations(region, graph) + assert result is True # Relu is in MAJOR_QUANTIZABLE_OPERATIONS + + def test_leaf_with_conv_bn_relu(self): + """Test LEAF region with Conv->BN->Relu pattern.""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + + result = has_quantizable_operations(region, graph) + assert result is True + + def test_leaf_without_quantizable_ops(self): + """Test LEAF region without major quantizable operations.""" + # Create graph with only shape operations + shape_tensor = _create_mock_tensor("input", np.float32) + output_tensor = _create_mock_tensor("output", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], [output_tensor]) + transpose_node = _create_mock_node("Transpose", [output_tensor], []) + + graph = MagicMock(spec=gs.Graph) + graph.nodes = [shape_node, transpose_node] + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) + region.add_node(1) + + result = has_quantizable_operations(region, graph) + assert result is False + + def test_composite_region_always_true(self): + """Test that COMPOSITE regions always return True.""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + # Don't add any nodes - COMPOSITE regions assume children have quantizable ops + + result = has_quantizable_operations(region, graph) + assert result is True + + def test_residual_block_has_quantizable_ops(self): + """Test residual block with Add operation.""" + graph, _ = _create_residual_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(3) # Add node + + result = has_quantizable_operations(region, graph) + assert result is True # Add is in MAJOR_QUANTIZABLE_OPERATIONS + + +class TestResolveRegionIOInsertionPoints(unittest.TestCase): + """Test resolve_region_io_insertion_points function.""" + + def test_resolve_with_region(self): + """Test resolving with a region containing Conv->BN->Relu.""" + graph, tensors = _create_simple_graph() + + # Set up tensor_users_map: conv_out is consumed by BatchNorm (node 1) + graph.tensor_users_map = {"conv_out": [1]} + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(1) # BatchNorm node + + result = resolve_region_io_insertion_points(region, graph, "conv_out") + + assert len(result) >= 1 + assert any(ip.tensor_name == "conv_out" for ip in result) + + def test_resolve_without_region(self): + """Test resolving without a region (None) for tensor-level insertion.""" + graph, _ = _create_simple_graph() + + # Set up tensor_users_map: bn_out is consumed by Relu (node 2) + graph.tensor_users_map = {"bn_out": [2]} + + result = resolve_region_io_insertion_points(None, graph, "bn_out") + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "bn_out" + assert ip.node_index == 2 + assert ip.input_index == 0 + + def test_resolve_tensor_not_found(self): + """Test resolving a tensor that has no users.""" + graph, _ = _create_simple_graph() + graph.tensor_users_map = {} + + result = resolve_region_io_insertion_points(None, graph, "nonexistent") + + assert len(result) == 0 + + def test_resolve_residual_skip_connection(self): + """Test resolving input tensor used by both Conv1 and Add (skip connection).""" + graph, tensors = _create_residual_graph() + + # Input tensor is used by Conv1 (node 0) and Add (node 3) + graph.tensor_users_map = {"input": [0, 3]} + + result = resolve_region_io_insertion_points(None, graph, "input") + + # Should find both consumers + assert len(result) == 2 + node_indices = {ip.node_index for ip in result} + assert 0 in node_indices # Conv1 + assert 3 in node_indices # Add + + def test_resolve_with_multiple_consumers(self): + """Test resolving tensor with multiple consumers in a region.""" + graph, tensors = _create_residual_graph() + + # relu1_out feeds conv2 (node 2) + graph.tensor_users_map = {"relu1_out": [2]} + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(2) # Conv2 + + result = resolve_region_io_insertion_points(region, graph, "relu1_out") + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "relu1_out" + assert ip.node_index == 2 + + +class TestMergeResolvedInsertionPoints(unittest.TestCase): + """Test merge_resolved_insertion_points function.""" + + def test_merge_all_users(self): + """Test merging when all users have insertion points.""" + graph, _ = _create_simple_graph() + + # Setup: tensor "conv_out" is used by BatchNorm (node 1) + resolved = { + ResolvedInsertionPoint(tensor_name="conv_out", node_index=1, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should be merged to tensor-level insertion + assert len(result) == 1 + merged = next(iter(result)) + assert merged.tensor_name == "conv_out" + assert merged.node_index is None + assert merged.input_index is None + + def test_no_merge_partial_users(self): + """Test no merging when only some users have insertion points.""" + graph, _ = _create_simple_graph() + + # Setup: tensor "conv_out" is used by nodes 1 and 2, but only node 1 has IP + resolved = { + ResolvedInsertionPoint(tensor_name="conv_out", node_index=1, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1, 2]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should NOT be merged - keep node-specific + assert len(result) == 1 + ip = next(iter(result)) + assert ip.node_index == 1 # Still node-specific + + def test_preserve_tensor_level_insertions(self): + """Test that existing tensor-level insertions are preserved.""" + graph, _ = _create_simple_graph() + + # Already tensor-level insertion + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=None, input_index=None), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1]} + + result = merge_resolved_insertion_points(graph, resolved) + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "input" + assert ip.node_index is None + + def test_merge_residual_skip_connection(self): + """Test merging with residual block where input has two users.""" + graph, _ = _create_residual_graph() + + # Input tensor used by Conv1 (node 0) and Add (node 3) + # If we have insertion points for both, they should merge + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=0, input_index=0), + ResolvedInsertionPoint(tensor_name="input", node_index=3, input_index=1), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"input": [0, 3]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should be merged to tensor-level insertion + assert len(result) == 1 + merged = next(iter(result)) + assert merged.tensor_name == "input" + assert merged.node_index is None + + def test_no_merge_residual_partial(self): + """Test no merging in residual block when only one branch has insertion point.""" + graph, _ = _create_residual_graph() + + # Input tensor used by Conv1 (node 0) and Add (node 3) + # Only Conv1 has an insertion point + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=0, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"input": [0, 3]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should NOT merge - only one of two users has IP + assert len(result) == 1 + ip = next(iter(result)) + assert ip.node_index == 0 # Still node-specific + + +# ============================================================================= +# Resolve Method Tests +# ============================================================================= + + +class TestNodeInputInsertionPointResolve(unittest.TestCase): + """Test NodeInputInsertionPoint.resolve() method.""" + + def test_resolve_simple(self): + """Test resolving a simple node input for Conv->BN->Relu->Pool.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv node + region.add_node(1) # BatchNorm node + region.add_node(2) # Relu node + region.add_node(3) # MaxPool node + + # Create insertion point for first input of first node (Conv) + ip = NodeInputInsertionPoint(node_index=0, input_index=0) + + result = ip.resolve(region, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "input" for rip in result) + + def test_resolve_conv_includes_weight(self): + """Test that resolving Conv input also includes weight.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv node + + # Create insertion point for first input of Conv (should also add weight) + ip = NodeInputInsertionPoint(node_index=0, input_index=0) + + result = ip.resolve(region, graph) + + # Should include both data input and weight + assert len(result) == 2 + tensor_names = {rip.tensor_name for rip in result} + assert "input" in tensor_names + assert "conv_weight" in tensor_names + + def test_resolve_relu_input(self): + """Test resolving Relu input in the middle of the chain.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + + # Relu is at local index 2, input 0 is bn_out + ip = NodeInputInsertionPoint(node_index=2, input_index=0) + + result = ip.resolve(region, graph) + + assert len(result) == 1 + rip = next(iter(result)) + assert rip.tensor_name == "bn_out" + + def test_resolve_residual_conv_input(self): + """Test resolving Conv input in residual block.""" + graph, tensors = _create_residual_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv1 + region.add_node(1) # Relu1 + region.add_node(2) # Conv2 + + # Conv2 is at local index 2, input 0 is relu1_out + ip = NodeInputInsertionPoint(node_index=2, input_index=0) + + result = ip.resolve(region, graph) + + # Conv includes both data and weight + assert len(result) == 2 + tensor_names = {rip.tensor_name for rip in result} + assert "relu1_out" in tensor_names + assert "conv2_weight" in tensor_names + + +class TestChildRegionInputInsertionPointResolve(unittest.TestCase): + """Test ChildRegionInputInsertionPoint.resolve() method.""" + + def test_resolve_composite_region(self): + """Test resolving child region input in COMPOSITE region.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = {"input": [0]} + + # Create parent (COMPOSITE) with child (LEAF) containing Conv->BN->Relu + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child.inputs = ["input"] + child.add_node(0) # Conv + child.add_node(1) # BatchNorm + child.add_node(2) # Relu + parent.add_child(child) + + ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + + result = ip.resolve(parent, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "input" for rip in result) + + def test_resolve_leaf_returns_empty(self): + """Test that LEAF regions return empty set.""" + graph, _ = _create_simple_graph() + + leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) + leaf.add_node(0) + + ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + + result = ip.resolve(leaf, graph) + + assert len(result) == 0 + + def test_resolve_multiple_children(self): + """Test resolving child inputs in COMPOSITE with multiple children.""" + graph, tensors = _create_residual_graph() + # input is consumed by Conv1 (node 0) and Add (node 3) + graph.tensor_users_map = {"input": [0, 3], "conv1_out": [1]} + + # Create parent with two child regions + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + + # First child: Conv1 (consumes "input") + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child1.inputs = ["input"] + child1.add_node(0) # Conv1 + + # Second child: Relu1 (consumes "conv1_out") + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + child2.inputs = ["conv1_out"] + child2.add_node(1) # Relu1 + + parent.add_child(child1) + parent.add_child(child2) + + # Resolve input of first child (region_index=0) - "input" tensor + ip1 = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + result1 = ip1.resolve(parent, graph) + + assert len(result1) >= 1 + assert any(rip.tensor_name == "input" for rip in result1) + + # Resolve input of second child (region_index=1) - "conv1_out" tensor + ip2 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + result2 = ip2.resolve(parent, graph) + + assert len(result2) >= 1 + assert any(rip.tensor_name == "conv1_out" for rip in result2) + + +class TestRegionOutputInsertionPointResolve(unittest.TestCase): + """Test RegionOutputInsertionPoint.resolve() method.""" + + def test_resolve_node_output(self): + """Test resolving a node output.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = {"conv_out": [1]} + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + region.outputs = ["conv_out"] + + # Output of first node (Conv) + ip = RegionOutputInsertionPoint(region_index=None, node_index=0, output_index=0) + + result = ip.resolve(region, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "conv_out" for rip in result) + + def test_resolve_child_region_output(self): + """Test resolving a child region output.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = {"relu_out": [3]} + + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child.outputs = ["relu_out"] + child.add_node(0) # Conv + child.add_node(1) # BatchNorm + child.add_node(2) # Relu + parent.add_child(child) + + ip = RegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) + + result = ip.resolve(parent, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "relu_out" for rip in result) + + def test_resolve_residual_add_output(self): + """Test resolving Add output in residual block.""" + graph, tensors = _create_residual_graph() + graph.tensor_users_map = {"add_out": [4]} + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv1 + region.add_node(1) # Relu1 + region.add_node(2) # Conv2 + region.add_node(3) # Add + region.add_node(4) # Relu2 + region.outputs = ["add_out"] + + # Add is at local index 3, output 0 + ip = RegionOutputInsertionPoint(region_index=None, node_index=3, output_index=0) + + result = ip.resolve(region, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "add_out" for rip in result) + + +# ============================================================================= +# Collect From Region Tests +# ============================================================================= + + +class TestNodeInputInsertionPointCollectFrom(unittest.TestCase): + """Test NodeInputInsertionPoint.collect_from_region() method.""" + + def test_collect_valid_inputs(self): + """Test collecting valid node input insertion points from Conv->BN->Relu->Pool.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + region.add_node(3) # MaxPool + + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected some insertion points + assert len(result) >= 1 + # All should be NodeInputInsertionPoint + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + def test_collect_from_residual_block(self): + """Test collecting from residual block with skip connection.""" + graph, tensors = _create_residual_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv1 + region.add_node(1) # Relu1 + region.add_node(2) # Conv2 + region.add_node(3) # Add + region.add_node(4) # Relu2 + + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected insertion points from Conv1, Add inputs, etc. + assert len(result) >= 1 + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + # Check that we have insertion points for different nodes + node_indices = {ip.node_index for ip in result} + assert len(node_indices) >= 1 # At least one node has valid inputs + + +class TestChildRegionInputInsertionPointCollectFrom(unittest.TestCase): + """Test ChildRegionInputInsertionPoint.collect_from_region() method.""" + + def test_collect_from_composite(self): + """Test collecting from COMPOSITE region with children.""" + graph, tensors = _create_simple_graph() + + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child.inputs = ["input"] + child.add_node(0) # Conv + child.add_node(1) # BatchNorm + child.add_node(2) # Relu + parent.add_child(child) + + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + + # Should find the child's input + assert len(result) >= 0 # May be filtered by skip_invalid_insertion_points + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + + def test_collect_from_leaf_returns_empty(self): + """Test that LEAF regions return empty list.""" + graph, _ = _create_simple_graph() + + leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) + leaf.add_node(0) + + result = ChildRegionInputInsertionPoint.collect_from_region(leaf, graph) + + assert len(result) == 0 + + def test_collect_from_composite_with_multiple_children(self): + """Test collecting from COMPOSITE with multiple child regions.""" + graph, tensors = _create_residual_graph() + + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child1.inputs = ["input"] + child1.add_node(0) # Conv1 + child1.add_node(1) # Relu1 + + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + child2.inputs = ["relu1_out", "input"] # Two inputs including skip connection + child2.add_node(2) # Conv2 + child2.add_node(3) # Add + + parent.add_child(child1) + parent.add_child(child2) + + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + + # Should find inputs from both children + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + + +class TestRegionOutputInsertionPointCollectFrom(unittest.TestCase): + """Test RegionOutputInsertionPoint.collect_from_region() method.""" + + def test_collect_node_outputs(self): + """Test collecting node output insertion points.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + region.add_node(3) # MaxPool + region.outputs = ["pool_out"] # Only pool_out is a region output + + result = RegionOutputInsertionPoint.collect_from_region(region, graph) + + # Should find the node output that matches region output + assert len(result) >= 0 # May be filtered + assert all(isinstance(ip, RegionOutputInsertionPoint) for ip in result) + + def test_collect_child_region_outputs(self): + """Test collecting child region output insertion points.""" + graph, tensors = _create_simple_graph() + + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child.outputs = ["relu_out"] + child.add_node(0) # Conv + child.add_node(1) # BatchNorm + child.add_node(2) # Relu + parent.add_child(child) + parent.outputs = ["relu_out"] # Child output is also parent output + + result = RegionOutputInsertionPoint.collect_from_region(parent, graph) + + # Should find the child region output + assert all(isinstance(ip, RegionOutputInsertionPoint) for ip in result) + + def test_collect_residual_block_outputs(self): + """Test collecting outputs from residual block.""" + graph, tensors = _create_residual_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv1 + region.add_node(1) # Relu1 + region.add_node(2) # Conv2 + region.add_node(3) # Add + region.add_node(4) # Relu2 + region.outputs = ["output"] # Final output + + result = RegionOutputInsertionPoint.collect_from_region(region, graph) + + # Should find the output + assert all(isinstance(ip, RegionOutputInsertionPoint) for ip in result) diff --git a/tests/unit/onnx/quantization/autotune/test_region.py b/tests/unit/onnx/quantization/autotune/test_region.py new file mode 100644 index 000000000..1481a3ad1 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the Region class in the autotuner. + +Tests region creation, hierarchy, and boundary management. +""" + +import unittest + +from modelopt.onnx.quantization.autotune.common import Region, RegionType + + +class TestRegion(unittest.TestCase): + """Test Region class functionality.""" + + def test_region_creation(self): + """Test creating regions of all types.""" + test_cases = [ + {"region_id": 1, "level": 0, "region_type": RegionType.LEAF}, + {"region_id": 2, "level": 1, "region_type": RegionType.COMPOSITE}, + {"region_id": 0, "level": 2, "region_type": RegionType.ROOT}, + ] + + for params in test_cases: + with self.subTest(**params): + region = Region(**params) + assert region.get_id() == params["region_id"] + assert region.get_level() == params["level"] + assert region.get_type() == params["region_type"] + + def test_parent_child_relationship(self): + """Test parent-child relationships.""" + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + + parent.add_child(child1) + parent.add_child(child2) + + assert len(parent.get_children()) == 2 + assert child1.get_parent() == parent + assert child2.get_parent() == parent + assert child1 in parent.get_children() + assert child2 in parent.get_children() + + def test_add_nodes(self): + """Test adding nodes to a region.""" + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + + region.add_node(0) + region.add_node(1) + region.add_node(2) + + assert region.get_size() == 3 + assert 0 in region.get_nodes() + assert 1 in region.get_nodes() + assert 2 in region.get_nodes() + + def test_input_output_tensors(self): + """Test setting input and output tensors.""" + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + + # Directly assign to inputs/outputs attributes + region.inputs = ["input_tensor_1", "input_tensor_2"] + region.outputs = ["output_tensor_1"] + + assert len(region.get_inputs()) == 2 + assert len(region.get_outputs()) == 1 + assert "input_tensor_1" in region.get_inputs() + assert "output_tensor_1" in region.get_outputs() + + def test_region_size_recursive(self): + """Test recursive size calculation.""" + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + + # Add nodes to children + child1.add_node(0) + child1.add_node(1) + child2.add_node(2) + child2.add_node(3) + child2.add_node(4) + + # Add children to parent + parent.add_child(child1) + parent.add_child(child2) + + # Parent itself might have direct nodes + parent.add_node(5) + + # Recursive count should include all nodes + assert len(parent.get_region_nodes_and_descendants()) == 6 + + def test_metadata(self): + """Test region metadata storage.""" + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + + region.metadata["pattern"] = "Conv->Relu" + region.metadata["quantizable"] = "true" + + assert region.metadata["pattern"] == "Conv->Relu" + assert region.metadata["quantizable"] == "true" + + def test_region_type_checks(self): + """Test checking region types (LEAF and COMPOSITE).""" + leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) + composite = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + + assert leaf.get_type() == RegionType.LEAF + assert leaf.get_type() != RegionType.COMPOSITE + assert composite.get_type() == RegionType.COMPOSITE + assert composite.get_type() != RegionType.LEAF + + def test_hierarchical_structure(self): + """Test complex hierarchical structure.""" + root = Region(region_id=0, level=2, region_type=RegionType.ROOT) + composite1 = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + composite2 = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + leaf1 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + leaf2 = Region(region_id=4, level=0, region_type=RegionType.LEAF) + leaf3 = Region(region_id=5, level=0, region_type=RegionType.LEAF) + + # Build hierarchy + root.add_child(composite1) + root.add_child(composite2) + composite1.add_child(leaf1) + composite1.add_child(leaf2) + composite2.add_child(leaf3) + + # Add some nodes + leaf1.add_node(0) + leaf2.add_node(1) + leaf3.add_node(2) + + # Verify structure + assert len(root.get_children()) == 2 + assert len(composite1.get_children()) == 2 + assert len(composite2.get_children()) == 1 + assert len(root.get_region_nodes_and_descendants()) == 3 + + def test_remove_child(self): + """Test removing a child region.""" + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + + parent.add_child(child) + assert len(parent.get_children()) == 1 + + parent.remove_child(child) + assert len(parent.get_children()) == 0 + assert child.get_parent() is None