torchonnx is a compiler-based tool that converts ONNX models (.onnx files) into native PyTorch models (.pth files for parameters and .py files for model structure).
Extensively tested on VNNCOMP 2024 benchmarks including Vision Transformers, CNNs, and complex neural network architectures.
While PyTorch provides the torch.onnx module to convert PyTorch models to ONNX, the reverse process—converting ONNX models back to PyTorch—is not officially supported. This tool addresses this gap for several key reasons:
-
Version Fragmentation: ONNX model format evolves across versions, with different versions supporting different operations. This creates significant compatibility challenges when working with models from various sources.
-
Framework Inconsistencies: There are numerous inconsistencies between ONNX and PyTorch models in terms of naming conventions, parameter handling, and operational semantics. PyTorch does not officially support reverse conversion, likely considering it unnecessary for their ecosystem.
-
Neural Network Verification Requirements: For the Neural Network Verification (NNV) community, ONNX has become the unified model format. Being able to work with these models natively in PyTorch is essential for research and verification tasks.
-
Code Quality and Maintainability: ONNX's computational graph representation does not always align with logical groupings that make sense in PyTorch. We need a tool that generates clean, maintainable PyTorch code.
While other tools exist for ONNX-to-PyTorch conversion, most fall short in performance and code quality. The most well-known tool, onnx2pytorch, serves as a runtime wrapper rather than a true compiler. Its forward method iterates over ONNX nodes at runtime instead of generating static PyTorch code, and parameter conversion is inefficient.
torchonnx takes a different approach: it is a true compiler that generates clean, efficient PyTorch code. The tool converts ONNX models into two separate files:
- A .py file defining the neural network structure as native PyTorch code
- A .pth file containing the model parameters as a state dictionary
This design eliminates runtime overhead and produces code that is readable, maintainable, and performs identically to hand-written PyTorch models.
- Zero Runtime Overhead: Generated PyTorch code has no ONNX dependencies and runs at native PyTorch speed
- Static Code Generation: All operations are compiled to clean Python code, not interpreted at runtime
- Optimized Parameter Handling: Intelligent tracking eliminates unused parameters, reducing model size
- Cached Constants: Constant tensors are registered as buffers for efficient device management
- Idiomatic PyTorch: Uses native PyTorch operations, type conversions, and best practices throughout
- Complete Type Hints: All generated code includes full type annotations for Python 3.10+
- Clean Structure: Human-readable modules with proper naming, documentation, and organization
- No Dead Code: Automatic elimination of unused operations, parameters, and buffers
- Code Optimization: Post-processing removes default arguments and converts to positional arguments
- Formatted Output: All code formatted with
blackfor consistency
- Pure Python Implementation: No compiled dependencies, easy to inspect and modify
- Modular Architecture: Clean 6-stage compiler pipeline with separation of concerns
- Easy to Extend: Add new operations or modify existing ones without breaking the codebase
- Well-Documented: reStructuredText docstrings with
:param:and:return:annotations
- VNNCOMP 2024 Benchmarks: Extensively tested on official neural network verification competition benchmarks
- Diverse Model Coverage: Successfully converts Vision Transformers, CNNs, MLPs, and complex architectures
- Validated Output: Generated models produce numerically identical results to original ONNX models
TorchONNX implements a 6-stage compiler pipeline that transforms ONNX models into optimized PyTorch code:
Loads and normalizes ONNX models to a consistent format:
- Model validation using ONNX checker
- Opset version conversion (target: opset 20)
- Shape inference using ONNX shape inference or shapeonnx
- Metadata cleanup
Key Files: normalize/normalize.py, normalize/utils.py
Extracts pure structural information from ONNX graph:
- Builds
ModelIRcontaining list ofNodeIRinstances - Captures graph topology, tensor shapes, and initializers
- No semantic interpretation at this stage (pure structural representation)
Key Files: build/builder.py, build/types.py
Transforms structural IR into semantic IR with PyTorch types:
- Classifies initializers into parameters (trainable), constants (buffers), and arguments (literals)
- Maps ONNX operations to PyTorch types (layers, functions, operators)
- Resolves tensor data types and shapes
- Builds
SemanticModelIRwith typed inputs (VariableInfo,ParameterInfo,ConstantInfo,ArgumentInfo)
Key Files:
analyze/builder.py- Main semantic IR builderanalyze/types.py- Semantic type definitionsanalyze/tensor_classifier.py- Tensor classification logicanalyze/type_mapping/- ONNX to PyTorch type mappingsanalyze/attr_extractor.py- ONNX attribute extraction
Optimizes semantic IR before code generation:
- Constant folding (future)
- Dead code elimination (future)
- Operation fusion (future)
Key Files: optimize/optimizer.py
Generates PyTorch module code from semantic IR:
__init__method: Parameter/constant registration and layer constructionforwardmethod: Operation-by-operation code generation using handlers- State dict: Parameter and constant tensors
- Import statements and module structure
Key Files:
generate/code_generator.py- Main orchestratorgenerate/_init_gen.py-__init__method generationgenerate/_forward_gen.py-forwardmethod generationgenerate/_state_dict_gen.py- State dict buildinggenerate/_templates.py- Code templatesgenerate/_handlers/- Operation-specific code generators
Operation Handlers:
_layers.py- Layer handlers (nn.Conv2d, nn.Linear, etc.)_operators.py- Operator handlers (torch.add, torch.matmul, etc.)_operations.py- Function handlers (reshape, concat, slice, etc.)_registry.py- Handler registration system
Post-processes generated code for cleanliness:
- Removes unused buffer registrations using regex parsing
- Removes default arguments from layer constructors (e.g.,
bias=True→ removed) - Removes default arguments from functions (e.g.,
F.relu(x, inplace=False)→F.relu(x)) - Converts named arguments to positional where appropriate (e.g.,
nn.Conv2d(in_channels=3, out_channels=64)→nn.Conv2d(3, 64)) - Filters state dict to exclude removed buffers
Key Files:
simplify/_optimizer.py- Main optimizer orchestratorsimplify/_line_optimizer.py- Line-by-line optimizationsimplify/_rules.py- Optimization rules and patterns
torchonnx/
├── torchonnx/
│ ├── __init__.py # Exports TorchONNX class
│ ├── _torchonnx.py # TorchONNX class (main API)
│ ├── normalize/ # Stage 1: ONNX normalization
│ │ ├── __init__.py
│ │ ├── normalize.py # Model preprocessing
│ │ └── utils.py # ONNX utilities
│ ├── build/ # Stage 2: Structural IR
│ │ ├── __init__.py
│ │ ├── builder.py # IR builder
│ │ └── types.py # NodeIR, ModelIR types
│ ├── analyze/ # Stage 3: Semantic IR
│ │ ├── __init__.py
│ │ ├── builder.py # Semantic IR builder
│ │ ├── types.py # Semantic type definitions
│ │ ├── tensor_classifier.py # Tensor classification
│ │ ├── attr_extractor.py # Attribute extraction
│ │ └── type_mapping/ # ONNX → PyTorch mappings
│ │ ├── _layers.py # Layer type mappings
│ │ └── _operations.py # Operation type mappings
│ ├── optimize/ # Stage 4: IR optimization
│ │ ├── __init__.py
│ │ └── optimizer.py # IR-level optimizations
│ ├── generate/ # Stage 5: Code generation
│ │ ├── __init__.py
│ │ ├── code_generator.py # Main code generator
│ │ ├── _init_gen.py # __init__ generation
│ │ ├── _forward_gen.py # forward() generation
│ │ ├── _state_dict_gen.py # State dict building
│ │ ├── _templates.py # Code templates
│ │ ├── _utils.py # Helper utilities
│ │ └── _handlers/ # Operation-specific handlers
│ │ ├── __init__.py
│ │ ├── _registry.py # Handler registry
│ │ ├── _layers.py # Layer handlers
│ │ ├── _operators.py # Operator handlers
│ │ └── _operations.py # Function handlers
│ └── simplify/ # Stage 6: Code optimization
│ ├── __init__.py
│ ├── _optimizer.py # Main optimizer
│ ├── _line_optimizer.py # Line optimizations
│ └── _rules.py # Optimization rules
├── test/ # Testing infrastructure
│ ├── benchmarks/ # Original ONNX files
│ ├── baselines/ # Expected outputs
│ ├── results/ # Generated outputs
│ ├── analyze_model_nodes.py # Model node analyzer
│ ├── build_benchmarks.py # Benchmark builder
│ ├── test_benchmarks.py # VNNCOMP 2024 tests
│ └── utils.py # Test utilities
└── README.md
There are no complex installation steps. The tool requires:
- Python 3.10 or higher (tested with Python 3.12)
onnx==1.17.0(For compatibility)onnxruntime=1.20(for model validation and testing)numpy==2.2.4torch(any recent 2.x version compatible with your system)
Please refer to the official PyTorch and ONNX installation guides for platform-specific instructions.
from torchonnx import TorchONNX
if __name__ == "__main__":
# Create converter instance
converter = TorchONNX(verbose=True)
# Convert ONNX model to PyTorch
converter.convert(
onnx_path="model.onnx",
benchmark_name="mymodel", # Optional: for module naming
target_py_path="model.py", # Optional: defaults to model.py
target_pth_path="model.pth" # Optional: defaults to model.pth
)The following example demonstrates conversion of a Vision Transformer (ViT) model from VNNCOMP 2023. Note that you should use slimonnx to simplify the model first, as the original may contain unsupported operations.
You can visualize the ONNX computational graph using netron.app.
from torchonnx import TorchONNX
if __name__ == "__main__":
file_path = "../nets/ibp_3_3_8_v22_simplified.onnx"
converter = TorchONNX(verbose=True)
converter.convert(file_path)The following shows generated PyTorch code for the ViT model. Note the clean structure, proper parameter registration, and readable forward pass:
__all__ = ["Vit2023Ibp338"]
import torch
import torch.nn as nn
def dynamic_slice(data, starts, ends, axes=None, steps=None):
"""Dynamic slice helper for ONNX Slice operation."""
# Ensure tensor
starts = torch.as_tensor(starts, device=data.device)
ends = torch.as_tensor(ends, device=data.device)
if axes is None:
axes = torch.arange(starts.numel(), device=data.device)
else:
axes = torch.as_tensor(axes, device=data.device)
if steps is None:
steps = torch.ones_like(starts)
else:
steps = torch.as_tensor(steps, device=data.device)
# Normalize negative starts/ends
dims = torch.as_tensor(data.shape, device=data.device)
# axes tells where to read dim size
dim_sizes = dims[axes]
starts = torch.where(starts < 0, dim_sizes + starts, starts)
ends = torch.where(ends < 0, dim_sizes + ends, ends)
# Clip to bounds (ONNX semantics)
# Use tensors for both min and max to avoid type mismatch
zero = torch.zeros_like(dim_sizes)
starts = torch.clamp(starts, min=zero, max=dim_sizes)
ends = torch.clamp(ends, min=zero, max=dim_sizes)
# Build index tuple dynamically
index = [slice(None)] * data.ndim
for i in range(axes.shape[0]):
ax = axes[i].item()
idx = torch.arange(starts[i], ends[i], steps[i], device=data.device)
index[ax] = idx
return data[tuple(index)]
class Vit2023Ibp338(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("c4", torch.empty([1], dtype=torch.int64))
self.register_buffer("c6", torch.empty([1], dtype=torch.int64))
self.register_buffer("c7", torch.empty([1], dtype=torch.int64))
self.register_buffer("c8", torch.empty([48], dtype=torch.float32))
self.register_buffer("c9", torch.empty([17, 48], dtype=torch.float32))
self.register_buffer("c11", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c12", torch.empty([48], dtype=torch.float32))
self.register_buffer("c13", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c14", torch.empty([48], dtype=torch.float32))
self.register_buffer("c15", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c16", torch.empty([48], dtype=torch.float32))
self.register_buffer("c18", torch.empty([1], dtype=torch.int64))
self.register_buffer("c19", torch.empty([1], dtype=torch.int64))
self.register_buffer("c20", torch.empty([1], dtype=torch.int64))
self.register_buffer("c22", torch.empty([1], dtype=torch.int64))
self.register_buffer("c23", torch.empty([1], dtype=torch.int64))
self.register_buffer("c24", torch.empty([1], dtype=torch.int64))
self.register_buffer("c26", torch.empty([1], dtype=torch.int64))
self.register_buffer("c27", torch.empty([1], dtype=torch.int64))
self.register_buffer("c28", torch.empty([1], dtype=torch.int64))
self.register_buffer("c31", torch.empty([1], dtype=torch.int64))
self.register_buffer("c32", torch.empty([1], dtype=torch.int64))
self.register_buffer("c33", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c34", torch.empty([48], dtype=torch.float32))
self.register_buffer("c35", torch.empty([48, 96], dtype=torch.float32))
self.register_buffer("c36", torch.empty([96], dtype=torch.float32))
self.register_buffer("c37", torch.empty([96, 48], dtype=torch.float32))
self.register_buffer("c38", torch.empty([48], dtype=torch.float32))
self.register_buffer("c40", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c41", torch.empty([48], dtype=torch.float32))
self.register_buffer("c42", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c43", torch.empty([48], dtype=torch.float32))
self.register_buffer("c44", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c45", torch.empty([48], dtype=torch.float32))
self.register_buffer("c47", torch.empty([1], dtype=torch.int64))
self.register_buffer("c48", torch.empty([1], dtype=torch.int64))
self.register_buffer("c49", torch.empty([1], dtype=torch.int64))
self.register_buffer("c51", torch.empty([1], dtype=torch.int64))
self.register_buffer("c52", torch.empty([1], dtype=torch.int64))
self.register_buffer("c53", torch.empty([1], dtype=torch.int64))
self.register_buffer("c55", torch.empty([1], dtype=torch.int64))
self.register_buffer("c56", torch.empty([1], dtype=torch.int64))
self.register_buffer("c57", torch.empty([1], dtype=torch.int64))
self.register_buffer("c60", torch.empty([1], dtype=torch.int64))
self.register_buffer("c61", torch.empty([1], dtype=torch.int64))
self.register_buffer("c62", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c63", torch.empty([48], dtype=torch.float32))
self.register_buffer("c64", torch.empty([48, 96], dtype=torch.float32))
self.register_buffer("c65", torch.empty([96], dtype=torch.float32))
self.register_buffer("c66", torch.empty([96, 48], dtype=torch.float32))
self.register_buffer("c67", torch.empty([48], dtype=torch.float32))
self.register_buffer("c69", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c70", torch.empty([48], dtype=torch.float32))
self.register_buffer("c71", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c72", torch.empty([48], dtype=torch.float32))
self.register_buffer("c73", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c74", torch.empty([48], dtype=torch.float32))
self.register_buffer("c76", torch.empty([1], dtype=torch.int64))
self.register_buffer("c77", torch.empty([1], dtype=torch.int64))
self.register_buffer("c78", torch.empty([1], dtype=torch.int64))
self.register_buffer("c80", torch.empty([1], dtype=torch.int64))
self.register_buffer("c81", torch.empty([1], dtype=torch.int64))
self.register_buffer("c82", torch.empty([1], dtype=torch.int64))
self.register_buffer("c84", torch.empty([1], dtype=torch.int64))
self.register_buffer("c85", torch.empty([1], dtype=torch.int64))
self.register_buffer("c86", torch.empty([1], dtype=torch.int64))
self.register_buffer("c89", torch.empty([1], dtype=torch.int64))
self.register_buffer("c90", torch.empty([1], dtype=torch.int64))
self.register_buffer("c91", torch.empty([48, 48], dtype=torch.float32))
self.register_buffer("c92", torch.empty([48], dtype=torch.float32))
self.register_buffer("c93", torch.empty([48, 96], dtype=torch.float32))
self.register_buffer("c94", torch.empty([96], dtype=torch.float32))
self.register_buffer("c95", torch.empty([96, 48], dtype=torch.float32))
self.register_buffer("c96", torch.empty([48], dtype=torch.float32))
self.register_buffer("c97", torch.empty([1], dtype=torch.int64))
self.conv2d1 = nn.Conv2d(3, 48, 8, stride=8)
self.batchnorm2d1 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.flatten1 = nn.Flatten(3)
self.softmax1 = nn.Softmax(-1)
self.batchnorm2d2 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.relu1 = nn.ReLU()
self.batchnorm2d3 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.flatten2 = nn.Flatten(3)
self.softmax2 = nn.Softmax(-1)
self.batchnorm2d4 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.relu2 = nn.ReLU()
self.batchnorm2d5 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.flatten3 = nn.Flatten(3)
self.softmax3 = nn.Softmax(-1)
self.batchnorm2d6 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.relu3 = nn.ReLU()
self.batchnorm2d7 = nn.BatchNorm2d(
48, eps=9.999999747378752e-06, momentum=0.10000002384185791
)
self.linear1 = nn.Linear(48, 10)
def forward(self, x0):
x1 = torch.tensor(x0.shape, dtype=torch.int64)
x2 = x1[0]
x3 = self.conv2d1(x0)
x4 = torch.tensor(x3.shape, dtype=torch.int64)
x5 = x4[0:2]
x6 = torch.cat([x5, self.c4])
x7 = x3.reshape([int(x) for x in x6.tolist()])
x8 = x7.permute((0, 2, 1))
x9 = x2.unsqueeze(0)
x10 = torch.cat([x9, self.c6, self.c7])
x11 = torch.full(x10.tolist(), 0.0, dtype=torch.float32)
x12 = x11 + self.c8
x13 = torch.cat([x12, x8], dim=1)
x14 = x13 + self.c9
x15 = x14.permute((0, 2, 1))
x16 = self.batchnorm2d1(x15.unsqueeze(2)).squeeze(2)
x17 = x16.permute((0, 2, 1))
x18 = torch.tensor(x17.shape, dtype=torch.int64)
x19 = x18[0]
x20 = x17 @ self.c11
x21 = self.c12 + x20
x22 = x17 @ self.c13
x23 = self.c14 + x22
x24 = x17 @ self.c15
x25 = self.c16 + x24
x26 = x19.unsqueeze(0)
x27 = torch.cat([x26, self.c18, self.c19, self.c20])
x28 = x19.unsqueeze(0)
x29 = torch.cat([x28, self.c22, self.c23, self.c24])
x30 = x19.unsqueeze(0)
x31 = torch.cat([x30, self.c26, self.c27, self.c28])
x32 = x21.reshape([int(x) for x in x27.tolist()])
x33 = x32.permute((0, 2, 1, 3))
x34 = x23.reshape([int(x) for x in x29.tolist()])
x35 = x25.reshape([int(x) for x in x31.tolist()])
x36 = x35.permute((0, 2, 1, 3))
x37 = x34.permute((0, 2, 3, 1))
x38 = x33 @ x37
x39 = x38 * 0.25
x40 = torch.tensor(x39.shape, dtype=torch.int64)
x41 = self.flatten1(x39)
x42 = self.softmax1(x41)
x43 = x42.reshape([int(x) for x in x40.tolist()])
x44 = x43 @ x36
x45 = x44.permute((0, 2, 1, 3))
x46 = x19.unsqueeze(0)
x47 = torch.cat([x46, self.c31, self.c32])
x48 = x45.reshape([int(x) for x in x47.tolist()])
x49 = x48 @ self.c33
x50 = self.c34 + x49
x51 = x50 + x14
x52 = x51.permute((0, 2, 1))
x53 = self.batchnorm2d2(x52.unsqueeze(2)).squeeze(2)
x54 = x53.permute((0, 2, 1))
x55 = x54 @ self.c35
x56 = self.c36 + x55
x57 = self.relu1(x56)
x58 = x57 @ self.c37
x59 = self.c38 + x58
x60 = x59 + x51
x61 = x60.permute((0, 2, 1))
x62 = self.batchnorm2d3(x61.unsqueeze(2)).squeeze(2)
x63 = x62.permute((0, 2, 1))
x64 = torch.tensor(x63.shape, dtype=torch.int64)
x65 = x64[0]
x66 = x63 @ self.c40
x67 = self.c41 + x66
x68 = x63 @ self.c42
x69 = self.c43 + x68
x70 = x63 @ self.c44
x71 = self.c45 + x70
x72 = x65.unsqueeze(0)
x73 = torch.cat([x72, self.c47, self.c48, self.c49])
x74 = x65.unsqueeze(0)
x75 = torch.cat([x74, self.c51, self.c52, self.c53])
x76 = x65.unsqueeze(0)
x77 = torch.cat([x76, self.c55, self.c56, self.c57])
x78 = x67.reshape([int(x) for x in x73.tolist()])
x79 = x78.permute((0, 2, 1, 3))
x80 = x69.reshape([int(x) for x in x75.tolist()])
x81 = x71.reshape([int(x) for x in x77.tolist()])
x82 = x81.permute((0, 2, 1, 3))
x83 = x80.permute((0, 2, 3, 1))
x84 = x79 @ x83
x85 = x84 * 0.25
x86 = torch.tensor(x85.shape, dtype=torch.int64)
x87 = self.flatten2(x85)
x88 = self.softmax2(x87)
x89 = x88.reshape([int(x) for x in x86.tolist()])
x90 = x89 @ x82
x91 = x90.permute((0, 2, 1, 3))
x92 = x65.unsqueeze(0)
x93 = torch.cat([x92, self.c60, self.c61])
x94 = x91.reshape([int(x) for x in x93.tolist()])
x95 = x94 @ self.c62
x96 = self.c63 + x95
x97 = x96 + x60
x98 = x97.permute((0, 2, 1))
x99 = self.batchnorm2d4(x98.unsqueeze(2)).squeeze(2)
x100 = x99.permute((0, 2, 1))
x101 = x100 @ self.c64
x102 = self.c65 + x101
x103 = self.relu2(x102)
x104 = x103 @ self.c66
x105 = self.c67 + x104
x106 = x105 + x97
x107 = x106.permute((0, 2, 1))
x108 = self.batchnorm2d5(x107.unsqueeze(2)).squeeze(2)
x109 = x108.permute((0, 2, 1))
x110 = torch.tensor(x109.shape, dtype=torch.int64)
x111 = x110[0]
x112 = x109 @ self.c69
x113 = self.c70 + x112
x114 = x109 @ self.c71
x115 = self.c72 + x114
x116 = x109 @ self.c73
x117 = self.c74 + x116
x118 = x111.unsqueeze(0)
x119 = torch.cat([x118, self.c76, self.c77, self.c78])
x120 = x111.unsqueeze(0)
x121 = torch.cat([x120, self.c80, self.c81, self.c82])
x122 = x111.unsqueeze(0)
x123 = torch.cat([x122, self.c84, self.c85, self.c86])
x124 = x113.reshape([int(x) for x in x119.tolist()])
x125 = x124.permute((0, 2, 1, 3))
x126 = x115.reshape([int(x) for x in x121.tolist()])
x127 = x117.reshape([int(x) for x in x123.tolist()])
x128 = x127.permute((0, 2, 1, 3))
x129 = x126.permute((0, 2, 3, 1))
x130 = x125 @ x129
x131 = x130 * 0.25
x132 = torch.tensor(x131.shape, dtype=torch.int64)
x133 = self.flatten3(x131)
x134 = self.softmax3(x133)
x135 = x134.reshape([int(x) for x in x132.tolist()])
x136 = x135 @ x128
x137 = x136.permute((0, 2, 1, 3))
x138 = x111.unsqueeze(0)
x139 = torch.cat([x138, self.c89, self.c90])
x140 = x137.reshape([int(x) for x in x139.tolist()])
x141 = x140 @ self.c91
x142 = self.c92 + x141
x143 = x142 + x106
x144 = x143.permute((0, 2, 1))
x145 = self.batchnorm2d6(x144.unsqueeze(2)).squeeze(2)
x146 = x145.permute((0, 2, 1))
x147 = x146 @ self.c93
x148 = self.c94 + x147
x149 = self.relu3(x148)
x150 = x149 @ self.c95
x151 = self.c96 + x150
x152 = x151 + x143
x153 = torch.mean(x152, self.c97.tolist(), keepdim=False)
x154 = self.batchnorm2d7(x153.unsqueeze(2).unsqueeze(3)).squeeze(2).squeeze(2)
x155 = self.linear1(x154)
return x155
TorchONNX is extensively tested on the VNNCOMP 2024 benchmarks, the official benchmark suite for neural network verification competitions. The test suite includes:
- Vision Transformers (ViT): Complex transformer architectures with attention mechanisms
- Convolutional Neural Networks: Various CNN architectures from traffic sign detection to autonomous control
- Feedforward Networks: MLPs with various activation functions and normalizations
- Hybrid Architectures: Models combining multiple architectural patterns
All converted models are validated to produce numerically identical outputs to their original ONNX counterparts, ensuring correctness across diverse model types and operations.
To test with VNNCOMP 2024 benchmarks, clone the vnncomp2024 repository and ensure the following structure:
torchonnx/
│ ├── torchonnx/
│ ├── README.md
│ └── test/
└── ...
vnncomp2024/
│ ├── benchmarks/
└── ...
Then run the test suite:
cd torchonnx/test
python test_benchmarks.pyThe tool implements most commonly used operations in feedforward neural networks and transformers:
- Convolution: Conv1d, Conv2d, ConvTranspose1d, ConvTranspose2d
- Pooling: MaxPool2d, AvgPool2d, AdaptiveAvgPool2d
- Normalization: BatchNorm2d (with automatic dimension handling)
- Activation: ReLU, LeakyReLU, Sigmoid, Tanh, Softmax, ELU, GELU
- Linear: Linear
- Dropout: Dropout
- Upsampling: Upsample
- Shape Operations: Flatten
- Convolution: F.conv, F.conv_transpose
- Linear: F.linear
- Pooling: F.interpolate
- Padding: F.pad
- Concatenation: torch.cat
- Indexing: torch.gather, scatter_nd
- Reduction: torch.mean, torch.sum, torch.min, torch.max, torch.argmax
- Clipping: torch.clamp
- Conditional: torch.where
- Generation: torch.full, torch.arange
- Arithmetic: add (+), sub (-), mul (*), div (/), matmul (@), pow (pow), neg (neg)
- Comparison: equal (==)
- Shape: reshape, permute, squeeze, unsqueeze, shape, expand, cast
- Slicing: slice, split
- Math: sign, cos, sin, floor
Transformer-based architectures are decomposed into basic operations and handled correctly.
- ShapeONNX: Advanced shape inference for ONNX models. SlimONNX uses ShapeONNX for shape-dependent optimizations.
- TorchVNNLIB: PyTorch library for neural network verification. Often used in conjunction with SlimONNX for model verification tasks. This convert the VNNLIB data files to
.pthformat for PyTorch or.npzformat for NumPy. - SlimONNX: ONNX model simplification tool that removes redundant operations and optimizes the graph before conversion.
- VNN-COMP: International Verification of Neural Networks Competition. SlimONNX is tested on all VNN-COMP 2024 benchmarks.
- ONNX Simplifier: Alternative ONNX optimization tool with different optimization strategies.
Contributions are welcome from the community. Whether fixing bugs, adding features, improving documentation, or sharing ideas, all contributions are appreciated.
Note: Direct pushes to the main branch are restricted. Please fork the repository and submit a Pull Request for any changes.
This project is licensed under the MIT License. See the LICENSE file for details.