-
Notifications
You must be signed in to change notification settings - Fork 293
Granite4 FP8 Block Quantization #2001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
krishnateja95
wants to merge
9
commits into
vllm-project:main
Choose a base branch
from
krishnateja95:fix/granite4-example-updates
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+309
−40
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
ebf9ff3
Granit4 FP8 Block Quantization
krishnateja95 fc14b3a
Update src/llmcompressor/modeling/granite4.py
krishnateja95 b6a92a7
Merge branch 'vllm-project:main' into fix/granite4-example-updates
krishnateja95 c22bbd4
Update modeling/granite4.py
krishnateja95 27c01fc
Merge branch 'vllm-project:main' into fix/granite4-example-updates
krishnateja95 4097e07
Granite4 MoECalibrationModule Update
krishnateja95 313e83c
Remove old granite 4 example
brian-dellabetta 6e389e2
rename block example file
brian-dellabetta 80f779c
Merge branch 'main' into fix/granite4-example-updates
brian-dellabetta File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,67 +1,42 @@ | ||
| from compressed_tensors.utils import replace_module | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
| from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( | ||
| GraniteMoeHybridParallelExperts, | ||
| ) | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modeling.granite4 import GraniteMoeHybridParallelExpertsLinear | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
| from llmcompressor.utils import dispatch_for_generation | ||
| from llmcompressor.modeling import replace_modules_for_calibration | ||
|
|
||
| """Please see details in `README_granite4.md`.""" | ||
| MODEL_ID = "ibm-granite/granite-4.0-h-small" | ||
|
|
||
| MODEL_ID = "ibm-granite/granite-4.0-tiny-preview" | ||
|
|
||
| # Load model. | ||
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
|
||
| skip_router_only = True # assume we want to quantize input/output moe layers | ||
| ignore_lay = [ | ||
| "lm_head", | ||
| ] | ||
| if skip_router_only: | ||
| # swap moe linears to a custom class | ||
| for n, m in model.named_modules(): | ||
| if isinstance(m, GraniteMoeHybridParallelExperts): | ||
| new_mod = GraniteMoeHybridParallelExpertsLinear.from_3d_expert(m) | ||
| replace_module(model, n, new_mod) | ||
| ignore_lay += ["re:.*block_sparse_moe.router"] | ||
| SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter" | ||
| else: | ||
| # Skip all .input_linear, .output-linear, and router layers. | ||
| ignore_lay += ["re:.*block_sparse_moe"] | ||
| SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoe" | ||
| model = replace_modules_for_calibration(model) | ||
|
|
||
| ignore_lay = ["lm_head"] | ||
|
|
||
| recipe = QuantizationModifier( | ||
| targets=["Linear", "GraniteMoeHybridParallelExpertsLinear"], | ||
| targets=["Linear"], | ||
| scheme="FP8_DYNAMIC", | ||
| ignore=ignore_lay, | ||
| ) | ||
|
|
||
| # Apply quantization. | ||
| oneshot(model=model, recipe=recipe) | ||
|
|
||
| # Confirm generations of the quantized model look sane. | ||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_for_generation(model) | ||
| input_ids = tokenizer( | ||
| "What is your favorite TV show?", return_tensors="pt" | ||
| ).input_ids.to("cuda") | ||
| output = model.generate(input_ids, max_new_tokens=20) | ||
| "Describe Large Language Model", return_tensors="pt" | ||
| ).input_ids.to(model.device) | ||
| output = model.generate(input_ids, max_new_tokens=35) | ||
| print(tokenizer.decode(output[0])) | ||
| print("==========================================") | ||
|
|
||
| # Revert weights of MoE experts to 3D format (num_experts, output_size, input_size) | ||
| for n, m in model.named_modules(): | ||
| if isinstance(m, GraniteMoeHybridParallelExpertsLinear): | ||
| # NOTE: can assert type != "meta" instead, which is sign of offloading | ||
| assert m.weight.device.type == "cuda", ( | ||
| "Found some offloaded weights. This is not compatible with reshaping " | ||
| "experts to 3D prior model save. Ensure the model is fully on cuda." | ||
| ) | ||
| m.to_3d_expert() | ||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-block" | ||
| print(f"Saving to {SAVE_DIR}") | ||
|
|
||
| model.save_pretrained(SAVE_DIR) | ||
| tokenizer.save_pretrained(SAVE_DIR) |
42 changes: 42 additions & 0 deletions
42
examples/quantization_w8a8_fp8/granite4_fp8_block_example.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
krishnateja95 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import QuantizationModifier | ||
| from llmcompressor.utils import dispatch_for_generation | ||
| from llmcompressor.modeling import replace_modules_for_calibration | ||
|
|
||
| MODEL_ID = "ibm-granite/granite-4.0-h-small" | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||
|
|
||
| model = replace_modules_for_calibration(model) | ||
|
|
||
| ignore_lay = ["lm_head", "re:.*block_sparse_moe.router", "re:.*mamba.in_proj", "re:.*shared_mlp.input_linear"] | ||
|
|
||
| recipe = QuantizationModifier( | ||
| targets=["Linear"], | ||
| scheme="FP8_BLOCK", | ||
| ignore=ignore_lay, | ||
| ) | ||
|
|
||
| oneshot(model=model, recipe=recipe) | ||
|
|
||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_for_generation(model) | ||
| input_ids = tokenizer( | ||
| "Describe Large Language Model", return_tensors="pt" | ||
| ).input_ids.to(model.device) | ||
| output = model.generate(input_ids, max_new_tokens=35) | ||
| print(tokenizer.decode(output[0])) | ||
| print("==========================================") | ||
|
|
||
| SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-block" | ||
| print(f"Saving to {SAVE_DIR}") | ||
|
|
||
| model.save_pretrained(SAVE_DIR) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.