Static shape inference for ONNX models where standard tools fail.
Tested on all models from VNN-COMP 2024 with 100% success rate.
ONNX's built-in onnx.shape_inference.infer_shapes handles most models correctly, but fails in several critical scenarios:
- Models with inconsistent ONNX versions or opset mismatches
- Non-standard conversions from PyTorch or other frameworks
- Dynamic shape operations where shape computations depend on data
- Shape operator chains (
Shape → Gather → Add → Reshape) - Models with custom shape manipulations
ShapeONNX solves these problems through static shape computation:
- Simulates shape calculations through a mini computation graph
- Propagates static values through shape operator chains
- Resolves intermediate shape tensors to compile-time constants
- Converts dynamic shape operations to static equivalents
- Provides reliable shape inference for neural network verification
Neural network verification tools require precise static shapes for:
- Layer-by-layer bound propagation
- Memory allocation for symbolic execution
- Constraint generation for SMT solvers
- Model optimization and fusion (SlimONNX)
When ONNX shape inference fails, verification pipelines break. ShapeONNX fills this gap by providing robust static shape inference for the complex models encountered in verification research.
- Robust Shape Inference: Handles models where onnx.shape_inference fails
- Shape Operator Chains: Resolves
Shape → Gather → Slice → Addpatterns - Dynamic to Static: Converts runtime shape computations to compile-time constants
- 46 Operators: Comprehensive coverage across 10 operator categories
- Fast Performance: Single-pass O(1) forward propagation
- Pure Python: No C/C++ dependencies, easy integration
- Production Ready: Tested on 140 VNN-COMP 2024 models
ShapeONNX is essential for:
- Neural Network Verification: Tools requiring static shapes (α,β-CROWN, ERAN, Marabou)
- Model Optimization: Pre-optimization shape resolution (SlimONNX)
- Shape-Dependent Transformations: Operations requiring known tensor dimensions
- Complex Model Analysis: Understanding shape propagation in non-standard models
- Python 3.10 or higher
- onnx 1.17.0
- numpy 2.2.4
Important: ONNX version compatibility matters. Use the specified versions to avoid opset incompatibilities.
pip install onnx==1.17.0 numpy==2.2.4- ONNX 1.17.0: Tested opset range 17-21
- NumPy 2.2.4: Required for Python 3.10+ compatibility
Models should be converted to ONNX IR version 21 using onnx.version_converter for maximum compatibility.
import onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.shapeonnx.utils import (
get_initializers,
get_input_nodes,
get_output_nodes,
convert_constant_to_initializer,
)
# Load and prepare model
model = onnx.load("model.onnx")
model = onnx.version_converter.convert_version(model, target_version=21)
# Extract model components
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
# Convert Constant nodes to initializers (required preprocessing)
nodes = convert_constant_to_initializer(list(model.graph.node), initializers)
# Infer shapes
shapes = infer_onnx_shape(
input_nodes,
output_nodes,
nodes,
initializers,
has_batch_dim=True,
verbose=False,
)
# Access inferred shapes
for tensor_name, shape in shapes.items():
print(f"{tensor_name}: {shape}")Main shape inference function.
def infer_onnx_shape(
input_nodes: list[ValueInfoProto],
output_nodes: list[ValueInfoProto],
nodes: list[NodeProto],
initializers: dict[str, TensorProto],
has_batch_dim: bool = True,
verbose: bool = False,
) -> dict[str, list[int]]Parameters:
input_nodes(list[ValueInfoProto]): Model input value infosoutput_nodes(list[ValueInfoProto]): Model output value infosnodes(list[NodeProto]): Model computation nodes (Constant nodes must be converted to initializers)initializers(dict[str, TensorProto]): Model initializers (weights and constants)has_batch_dim(bool): Whether model has batch dimension (default: True)verbose(bool): Print debug information during inference (default: False)
Returns: dict[str, list[int]] - Dictionary mapping tensor names to inferred shapes
Note: Constant nodes must be converted to initializers before calling this function using convert_constant_to_initializer().
Extract shapes from model input/output nodes.
def extract_io_shapes(
nodes: list[ValueInfoProto],
has_batch_dim: bool
) -> dict[str, list[int]]Parameters:
nodes(list[ValueInfoProto]): Input or output value infoshas_batch_dim(bool): Whether tensors have batch dimension
Returns: dict[str, list[int]] - Dictionary mapping names to shapes
Convert Constant nodes to initializers (required preprocessing step).
def convert_constant_to_initializer(
nodes: list[NodeProto],
initializers: dict[str, TensorProto]
) -> list[NodeProto]Parameters:
nodes(list[NodeProto]): Model nodesinitializers(dict[str, TensorProto]): Initializer dictionary (modified in-place)
Returns: list[NodeProto] - Nodes with Constant nodes removed
Extract initializers from model.
def get_initializers(model: ModelProto) -> dict[str, TensorProto]Extract input nodes with proper shape formatting.
def get_input_nodes(
model: ModelProto,
initializers: dict[str, TensorProto],
has_batch_dim: bool
) -> list[ValueInfoProto]Extract output nodes with proper shape formatting.
def get_output_nodes(
model: ModelProto,
has_batch_dim: bool
) -> list[ValueInfoProto]ShapeONNX supports 46 operators across 10 categories:
Add, Sub, Mul, Div, Pow, Neg
Relu, LeakyRelu, Sigmoid, Tanh, Clip, Sin, Cos
Conv, ConvTranspose, MaxPool, AveragePool, GlobalAveragePool
BatchNormalization
Reshape, Transpose, Squeeze, Unsqueeze, Flatten, Expand
Slice, Split, Gather, Concat
Shape, ConstantOfShape, Range
ReduceMean, ReduceSum, ArgMax
Equal, Where, Max, Min
MatMul, Gemm
Cast, Dropout, Pad, Resize, Scatter, ScatterElements, ScatterND, Softmax, Floor, Sign
- Immutable Context: Frozen dataclass for shape inference context
- Pure Functions: All shape inference functions are stateless with explicit inputs
- Direct Dictionary Access: Minimal abstraction for performance
- Full Type Hints: Complete type annotations using Python 3.10+ syntax
- Single-Pass Forward Propagation: O(1) complexity per operator
- Pre-Converted Initializers: Integer tensors converted once at initialization
- Efficient Operator Dispatch: Dictionary-based operator function mapping
- Minimal Memory Allocations: Shape lists reused where possible
Benchmark: 140 VNN-COMP 2024 models processed in approximately 6.5 seconds on Intel i5-12400F.
shapeonnx/
├── __init__.py # Public API
├── infer_shape.py # Main shape inference
├── utils.py # Utility functions
├── context.py # Shape inference context
└── operators/ # Operator-specific inference
├── arithmetic.py # Add, Sub, Mul, Div, Pow, Neg
├── activation.py # Relu, Sigmoid, Tanh, etc.
├── conv_pool.py # Conv, Pool operations
├── normalization.py # BatchNormalization
├── tensor_ops.py # Reshape, Transpose, etc.
├── slicing.py # Slice, Gather, Concat
├── shape_ops.py # Shape, ConstantOfShape
├── reduction.py # ReduceMean, ReduceSum
├── comparison.py # Equal, Where, Max, Min
└── matrix.py # MatMul, Gemm
import onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.shapeonnx.utils import (
get_initializers,
get_input_nodes,
get_output_nodes,
convert_constant_to_initializer,
)
# Load model
model = onnx.load("resnet18.onnx")
# Prepare components
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
nodes = convert_constant_to_initializer(list(model.graph.node), initializers)
# Infer shapes
shapes = infer_onnx_shape(
input_nodes, output_nodes, nodes, initializers,
has_batch_dim=True, verbose=True
)
# Print all tensor shapes
for name, shape in sorted(shapes.items()):
print(f"{name}: {shape}")import onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.shapeonnx.utils import (
get_initializers,
get_input_nodes,
get_output_nodes,
convert_constant_to_initializer,
)
# Load and prepare model
model = onnx.load("model.onnx")
model = onnx.version_converter.convert_version(model, target_version=21)
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
nodes = convert_constant_to_initializer(list(model.graph.node), initializers)
# Infer shapes for optimization
shapes = infer_onnx_shape(
input_nodes, output_nodes, nodes, initializers,
has_batch_dim=True
)
# Use shapes for optimization decisions
for node in nodes:
for input_name in node.input:
if input_name in shapes:
input_shape = shapes[input_name]
# Make optimization decisions based on shape
if len(input_shape) == 2:
# Can apply matrix-specific optimizations
passimport onnx
from shapeonnx import infer_onnx_shape
from shapeonnx.shapeonnx.utils import (
get_initializers,
get_input_nodes,
get_output_nodes,
convert_constant_to_initializer,
)
# Model with Shape → Gather → Add → Reshape pattern
model = onnx.load("dynamic_reshape_model.onnx")
initializers = get_initializers(model)
input_nodes = get_input_nodes(model, initializers, has_batch_dim=True)
output_nodes = get_output_nodes(model, has_batch_dim=True)
nodes = convert_constant_to_initializer(list(model.graph.node), initializers)
# ShapeONNX resolves shape chains to static values
shapes = infer_onnx_shape(
input_nodes, output_nodes, nodes, initializers,
has_batch_dim=True, verbose=True
)
# Dynamic reshape operations now have static target shapes
print("Resolved static shapes for all tensors")ShapeONNX has been extensively tested on models from VNN-COMP 2024:
- Total Models Tested: 140 diverse neural networks
- Success Rate: 100% (all models successfully processed)
- Model Types: CNNs, ResNets, VGG, GANs, Transformers, Graph Neural Networks
- Opset Coverage: Opset 17-21
cd shapeonnx/test
python test_baseline.pyExpected output: Tested: 140/140, Passed: 140/140
The test suite includes baseline comparison to detect regressions. Shapes for all 140 models are stored in test/baselines/ and compared against current inference results.
Hardware: Intel i5-12400F (6 cores, 12 threads)
Results:
- 100+ VNN-COMP models: ~6.5 seconds total
- Average per model: ~46 milliseconds
- Complex models (VIT, ResNet50): <200ms
- Simple models (ACAS Xu): <10ms
Memory: Typical peak memory usage under 500MB for largest models.
- Constant nodes must be converted to initializers before shape inference
- Asymmetric padding in Conv/Pool operations not supported
- Control flow operators (If, Loop, Scan) not supported
- Some operators have limited attribute support
- Assumes static input shapes (dynamic batch size handled via
has_batch_dimflag)
- SlimONNX: ONNX model optimization. Uses ShapeONNX for shape-dependent optimizations like constant folding and redundant operation removal.
- TorchVNNLIB: VNN-LIB to tensor converter for neural network verification.
- VNN-COMP: International Verification of Neural Networks Competition.
Contributions are welcome. Please:
- Fork the repository
- Create a feature branch
- Implement operator following existing patterns in
operators/directory - Add tests and verify baseline tests pass
- Run black formatter on all modified files
- Submit a pull request
Direct pushes to main branch are restricted.
To add a new operator:
def infer_<operator>_shape(
node: NodeProto,
ctx: ShapeInferenceContext
) -> list[tuple[int | list[int] | None, int | list[int] | None]]:
"""Infer shape for <Operator> node.
:param node: ONNX node
:param ctx: Shape inference context
:return: List of (data_shape, explicit_shape) tuples
"""
# Implementation
return [(output_shape, None)]Then register in INFER_SHAPE_FUNC_MAPPING dictionary and add test cases.
MIT License. See LICENSE file for details.