Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Dec 6, 2025

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

  • Fuse Z-Image's qkv_proj and gate_up_proj
  • Reuse vLLM's gated silu activation kernel

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Isotr0py <[email protected]>
@Isotr0py Isotr0py requested a review from ZJY0516 December 6, 2025 16:48
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +669 to +673
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip remapping names that are already fused

The new load_weights remaps checkpoint keys by substring, so a fused key like .to_qkv.weight or .w13.weight also matches the .to_q/.w1 patterns. The subsequent name.replace turns it into .to_qkvkv.weight/.w133.weight, which is absent from params_dict and will throw during weight loading. Any state dict saved from the fused model (the format this commit introduces) will now fail to reload. Please guard the remapping so it only applies to unfused .to_q/.w1 names.

Useful? React with 👍 / 👎.

@hsliuustc0106
Copy link
Collaborator

any test result in terms of acc and speed up?

@Isotr0py
Copy link
Member Author

Isotr0py commented Dec 8, 2025

any test result in terms of acc and speed up?

Just benchmark manually on RTX 3090, default tex_to_image.py runtime reduce from 45.59s to 45.12s:

Main branch (45.59s)
[05:12 下午]-[mozf@A405-RTX-Server]-[~/develop-projects/vllm-omni]- |main → upstream U:1 ?:7 ✗|
$ python examples/offline_inference/qwen_image/text_to_image.py --model /home/mozf/LLM/Z-Image-Turbo/
INFO 12-08 17:13:00 [__init__.py:216] Automatically detected platform cuda.
WARNING:vllm_omni.diffusion.attention.backends.flash_attn:FlashAttentionBackend is not available. You may install flash-attn by running `uv pip install flash-attn==2.8.1 --no-build-isolation` or install pre-built flash-attn from https://github.com/Dao-AILab/flash-attention/releases
INFO 12-08 17:13:05 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_cee42fe4'), local_subscribe_addr='ipc:///tmp/2c6522e3-36f1-4e16-87bb-ba78b9f3f481', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO:vllm_omni.diffusion.diffusion_engine:Starting server...
INFO 12-08 17:13:07 [__init__.py:216] Automatically detected platform cuda.
INFO 12-08 17:13:11 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_a8be8967'), local_subscribe_addr='ipc:///tmp/09c9916b-5f54-4240-844f-c01a5018a2da', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0 created result MessageQueue
WARNING 12-08 17:13:11 [__init__.py:755] Current vLLM config is not set.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
WARNING 12-08 17:13:11 [__init__.py:755] Current vLLM config is not set.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 12-08 17:13:11 [parallel_state.py:1208] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Initialized device and distributed environment.
WARNING:vllm_omni.diffusion.attention.backends.flash_attn:FlashAttentionBackend is not available. You may install flash-attn by running `uv pip install flash-attn==2.8.1 --no-build-isolation` or install pre-built flash-attn from https://github.com/Dao-AILab/flash-attention/releases
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.60it/s]
WARNING 12-08 17:13:14 [__init__.py:755] Current vLLM config is not set.
Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:02<00:04,  2.09s/it]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:04<00:02,  2.32s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:06<00:00,  1.94s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:06<00:00,  2.02s/it]

INFO:vllm_omni.diffusion.model_loader.diffusers_loader:Loading weights took 6.27 seconds
INFO:vllm_omni.diffusion.worker.gpu_worker:Model loading took 19.2180 GiB and 10.107239 seconds
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Model loaded successfully.
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Scheduler loop started.
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0 ready to receive requests via shared memory
INFO:vllm_omni.diffusion.scheduler:SyncScheduler initialized result MessageQueue
INFO:vllm_omni.diffusion.omni_diffusion:Prepared 1 requests for generation.
INFO:vllm_omni.diffusion.diffusion_engine:Generation completed successfully.
INFO:vllm_omni.diffusion.diffusion_engine:Post-processing completed in 0.0696 seconds
Image generation took 45.59 seconds.
Saved generated image to qwen_image_output.png
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Received shutdown message
INFO:vllm_omni.diffusion.worker.gpu_worker:event loop terminated.
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Destroyed process group
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Shutdown complete.
PR (45.12s)
[05:14 下午]-[mozf@A405-RTX-Server]-[~/develop-projects/vllm-omni]- |fuse-z-image → origin U:1 ?:7 ✗|
$ python examples/offline_inference/qwen_image/text_to_image.py --model /home/mozf/LLM/Z-Image-Turbo/
INFO 12-08 17:14:47 [__init__.py:216] Automatically detected platform cuda.
WARNING:vllm_omni.diffusion.attention.backends.flash_attn:FlashAttentionBackend is not available. You may install flash-attn by running `uv pip install flash-attn==2.8.1 --no-build-isolation` or install pre-built flash-attn from https://github.com/Dao-AILab/flash-attention/releases
INFO 12-08 17:14:51 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_cde70bc3'), local_subscribe_addr='ipc:///tmp/37205691-b2f6-4053-8033-abcdd23658dd', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO:vllm_omni.diffusion.diffusion_engine:Starting server...
INFO 12-08 17:14:54 [__init__.py:216] Automatically detected platform cuda.
INFO 12-08 17:14:58 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_7630175f'), local_subscribe_addr='ipc:///tmp/ba68dad9-8662-4567-8d44-5ca39d2d36b2', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0 created result MessageQueue
WARNING 12-08 17:14:58 [__init__.py:755] Current vLLM config is not set.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
WARNING 12-08 17:14:58 [__init__.py:755] Current vLLM config is not set.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
INFO 12-08 17:14:58 [parallel_state.py:1208] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Initialized device and distributed environment.
WARNING:vllm_omni.diffusion.attention.backends.flash_attn:FlashAttentionBackend is not available. You may install flash-attn by running `uv pip install flash-attn==2.8.1 --no-build-isolation` or install pre-built flash-attn from https://github.com/Dao-AILab/flash-attention/releases
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.62it/s]
WARNING 12-08 17:15:01 [__init__.py:755] Current vLLM config is not set.
Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:02<00:04,  2.12s/it]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:04<00:02,  2.34s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:06<00:00,  1.96s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:06<00:00,  2.04s/it]

INFO:vllm_omni.diffusion.model_loader.diffusers_loader:Loading weights took 6.35 seconds
INFO:vllm_omni.diffusion.worker.gpu_worker:Model loading took 19.1516 GiB and 10.126088 seconds
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Model loaded successfully.
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Scheduler loop started.
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0 ready to receive requests via shared memory
INFO:vllm_omni.diffusion.scheduler:SyncScheduler initialized result MessageQueue
INFO:vllm_omni.diffusion.omni_diffusion:Prepared 1 requests for generation.
INFO:vllm_omni.diffusion.diffusion_engine:Generation completed successfully.
INFO:vllm_omni.diffusion.diffusion_engine:Post-processing completed in 0.0707 seconds
Image generation took 45.12 seconds.
Saved generated image to qwen_image_output.png
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Received shutdown message
INFO:vllm_omni.diffusion.worker.gpu_worker:event loop terminated.
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Destroyed process group
INFO:vllm_omni.diffusion.worker.gpu_worker:Worker 0: Shutdown complete

And the generated image looks reasonable:
qwen_image_output

@SamitHuang
Copy link
Collaborator

is the runtime corresponding to omni.generate(...)?

@Isotr0py
Copy link
Member Author

Isotr0py commented Dec 8, 2025

is the runtime corresponding to omni.generate(...)?

Yes, it's measured manually with these changes:

+++ torch.cuda.synchronize()
+++ start_time = time.time()

    images = omni.generate(
        args.prompt,
        height=args.height,
        width=args.width,
        generator=generator,
        true_cfg_scale=args.cfg_scale,
        num_inference_steps=args.num_inference_steps,
        num_outputs_per_prompt=args.num_images_per_prompt,
    )

+++ torch.cuda.synchronize()
+++ end_time = time.time()
+++ print(f"Image generation took {end_time - start_time:.2f} seconds.")

Copy link
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Isotr0py Isotr0py merged commit 83c2723 into vllm-project:main Dec 8, 2025
4 checks passed
@Isotr0py Isotr0py deleted the fuse-z-image branch December 8, 2025 15:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants