Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

cannot use hf models #1147

@LYMDLUT

Description

@LYMDLUT

$ torchrun --nproc-per-node 4 pippy_llama.py

import os
import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.pipelining import ScheduleGPipe, PipelineStage

Grab the model

whole_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct", device_map="meta"
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
tokenizer.pad_token = tokenizer.eos_token

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.distributed.init_process_group(rank=rank, world_size=world_size)

Cut model by equal number of layers per rank

layers_per_rank = whole_model.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")

stage_idx = rank
num_stages = world_size

def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
model = copy.deepcopy(whole_model)
model = whole_model
if not is_first:
model.model.embed_tokens = None

drop_layers = stop_layer is not None
num_layers = len(model.model.layers) - 1
for idx in range(num_layers, -1, -1):
    if f"layers.{idx}" == stop_layer:
        drop_layers = False
    if f"layers.{idx}" == start_layer:
        drop_layers = True
    if drop_layers:
        del model.model.layers[idx]
# drop_layers = start_layer is not None
# for name in list(model.model.layers.keys()):
#     # we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
#     if f"layers.{name}" == start_layer:
#         drop_layers = False
#     if f"layers.{name}" == stop_layer:
#         drop_layers = True
#     if drop_layers:
#         del model.model.layers[name]

if not is_last:
    model.model.norm = None
    model.lm_head = None

stage = PipelineStage(
    model,
    stage_idx,
    num_stages,
    device,
    #group=pp_mesh.get_group("pp"),
)
return stage, model

base_interval = whole_model.config.num_hidden_layers // num_stages
extra_layers = whole_model.config.num_hidden_layers % num_stages

splits = []
current_layer = 0
for i in range(num_stages - 1):
if i == 0:
current_layer += base_interval
else:
# Middle stages get an extra layer if there are any remaining
if extra_layers > 0:
current_layer += base_interval + 1
extra_layers -= 1
else:
current_layer += base_interval
splits.append("layers." + str(current_layer))

start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
stage_idx,
start_layer,
stop_layer,
is_first=stage_idx == 0,
is_last=stage_idx == num_stages - 1,
)
model_chunk.to_empty(device=device)

Run time inputs

full_batch_prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # full batch size = 8
inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True)["input_ids"].to(device)

schedule = ScheduleGPipe(stage, num_stages)

Run

if rank == 0:
schedule.step(inputs)
elif rank == world_size - 1:

output = schedule.step()
if output is not None:
    next_token_logits = output[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1)
    print(tokenizer.batch_decode(next_token))

else:
schedule.step()

16d6183f8d29eb1acc7ba798dc1833d

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions