Skip to content

ZhongkuiMa/torchonnx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

72 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TorchONNX: Convert ONNX Model to PyTorch Model

Python 3.10+ PyTorch 2.x ONNX 1.17 License: MIT Code style: black PRs Welcome

torchonnx is a compiler-based tool that converts ONNX models (.onnx files) into native PyTorch models (.pth files for parameters and .py files for model structure).

Extensively tested on VNNCOMP 2024 benchmarks including Vision Transformers, CNNs, and complex neural network architectures.

Motivation

While PyTorch provides the torch.onnx module to convert PyTorch models to ONNX, the reverse process—converting ONNX models back to PyTorch—is not officially supported. This tool addresses this gap for several key reasons:

  • Version Fragmentation: ONNX model format evolves across versions, with different versions supporting different operations. This creates significant compatibility challenges when working with models from various sources.

  • Framework Inconsistencies: There are numerous inconsistencies between ONNX and PyTorch models in terms of naming conventions, parameter handling, and operational semantics. PyTorch does not officially support reverse conversion, likely considering it unnecessary for their ecosystem.

  • Neural Network Verification Requirements: For the Neural Network Verification (NNV) community, ONNX has become the unified model format. Being able to work with these models natively in PyTorch is essential for research and verification tasks.

  • Code Quality and Maintainability: ONNX's computational graph representation does not always align with logical groupings that make sense in PyTorch. We need a tool that generates clean, maintainable PyTorch code.

Why TorchONNX?

While other tools exist for ONNX-to-PyTorch conversion, most fall short in performance and code quality. The most well-known tool, onnx2pytorch, serves as a runtime wrapper rather than a true compiler. Its forward method iterates over ONNX nodes at runtime instead of generating static PyTorch code, and parameter conversion is inefficient.

torchonnx takes a different approach: it is a true compiler that generates clean, efficient PyTorch code. The tool converts ONNX models into two separate files:

  1. A .py file defining the neural network structure as native PyTorch code
  2. A .pth file containing the model parameters as a state dictionary

This design eliminates runtime overhead and produces code that is readable, maintainable, and performs identically to hand-written PyTorch models.

Key Advantages

True Compiler Architecture

  • Zero Runtime Overhead: Generated PyTorch code has no ONNX dependencies and runs at native PyTorch speed
  • Static Code Generation: All operations are compiled to clean Python code, not interpreted at runtime
  • Optimized Parameter Handling: Intelligent tracking eliminates unused parameters, reducing model size
  • Cached Constants: Constant tensors are registered as buffers for efficient device management
  • Idiomatic PyTorch: Uses native PyTorch operations, type conversions, and best practices throughout

Production-Ready Code Quality

  • Complete Type Hints: All generated code includes full type annotations for Python 3.10+
  • Clean Structure: Human-readable modules with proper naming, documentation, and organization
  • No Dead Code: Automatic elimination of unused operations, parameters, and buffers
  • Code Optimization: Post-processing removes default arguments and converts to positional arguments
  • Formatted Output: All code formatted with black for consistency

Extensible and Maintainable

  • Pure Python Implementation: No compiled dependencies, easy to inspect and modify
  • Modular Architecture: Clean 6-stage compiler pipeline with separation of concerns
  • Easy to Extend: Add new operations or modify existing ones without breaking the codebase
  • Well-Documented: reStructuredText docstrings with :param: and :return: annotations

Comprehensive Testing

  • VNNCOMP 2024 Benchmarks: Extensively tested on official neural network verification competition benchmarks
  • Diverse Model Coverage: Successfully converts Vision Transformers, CNNs, MLPs, and complex architectures
  • Validated Output: Generated models produce numerically identical results to original ONNX models

Compiler Architecture

TorchONNX implements a 6-stage compiler pipeline that transforms ONNX models into optimized PyTorch code:

Stage 1: Normalization

Loads and normalizes ONNX models to a consistent format:

  • Model validation using ONNX checker
  • Opset version conversion (target: opset 20)
  • Shape inference using ONNX shape inference or shapeonnx
  • Metadata cleanup

Key Files: normalize/normalize.py, normalize/utils.py

Stage 2: Structural IR Building

Extracts pure structural information from ONNX graph:

  • Builds ModelIR containing list of NodeIR instances
  • Captures graph topology, tensor shapes, and initializers
  • No semantic interpretation at this stage (pure structural representation)

Key Files: build/builder.py, build/types.py

Stage 3: Semantic IR Building

Transforms structural IR into semantic IR with PyTorch types:

  • Classifies initializers into parameters (trainable), constants (buffers), and arguments (literals)
  • Maps ONNX operations to PyTorch types (layers, functions, operators)
  • Resolves tensor data types and shapes
  • Builds SemanticModelIR with typed inputs (VariableInfo, ParameterInfo, ConstantInfo, ArgumentInfo)

Key Files:

  • analyze/builder.py - Main semantic IR builder
  • analyze/types.py - Semantic type definitions
  • analyze/tensor_classifier.py - Tensor classification logic
  • analyze/type_mapping/ - ONNX to PyTorch type mappings
  • analyze/attr_extractor.py - ONNX attribute extraction

Stage 4: IR Optimization

Optimizes semantic IR before code generation:

  • Constant folding (future)
  • Dead code elimination (future)
  • Operation fusion (future)

Key Files: optimize/optimizer.py

Stage 5: Code Generation

Generates PyTorch module code from semantic IR:

  • __init__ method: Parameter/constant registration and layer construction
  • forward method: Operation-by-operation code generation using handlers
  • State dict: Parameter and constant tensors
  • Import statements and module structure

Key Files:

  • generate/code_generator.py - Main orchestrator
  • generate/_init_gen.py - __init__ method generation
  • generate/_forward_gen.py - forward method generation
  • generate/_state_dict_gen.py - State dict building
  • generate/_templates.py - Code templates
  • generate/_handlers/ - Operation-specific code generators

Operation Handlers:

  • _layers.py - Layer handlers (nn.Conv2d, nn.Linear, etc.)
  • _operators.py - Operator handlers (torch.add, torch.matmul, etc.)
  • _operations.py - Function handlers (reshape, concat, slice, etc.)
  • _registry.py - Handler registration system

Stage 6: Code Optimization

Post-processes generated code for cleanliness:

  • Removes unused buffer registrations using regex parsing
  • Removes default arguments from layer constructors (e.g., bias=True → removed)
  • Removes default arguments from functions (e.g., F.relu(x, inplace=False)F.relu(x))
  • Converts named arguments to positional where appropriate (e.g., nn.Conv2d(in_channels=3, out_channels=64)nn.Conv2d(3, 64))
  • Filters state dict to exclude removed buffers

Key Files:

  • simplify/_optimizer.py - Main optimizer orchestrator
  • simplify/_line_optimizer.py - Line-by-line optimization
  • simplify/_rules.py - Optimization rules and patterns

Module Structure

torchonnx/
├── torchonnx/
│   ├── __init__.py                   # Exports TorchONNX class
│   ├── _torchonnx.py                 # TorchONNX class (main API)
│   ├── normalize/                    # Stage 1: ONNX normalization
│   │   ├── __init__.py
│   │   ├── normalize.py              # Model preprocessing
│   │   └── utils.py                  # ONNX utilities
│   ├── build/                        # Stage 2: Structural IR
│   │   ├── __init__.py
│   │   ├── builder.py                # IR builder
│   │   └── types.py                  # NodeIR, ModelIR types
│   ├── analyze/                      # Stage 3: Semantic IR
│   │   ├── __init__.py
│   │   ├── builder.py                # Semantic IR builder
│   │   ├── types.py                  # Semantic type definitions
│   │   ├── tensor_classifier.py      # Tensor classification
│   │   ├── attr_extractor.py         # Attribute extraction
│   │   └── type_mapping/             # ONNX → PyTorch mappings
│   │       ├── _layers.py            # Layer type mappings
│   │       └── _operations.py        # Operation type mappings
│   ├── optimize/                     # Stage 4: IR optimization
│   │   ├── __init__.py
│   │   └── optimizer.py              # IR-level optimizations
│   ├── generate/                     # Stage 5: Code generation
│   │   ├── __init__.py
│   │   ├── code_generator.py         # Main code generator
│   │   ├── _init_gen.py              # __init__ generation
│   │   ├── _forward_gen.py           # forward() generation
│   │   ├── _state_dict_gen.py        # State dict building
│   │   ├── _templates.py             # Code templates
│   │   ├── _utils.py                 # Helper utilities
│   │   └── _handlers/                # Operation-specific handlers
│   │       ├── __init__.py
│   │       ├── _registry.py          # Handler registry
│   │       ├── _layers.py            # Layer handlers
│   │       ├── _operators.py         # Operator handlers
│   │       └── _operations.py        # Function handlers
│   └── simplify/                     # Stage 6: Code optimization
│       ├── __init__.py
│       ├── _optimizer.py             # Main optimizer
│       ├── _line_optimizer.py        # Line optimizations
│       └── _rules.py                 # Optimization rules
├── test/                             # Testing infrastructure
│   ├── benchmarks/                   # Original ONNX files
│   ├── baselines/                    # Expected outputs
│   ├── results/                      # Generated outputs
│   ├── analyze_model_nodes.py        # Model node analyzer
│   ├── build_benchmarks.py           # Benchmark builder
│   ├── test_benchmarks.py            # VNNCOMP 2024 tests
│   └── utils.py                      # Test utilities
└── README.md

Installation

There are no complex installation steps. The tool requires:

  • Python 3.10 or higher (tested with Python 3.12)
  • onnx==1.17.0 (For compatibility)
  • onnxruntime=1.20 (for model validation and testing)
  • numpy==2.2.4
  • torch (any recent 2.x version compatible with your system)

Please refer to the official PyTorch and ONNX installation guides for platform-specific instructions.

Usage

Basic Example

from torchonnx import TorchONNX

if __name__ == "__main__":
    # Create converter instance
    converter = TorchONNX(verbose=True)

    # Convert ONNX model to PyTorch
    converter.convert(
        onnx_path="model.onnx",
        benchmark_name="mymodel",  # Optional: for module naming
        target_py_path="model.py",  # Optional: defaults to model.py
        target_pth_path="model.pth"  # Optional: defaults to model.pth
    )

Advanced Example: ViT Model Conversion

The following example demonstrates conversion of a Vision Transformer (ViT) model from VNNCOMP 2023. Note that you should use slimonnx to simplify the model first, as the original may contain unsupported operations.

You can visualize the ONNX computational graph using netron.app.

from torchonnx import TorchONNX

if __name__ == "__main__":
    file_path = "../nets/ibp_3_3_8_v22_simplified.onnx"
    converter = TorchONNX(verbose=True)
    converter.convert(file_path)

Generated Code Example

The following shows generated PyTorch code for the ViT model. Note the clean structure, proper parameter registration, and readable forward pass:

__all__ = ["Vit2023Ibp338"]

import torch
import torch.nn as nn


def dynamic_slice(data, starts, ends, axes=None, steps=None):
    """Dynamic slice helper for ONNX Slice operation."""
    # Ensure tensor
    starts = torch.as_tensor(starts, device=data.device)
    ends = torch.as_tensor(ends, device=data.device)
    if axes is None:
        axes = torch.arange(starts.numel(), device=data.device)
    else:
        axes = torch.as_tensor(axes, device=data.device)
    if steps is None:
        steps = torch.ones_like(starts)
    else:
        steps = torch.as_tensor(steps, device=data.device)

    # Normalize negative starts/ends
    dims = torch.as_tensor(data.shape, device=data.device)
    # axes tells where to read dim size
    dim_sizes = dims[axes]

    starts = torch.where(starts < 0, dim_sizes + starts, starts)
    ends = torch.where(ends < 0, dim_sizes + ends, ends)

    # Clip to bounds (ONNX semantics)
    # Use tensors for both min and max to avoid type mismatch
    zero = torch.zeros_like(dim_sizes)
    starts = torch.clamp(starts, min=zero, max=dim_sizes)
    ends = torch.clamp(ends, min=zero, max=dim_sizes)

    # Build index tuple dynamically
    index = [slice(None)] * data.ndim
    for i in range(axes.shape[0]):
        ax = axes[i].item()
        idx = torch.arange(starts[i], ends[i], steps[i], device=data.device)
        index[ax] = idx

    return data[tuple(index)]


class Vit2023Ibp338(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("c4", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c6", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c7", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c8", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c9", torch.empty([17, 48], dtype=torch.float32))
        self.register_buffer("c11", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c12", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c13", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c14", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c15", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c16", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c18", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c19", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c20", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c22", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c23", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c24", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c26", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c27", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c28", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c31", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c32", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c33", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c34", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c35", torch.empty([48, 96], dtype=torch.float32))
        self.register_buffer("c36", torch.empty([96], dtype=torch.float32))
        self.register_buffer("c37", torch.empty([96, 48], dtype=torch.float32))
        self.register_buffer("c38", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c40", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c41", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c42", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c43", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c44", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c45", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c47", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c48", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c49", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c51", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c52", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c53", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c55", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c56", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c57", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c60", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c61", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c62", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c63", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c64", torch.empty([48, 96], dtype=torch.float32))
        self.register_buffer("c65", torch.empty([96], dtype=torch.float32))
        self.register_buffer("c66", torch.empty([96, 48], dtype=torch.float32))
        self.register_buffer("c67", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c69", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c70", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c71", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c72", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c73", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c74", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c76", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c77", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c78", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c80", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c81", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c82", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c84", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c85", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c86", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c89", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c90", torch.empty([1], dtype=torch.int64))
        self.register_buffer("c91", torch.empty([48, 48], dtype=torch.float32))
        self.register_buffer("c92", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c93", torch.empty([48, 96], dtype=torch.float32))
        self.register_buffer("c94", torch.empty([96], dtype=torch.float32))
        self.register_buffer("c95", torch.empty([96, 48], dtype=torch.float32))
        self.register_buffer("c96", torch.empty([48], dtype=torch.float32))
        self.register_buffer("c97", torch.empty([1], dtype=torch.int64))
        self.conv2d1 = nn.Conv2d(3, 48, 8, stride=8)
        self.batchnorm2d1 = nn.BatchNorm2d(
            48, eps=9.999999747378752e-06, momentum=0.10000002384185791
        )
        self.flatten1 = nn.Flatten(3)
        self.softmax1 = nn.Softmax(-1)
        self.batchnorm2d2 = nn.BatchNorm2d(
            48, eps=9.999999747378752e-06, momentum=0.10000002384185791
        )
        self.relu1 = nn.ReLU()
        self.batchnorm2d3 = nn.BatchNorm2d(
            48, eps=9.999999747378752e-06, momentum=0.10000002384185791
        )
        self.flatten2 = nn.Flatten(3)
        self.softmax2 = nn.Softmax(-1)
        self.batchnorm2d4 = nn.BatchNorm2d(
            48, eps=9.999999747378752e-06, momentum=0.10000002384185791
        )
        self.relu2 = nn.ReLU()
        self.batchnorm2d5 = nn.BatchNorm2d(
            48, eps=9.999999747378752e-06, momentum=0.10000002384185791
        )
        self.flatten3 = nn.Flatten(3)
        self.softmax3 = nn.Softmax(-1)
        self.batchnorm2d6 = nn.BatchNorm2d(
            48, eps=9.999999747378752e-06, momentum=0.10000002384185791
        )
        self.relu3 = nn.ReLU()
        self.batchnorm2d7 = nn.BatchNorm2d(
            48, eps=9.999999747378752e-06, momentum=0.10000002384185791
        )
        self.linear1 = nn.Linear(48, 10)

    def forward(self, x0):
        x1 = torch.tensor(x0.shape, dtype=torch.int64)
        x2 = x1[0]
        x3 = self.conv2d1(x0)
        x4 = torch.tensor(x3.shape, dtype=torch.int64)
        x5 = x4[0:2]
        x6 = torch.cat([x5, self.c4])
        x7 = x3.reshape([int(x) for x in x6.tolist()])
        x8 = x7.permute((0, 2, 1))
        x9 = x2.unsqueeze(0)
        x10 = torch.cat([x9, self.c6, self.c7])
        x11 = torch.full(x10.tolist(), 0.0, dtype=torch.float32)
        x12 = x11 + self.c8
        x13 = torch.cat([x12, x8], dim=1)
        x14 = x13 + self.c9
        x15 = x14.permute((0, 2, 1))
        x16 = self.batchnorm2d1(x15.unsqueeze(2)).squeeze(2)
        x17 = x16.permute((0, 2, 1))
        x18 = torch.tensor(x17.shape, dtype=torch.int64)
        x19 = x18[0]
        x20 = x17 @ self.c11
        x21 = self.c12 + x20
        x22 = x17 @ self.c13
        x23 = self.c14 + x22
        x24 = x17 @ self.c15
        x25 = self.c16 + x24
        x26 = x19.unsqueeze(0)
        x27 = torch.cat([x26, self.c18, self.c19, self.c20])
        x28 = x19.unsqueeze(0)
        x29 = torch.cat([x28, self.c22, self.c23, self.c24])
        x30 = x19.unsqueeze(0)
        x31 = torch.cat([x30, self.c26, self.c27, self.c28])
        x32 = x21.reshape([int(x) for x in x27.tolist()])
        x33 = x32.permute((0, 2, 1, 3))
        x34 = x23.reshape([int(x) for x in x29.tolist()])
        x35 = x25.reshape([int(x) for x in x31.tolist()])
        x36 = x35.permute((0, 2, 1, 3))
        x37 = x34.permute((0, 2, 3, 1))
        x38 = x33 @ x37
        x39 = x38 * 0.25
        x40 = torch.tensor(x39.shape, dtype=torch.int64)
        x41 = self.flatten1(x39)
        x42 = self.softmax1(x41)
        x43 = x42.reshape([int(x) for x in x40.tolist()])
        x44 = x43 @ x36
        x45 = x44.permute((0, 2, 1, 3))
        x46 = x19.unsqueeze(0)
        x47 = torch.cat([x46, self.c31, self.c32])
        x48 = x45.reshape([int(x) for x in x47.tolist()])
        x49 = x48 @ self.c33
        x50 = self.c34 + x49
        x51 = x50 + x14
        x52 = x51.permute((0, 2, 1))
        x53 = self.batchnorm2d2(x52.unsqueeze(2)).squeeze(2)
        x54 = x53.permute((0, 2, 1))
        x55 = x54 @ self.c35
        x56 = self.c36 + x55
        x57 = self.relu1(x56)
        x58 = x57 @ self.c37
        x59 = self.c38 + x58
        x60 = x59 + x51
        x61 = x60.permute((0, 2, 1))
        x62 = self.batchnorm2d3(x61.unsqueeze(2)).squeeze(2)
        x63 = x62.permute((0, 2, 1))
        x64 = torch.tensor(x63.shape, dtype=torch.int64)
        x65 = x64[0]
        x66 = x63 @ self.c40
        x67 = self.c41 + x66
        x68 = x63 @ self.c42
        x69 = self.c43 + x68
        x70 = x63 @ self.c44
        x71 = self.c45 + x70
        x72 = x65.unsqueeze(0)
        x73 = torch.cat([x72, self.c47, self.c48, self.c49])
        x74 = x65.unsqueeze(0)
        x75 = torch.cat([x74, self.c51, self.c52, self.c53])
        x76 = x65.unsqueeze(0)
        x77 = torch.cat([x76, self.c55, self.c56, self.c57])
        x78 = x67.reshape([int(x) for x in x73.tolist()])
        x79 = x78.permute((0, 2, 1, 3))
        x80 = x69.reshape([int(x) for x in x75.tolist()])
        x81 = x71.reshape([int(x) for x in x77.tolist()])
        x82 = x81.permute((0, 2, 1, 3))
        x83 = x80.permute((0, 2, 3, 1))
        x84 = x79 @ x83
        x85 = x84 * 0.25
        x86 = torch.tensor(x85.shape, dtype=torch.int64)
        x87 = self.flatten2(x85)
        x88 = self.softmax2(x87)
        x89 = x88.reshape([int(x) for x in x86.tolist()])
        x90 = x89 @ x82
        x91 = x90.permute((0, 2, 1, 3))
        x92 = x65.unsqueeze(0)
        x93 = torch.cat([x92, self.c60, self.c61])
        x94 = x91.reshape([int(x) for x in x93.tolist()])
        x95 = x94 @ self.c62
        x96 = self.c63 + x95
        x97 = x96 + x60
        x98 = x97.permute((0, 2, 1))
        x99 = self.batchnorm2d4(x98.unsqueeze(2)).squeeze(2)
        x100 = x99.permute((0, 2, 1))
        x101 = x100 @ self.c64
        x102 = self.c65 + x101
        x103 = self.relu2(x102)
        x104 = x103 @ self.c66
        x105 = self.c67 + x104
        x106 = x105 + x97
        x107 = x106.permute((0, 2, 1))
        x108 = self.batchnorm2d5(x107.unsqueeze(2)).squeeze(2)
        x109 = x108.permute((0, 2, 1))
        x110 = torch.tensor(x109.shape, dtype=torch.int64)
        x111 = x110[0]
        x112 = x109 @ self.c69
        x113 = self.c70 + x112
        x114 = x109 @ self.c71
        x115 = self.c72 + x114
        x116 = x109 @ self.c73
        x117 = self.c74 + x116
        x118 = x111.unsqueeze(0)
        x119 = torch.cat([x118, self.c76, self.c77, self.c78])
        x120 = x111.unsqueeze(0)
        x121 = torch.cat([x120, self.c80, self.c81, self.c82])
        x122 = x111.unsqueeze(0)
        x123 = torch.cat([x122, self.c84, self.c85, self.c86])
        x124 = x113.reshape([int(x) for x in x119.tolist()])
        x125 = x124.permute((0, 2, 1, 3))
        x126 = x115.reshape([int(x) for x in x121.tolist()])
        x127 = x117.reshape([int(x) for x in x123.tolist()])
        x128 = x127.permute((0, 2, 1, 3))
        x129 = x126.permute((0, 2, 3, 1))
        x130 = x125 @ x129
        x131 = x130 * 0.25
        x132 = torch.tensor(x131.shape, dtype=torch.int64)
        x133 = self.flatten3(x131)
        x134 = self.softmax3(x133)
        x135 = x134.reshape([int(x) for x in x132.tolist()])
        x136 = x135 @ x128
        x137 = x136.permute((0, 2, 1, 3))
        x138 = x111.unsqueeze(0)
        x139 = torch.cat([x138, self.c89, self.c90])
        x140 = x137.reshape([int(x) for x in x139.tolist()])
        x141 = x140 @ self.c91
        x142 = self.c92 + x141
        x143 = x142 + x106
        x144 = x143.permute((0, 2, 1))
        x145 = self.batchnorm2d6(x144.unsqueeze(2)).squeeze(2)
        x146 = x145.permute((0, 2, 1))
        x147 = x146 @ self.c93
        x148 = self.c94 + x147
        x149 = self.relu3(x148)
        x150 = x149 @ self.c95
        x151 = self.c96 + x150
        x152 = x151 + x143
        x153 = torch.mean(x152, self.c97.tolist(), keepdim=False)
        x154 = self.batchnorm2d7(x153.unsqueeze(2).unsqueeze(3)).squeeze(2).squeeze(2)
        x155 = self.linear1(x154)
        return x155

Testing & Validation

TorchONNX is extensively tested on the VNNCOMP 2024 benchmarks, the official benchmark suite for neural network verification competitions. The test suite includes:

  • Vision Transformers (ViT): Complex transformer architectures with attention mechanisms
  • Convolutional Neural Networks: Various CNN architectures from traffic sign detection to autonomous control
  • Feedforward Networks: MLPs with various activation functions and normalizations
  • Hybrid Architectures: Models combining multiple architectural patterns

All converted models are validated to produce numerically identical outputs to their original ONNX counterparts, ensuring correctness across diverse model types and operations.

Running Tests

To test with VNNCOMP 2024 benchmarks, clone the vnncomp2024 repository and ensure the following structure:

torchonnx/
│   ├── torchonnx/
│   ├── README.md
│   └── test/
└── ...
vnncomp2024/
│   ├── benchmarks/
└── ...

Then run the test suite:

cd torchonnx/test
python test_benchmarks.py

Supported Operations

The tool implements most commonly used operations in feedforward neural networks and transformers:

Layers (nn.Module)

  • Convolution: Conv1d, Conv2d, ConvTranspose1d, ConvTranspose2d
  • Pooling: MaxPool2d, AvgPool2d, AdaptiveAvgPool2d
  • Normalization: BatchNorm2d (with automatic dimension handling)
  • Activation: ReLU, LeakyReLU, Sigmoid, Tanh, Softmax, ELU, GELU
  • Linear: Linear
  • Dropout: Dropout
  • Upsampling: Upsample
  • Shape Operations: Flatten

Functions (F.* and torch.*)

  • Convolution: F.conv, F.conv_transpose
  • Linear: F.linear
  • Pooling: F.interpolate
  • Padding: F.pad
  • Concatenation: torch.cat
  • Indexing: torch.gather, scatter_nd
  • Reduction: torch.mean, torch.sum, torch.min, torch.max, torch.argmax
  • Clipping: torch.clamp
  • Conditional: torch.where
  • Generation: torch.full, torch.arange

Operators (torch.*)

  • Arithmetic: add (+), sub (-), mul (*), div (/), matmul (@), pow (pow), neg (neg)
  • Comparison: equal (==)

Tensor Operations

  • Shape: reshape, permute, squeeze, unsqueeze, shape, expand, cast
  • Slicing: slice, split
  • Math: sign, cos, sin, floor

Transformer-based architectures are decomposed into basic operations and handled correctly.

Related Projects

  • ShapeONNX: Advanced shape inference for ONNX models. SlimONNX uses ShapeONNX for shape-dependent optimizations.
  • TorchVNNLIB: PyTorch library for neural network verification. Often used in conjunction with SlimONNX for model verification tasks. This convert the VNNLIB data files to .pth format for PyTorch or .npz format for NumPy.
  • SlimONNX: ONNX model simplification tool that removes redundant operations and optimizes the graph before conversion.
  • VNN-COMP: International Verification of Neural Networks Competition. SlimONNX is tested on all VNN-COMP 2024 benchmarks.
  • ONNX Simplifier: Alternative ONNX optimization tool with different optimization strategies.

Contributing

Contributions are welcome from the community. Whether fixing bugs, adding features, improving documentation, or sharing ideas, all contributions are appreciated.

Note: Direct pushes to the main branch are restricted. Please fork the repository and submit a Pull Request for any changes.

License

This project is licensed under the MIT License. See the LICENSE file for details.

About

TorchONNX is a tool to convert an ONNX model to a pytorch model.

Topics

Resources

License

Stars

Watchers

Forks

Languages