Skip to content

Commit 1af667d

Browse files
Feat: support multiple dtypes per checkpoint (#16)
* Feat: support multiple dtypes * add revert command, fix missing types, add robust test * add pre commit hook * lints * dos2unix * fix integrations * update tests * lints * more lints * more lints * add docs, update colors
1 parent 36f1f06 commit 1af667d

23 files changed

+1172
-490
lines changed

.github/workflows/pre-commit.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name: Lint (pre-commit)
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
types: [assigned, opened, synchronize, reopened]
9+
10+
jobs:
11+
pre-commit:
12+
runs-on: ubuntu-latest
13+
steps:
14+
- uses: actions/checkout@v3
15+
- uses: actions/setup-python@v3
16+
with:
17+
python-version: "3.11"
18+
- uses: pre-commit/[email protected]

.pre-commit-config.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v4.0.1
4+
hooks:
5+
- id: check-docstring-first
6+
- id: check-toml
7+
- id: check-yaml
8+
exclude: packaging/.*
9+
args:
10+
- --allow-multiple-documents
11+
- id: mixed-line-ending
12+
args: [--fix=lf]
13+
- id: end-of-file-fixer
14+
15+
- repo: https://github.com/astral-sh/ruff-pre-commit
16+
rev: 'v0.3.4'
17+
hooks:
18+
- id: ruff
19+
name: lint with ruff
20+
- id: ruff
21+
name: sort imports with ruff
22+
args: [--select, I, --fix]
23+
- id: ruff-format
24+
name: format with ruff
25+
26+
- repo: https://github.com/pre-commit/pre-commit-hooks
27+
rev: v4.3.0
28+
hooks:
29+
- id: check-added-large-files
30+
- id: check-merge-conflict
31+
- id: check-vcs-permalinks
32+
- id: debug-statements
33+
34+
- repo: https://github.com/pre-commit/mirrors-clang-format
35+
rev: v14.0.6
36+
hooks:
37+
- id: clang-format

README.md

Lines changed: 187 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,187 @@
1-
<div align="center">
2-
<picture>
3-
<source media="(prefers-color-scheme: dark)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/flashpack-logo-white.png?raw=true">
4-
<source media="(prefers-color-scheme: light)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/flashpack-logo-black.png?raw=true">
5-
<img alt="FlashPack Logo" src="https://github.com/fal-ai/flashpack/blob/main/media/flashpack-logo-black.png?raw=true">
6-
</picture>
7-
<h2>Disk-to-GPU Tensor loading at up to 25Gbps without GDS</h2>
8-
</div>
9-
10-
<div align="center">
11-
<picture>
12-
<source media="(prefers-color-scheme: dark)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/benchmark-white.png?raw=true">
13-
<source media="(prefers-color-scheme: light)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/benchmark-black.png?raw=true">
14-
<img alt="Benchmark Results" src="https://github.com/fal-ai/flashpack/blob/main/media/benchmark-black.png?raw=true">
15-
</picture>
16-
<em>Run this benchmark in `scripts/run_benchmark.py`</em>
17-
</div>
18-
19-
<div align="center">
20-
<picture>
21-
<source media="(prefers-color-scheme: dark)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/load-state-dict-comparison-white.png?raw=true">
22-
<source media="(prefers-color-scheme: light)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/load-state-dict-comparison-black.png?raw=true">
23-
<img alt="Benchmark Results" src="https://github.com/fal-ai/flashpack/blob/main/media/load-state-dict-comparison-black.png?raw=true">
24-
</picture>
25-
<em>Run this benchmark in `tests/test_speed_comparison.py`</em>
26-
</div>
27-
28-
# Integration Guide
29-
## Mixins
30-
### Diffusers/Transformers
31-
32-
```py
33-
# Integration classes
34-
from flashpack.integrations.diffusers import FlashPackDiffusersModelMixin, FlashPackDiffusionPipeline
35-
from flashpack.integrations.transformers import FlashPackTransformersModelMixin
36-
37-
# Base classes
38-
from diffusers.models import MyModel, SomeOtherModel
39-
from diffusers.pipelines import MyPipeline
40-
41-
# Define mixed classes
42-
class FlashPackMyModel(MyModel, FlashPackDiffusersModelMixin):
43-
pass
44-
45-
class FlashPackMyPipeline(MyPipeline, FlashPackDiffusionPipine):
46-
def __init__(
47-
self,
48-
my_model: FlashPackMyModel,
49-
other_model: SomeOtherModel,
50-
) -> None:
51-
super().__init__()
52-
53-
# Load base pipeline
54-
pipeline = FlashPackMyPipeline.from_pretrained("some/repository")
55-
56-
# Save flashpack pipeline
57-
pipeline.save_pretrained_flashpack(
58-
"some_directory",
59-
push_to_hub=False, # pass repo_id when using this
60-
)
61-
62-
# Load directly from flashpack directory or repository
63-
pipeline = FlashPackMyPipeline.from_pretrained_flashpack("my/flashpack-repository")
64-
```
65-
66-
### Vanilla PyTorch
67-
68-
```py
69-
from flashpack import FlashPackMixin
70-
71-
class MyModule(nn.Module, FlashPackMixin):
72-
def __init__(self, some_arg: int = 4) -> None:
73-
...
74-
75-
module = MyModule(some_arg = 4)
76-
module.save_flashpack("model.flashpack")
77-
78-
loaded_module = module.from_flashpack("model.flashpack", some_arg=4)
79-
```
80-
81-
## Direct Integration
82-
83-
```py
84-
from flashpack import pack_to_file, assign_from_file
85-
86-
flashpack_path = "/path/to/model.flashpack"
87-
model = nn.Module(...)
88-
89-
pack_to_file(model, flashpack_path) # write state dict to file
90-
assign_from_file(model, flashpack_path) # load state dict from file
91-
```
1+
<div align="center">
2+
<picture>
3+
<source media="(prefers-color-scheme: dark)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/flashpack-logo-white.png?raw=true">
4+
<source media="(prefers-color-scheme: light)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/flashpack-logo-black.png?raw=true">
5+
<img alt="FlashPack Logo" src="https://github.com/fal-ai/flashpack/blob/main/media/flashpack-logo-black.png?raw=true">
6+
</picture>
7+
<h2>Disk-to-GPU Tensor loading at up to 25Gbps without GDS</h2>
8+
</div>
9+
10+
## Updates
11+
12+
- **2025-11-25**: Now supports **multiple data types per checkpoint** with no regressions in speed!
13+
14+
<div align="center">
15+
<picture>
16+
<source media="(prefers-color-scheme: dark)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/benchmark-white.png?raw=true">
17+
<source media="(prefers-color-scheme: light)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/benchmark-black.png?raw=true">
18+
<img alt="Benchmark Results" src="https://github.com/fal-ai/flashpack/blob/main/media/benchmark-black.png?raw=true">
19+
</picture>
20+
<em>Run this benchmark in `scripts/run_benchmark.py`</em>
21+
</div>
22+
23+
<div align="center">
24+
<picture>
25+
<source media="(prefers-color-scheme: dark)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/load-state-dict-comparison-white.png?raw=true">
26+
<source media="(prefers-color-scheme: light)" srcset="https://github.com/fal-ai/flashpack/blob/main/media/load-state-dict-comparison-black.png?raw=true">
27+
<img alt="Benchmark Results" src="https://github.com/fal-ai/flashpack/blob/main/media/load-state-dict-comparison-black.png?raw=true">
28+
</picture>
29+
<em>Run this benchmark in `tests/test_speed_comparison.py`</em>
30+
</div>
31+
32+
# Integration Guide
33+
## Mixins
34+
### Diffusers/Transformers
35+
36+
```py
37+
# Integration classes
38+
from flashpack.integrations.diffusers import FlashPackDiffusersModelMixin, FlashPackDiffusionPipeline
39+
from flashpack.integrations.transformers import FlashPackTransformersModelMixin
40+
41+
# Base classes
42+
from diffusers.models import MyModel, SomeOtherModel
43+
from diffusers.pipelines import MyPipeline
44+
45+
# Define mixed classes
46+
class FlashPackMyModel(MyModel, FlashPackDiffusersModelMixin):
47+
pass
48+
49+
class FlashPackMyPipeline(MyPipeline, FlashPackDiffusionPipine):
50+
def __init__(
51+
self,
52+
my_model: FlashPackMyModel,
53+
other_model: SomeOtherModel,
54+
) -> None:
55+
super().__init__()
56+
57+
# Load base pipeline
58+
pipeline = FlashPackMyPipeline.from_pretrained("some/repository")
59+
60+
# Save flashpack pipeline
61+
pipeline.save_pretrained_flashpack(
62+
"some_directory",
63+
push_to_hub=False, # pass repo_id when using this
64+
)
65+
66+
# Load directly from flashpack directory or repository
67+
pipeline = FlashPackMyPipeline.from_pretrained_flashpack("my/flashpack-repository")
68+
```
69+
70+
### Vanilla PyTorch
71+
72+
```py
73+
from flashpack import FlashPackMixin
74+
75+
class MyModule(nn.Module, FlashPackMixin):
76+
def __init__(self, some_arg: int = 4) -> None:
77+
...
78+
79+
module = MyModule(some_arg = 4)
80+
module.save_flashpack("model.flashpack")
81+
82+
loaded_module = module.from_flashpack("model.flashpack", some_arg=4)
83+
```
84+
85+
## Direct Integration
86+
87+
```py
88+
from flashpack import pack_to_file, assign_from_file
89+
90+
flashpack_path = "/path/to/model.flashpack"
91+
model = nn.Module(...)
92+
93+
pack_to_file(model, flashpack_path) # write state dict to file
94+
assign_from_file(model, flashpack_path) # load state dict from file
95+
```
96+
97+
# CLI Commands
98+
99+
FlashPack provides a command-line interface for converting, inspecting, and reverting flashpack files.
100+
101+
## `flashpack convert`
102+
103+
Convert a model to a flashpack file.
104+
105+
```bash
106+
flashpack convert <path_or_repo_id> [destination_path] [options]
107+
```
108+
109+
**Arguments:**
110+
- `path_or_repo_id` - Local path or Hugging Face repository ID
111+
- `destination_path` - (Optional) Output path for the flashpack file
112+
113+
**Options:**
114+
| Option | Description |
115+
|--------|-------------|
116+
| `--subfolder` | Subfolder of the model (for repo_id) |
117+
| `--variant` | Model variant (for repo_id) |
118+
| `--dtype` | Target dtype for the flashpack file. When omitted, no type changes are made |
119+
| `--ignore-names` | Tensor names to ignore (can be specified multiple times) |
120+
| `--ignore-prefixes` | Tensor prefixes to ignore (can be specified multiple times) |
121+
| `--ignore-suffixes` | Tensor suffixes to ignore (can be specified multiple times) |
122+
| `--use-transformers` | Load the path as a transformers model |
123+
| `--use-diffusers` | Load the path as a diffusers model |
124+
| `-v, --verbose` | Enable verbose output |
125+
126+
**Examples:**
127+
```bash
128+
# Convert a local model
129+
flashpack convert ./my_model ./my_model.flashpack
130+
131+
# Convert from Hugging Face
132+
flashpack convert stabilityai/stable-diffusion-xl-base-1.0 --subfolder unet --use-diffusers
133+
134+
# Convert with specific dtype
135+
flashpack convert ./my_model ./my_model.flashpack --dtype float16
136+
```
137+
138+
## `flashpack revert`
139+
140+
Revert a flashpack file back to safetensors or torch format.
141+
142+
```bash
143+
flashpack revert <path> [destination_path] [options]
144+
```
145+
146+
**Arguments:**
147+
- `path` - Path to the flashpack file
148+
- `destination_path` - (Optional) Output path for the reverted file
149+
150+
**Options:**
151+
| Option | Description |
152+
|--------|-------------|
153+
| `-v, --verbose` | Enable verbose output |
154+
155+
**Example:**
156+
```bash
157+
flashpack revert ./my_model.flashpack ./my_model.safetensors
158+
```
159+
160+
## `flashpack metadata`
161+
162+
Print the metadata of a flashpack file.
163+
164+
```bash
165+
flashpack metadata <path> [options]
166+
```
167+
168+
**Arguments:**
169+
- `path` - Path to the flashpack file
170+
171+
**Options:**
172+
| Option | Description |
173+
|--------|-------------|
174+
| `-i, --show-index` | Show the tensor index |
175+
| `-j, --json` | Output metadata in JSON format |
176+
177+
**Examples:**
178+
```bash
179+
# View basic metadata
180+
flashpack metadata ./my_model.flashpack
181+
182+
# View metadata with tensor index
183+
flashpack metadata ./my_model.flashpack --show-index
184+
185+
# Output as JSON
186+
flashpack metadata ./my_model.flashpack --json
187+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ dev = [
4646
]
4747

4848
[tool.setuptools_scm]
49-
write_to = "src/flashpack/version.py"
49+
write_to = "src/flashpack/version.py"

scripts/plot_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# configuration
1111
accelerate_color = "#0f5ef3"
1212
flashpack_color = "#adff02"
13-
label_color = "#111111"
13+
label_color = "#eeeeee"
1414
model_labels = {
1515
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": "Wan2.1 1.3B DiT",
1616
"Wan-AI/Wan2.1-T2V-14B-Diffusers": "Wan2.1 14B DiT",

scripts/run_benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import torch
99
from flashpack.integrations.diffusers import patch_diffusers_auto_model
1010
from flashpack.integrations.transformers import patch_transformers_auto_model
11+
from huggingface_hub import snapshot_download
1112

1213
patch_diffusers_auto_model()
1314
patch_transformers_auto_model()
1415

15-
from diffusers.models import AutoModel as DiffusersAutoModel
16-
from huggingface_hub import snapshot_download
17-
from transformers import AutoModel as TransformersAutoModel
16+
from diffusers.models import AutoModel as DiffusersAutoModel # noqa: E402
17+
from transformers import AutoModel as TransformersAutoModel # noqa: E402
1818

1919

2020
def test_model(
@@ -201,4 +201,4 @@ def sync_and_flush() -> None:
201201
)
202202
print_test_result(
203203
test_model_name, accelerate_time, flashpack_time, total_bytes
204-
)
204+
)

0 commit comments

Comments
 (0)