-
Notifications
You must be signed in to change notification settings - Fork 88
cannot use hf models #1147
Description
$ torchrun --nproc-per-node 4 pippy_llama.py
import os
import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.distributed.pipelining import ScheduleGPipe, PipelineStage
Grab the model
whole_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct", device_map="meta"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.distributed.init_process_group(rank=rank, world_size=world_size)
Cut model by equal number of layers per rank
layers_per_rank = whole_model.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")
stage_idx = rank
num_stages = world_size
def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
model = copy.deepcopy(whole_model)
model = whole_model
if not is_first:
model.model.embed_tokens = None
drop_layers = stop_layer is not None
num_layers = len(model.model.layers) - 1
for idx in range(num_layers, -1, -1):
if f"layers.{idx}" == stop_layer:
drop_layers = False
if f"layers.{idx}" == start_layer:
drop_layers = True
if drop_layers:
del model.model.layers[idx]
# drop_layers = start_layer is not None
# for name in list(model.model.layers.keys()):
# # we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
# if f"layers.{name}" == start_layer:
# drop_layers = False
# if f"layers.{name}" == stop_layer:
# drop_layers = True
# if drop_layers:
# del model.model.layers[name]
if not is_last:
model.model.norm = None
model.lm_head = None
stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
#group=pp_mesh.get_group("pp"),
)
return stage, model
base_interval = whole_model.config.num_hidden_layers // num_stages
extra_layers = whole_model.config.num_hidden_layers % num_stages
splits = []
current_layer = 0
for i in range(num_stages - 1):
if i == 0:
current_layer += base_interval
else:
# Middle stages get an extra layer if there are any remaining
if extra_layers > 0:
current_layer += base_interval + 1
extra_layers -= 1
else:
current_layer += base_interval
splits.append("layers." + str(current_layer))
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
stage_idx,
start_layer,
stop_layer,
is_first=stage_idx == 0,
is_last=stage_idx == num_stages - 1,
)
model_chunk.to_empty(device=device)
Run time inputs
full_batch_prompts = (
"How do you", "I like to", "Can I help", "You need to",
"The weather is", "I found a", "What is your", "You are so",
) # full batch size = 8
inputs = tokenizer(full_batch_prompts, return_tensors="pt", padding=True)["input_ids"].to(device)
schedule = ScheduleGPipe(stage, num_stages)
Run
if rank == 0:
schedule.step(inputs)
elif rank == world_size - 1:
output = schedule.step()
if output is not None:
next_token_logits = output[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
print(tokenizer.batch_decode(next_token))
else:
schedule.step()
