-
Notifications
You must be signed in to change notification settings - Fork 94
[Model] Fuse Z-Image's qkv_proj and gate_up_proj
#226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||
| if weight_name not in name: | ||
| continue | ||
| name = name.replace(weight_name, param_name) | ||
| param = params_dict[name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip remapping names that are already fused
The new load_weights remaps checkpoint keys by substring, so a fused key like .to_qkv.weight or .w13.weight also matches the .to_q/.w1 patterns. The subsequent name.replace turns it into .to_qkvkv.weight/.w133.weight, which is absent from params_dict and will throw during weight loading. Any state dict saved from the fused model (the format this commit introduces) will now fail to reload. Please guard the remapping so it only applies to unfused .to_q/.w1 names.
Useful? React with 👍 / 👎.
|
any test result in terms of acc and speed up? |
Signed-off-by: Isotr0py <[email protected]>
Just benchmark manually on RTX 3090, default Main branch (45.59s)PR (45.12s) |
|
is the runtime corresponding to omni.generate(...)? |
Yes, it's measured manually with these changes: +++ torch.cuda.synchronize()
+++ start_time = time.time()
images = omni.generate(
args.prompt,
height=args.height,
width=args.width,
generator=generator,
true_cfg_scale=args.cfg_scale,
num_inference_steps=args.num_inference_steps,
num_outputs_per_prompt=args.num_images_per_prompt,
)
+++ torch.cuda.synchronize()
+++ end_time = time.time()
+++ print(f"Image generation took {end_time - start_time:.2f} seconds.") |
hsliuustc0106
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
qkv_projandgate_up_projTest Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)