Skip to content
This repository was archived by the owner on May 10, 2024. It is now read-only.

Commit 2eff234

Browse files
authored
Merge pull request #6 from Trainy-ai/patchfix
Patchfix
2 parents bc2d1d0 + c859d14 commit 2eff234

File tree

11 files changed

+76
-36
lines changed

11 files changed

+76
-36
lines changed

docs/source/quickstart/finetuning.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ To do a vicuna finetune of your first model through LLM-ATC, run the following
1414
.. code-block:: console
1515
1616
# start training
17-
$ llm-atc train --model_type vicuna --finetune_data ./vicuna_test.json --name myvicuna --description "This is a finetuned model that just says its name is vicuna" -c mycluster --cloud gcp --envs "MODEL_SIZE=7 WANDB_API_KEY=<my wandb key>" --accelerator A100-80G:4
17+
$ llm-atc train --model_type vicuna --finetune_data ./vicuna_test.json --name myvicuna --checkpoint_bucket my-trainy-bucket --checkpoint_path ~/test_vicuna --checkpoint_store S3 --description "This is a finetuned model that just says its name is vicuna" -c mycluster --cloud gcp --envs "MODEL_SIZE=7 WANDB_API_KEY=<my wandb key>" --accelerator A100-80G:4
1818
1919
# Once training is done, shutdown the cluster
2020
$ sky down

docs/source/quickstart/serving.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ by using the :code:`llm-atc/` prefix.
1313
.. code-block:: console
1414
1515
# serve an llm-atc finetuned model, requires `llm-atc/` prefix and grabs model checkpoint from object store
16-
$ llm-atc serve --name llm-atc/myvicuna --accelerator A100:1 -c servecluster --cloud gcp --region asia-southeast1 --envs "HF_TOKEN=<HuggingFace_token>"
16+
$ llm-atc serve --name llm-atc/myvicuna --source s3://my-bucket/my_vicuna/ --accelerator A100:1 -c servecluster --cloud gcp --region asia-southeast1 --envs "HF_TOKEN=<HuggingFace_token>"
1717
1818
# serve a HuggingFace model, e.g. `lmsys/vicuna-13b-v1.3`
1919
$ llm-atc serve --name lmsys/vicuna-13b-v1.3 --accelerator A100:1 -c servecluster --cloud gcp --region asia-southeast1 --envs "HF_TOKEN=<HuggingFace_token>"

llm_atc/cli.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,21 @@ def cli():
5858
required=True,
5959
help="local/cloud URI to finetuning data. (e.g ~/mychat.json, s3://my_bucket/my_chat.json)",
6060
)
61+
@click.option(
62+
"--checkpoint_bucket", type=str, required=True, help="object store bucket name"
63+
)
64+
@click.option(
65+
"--checkpoint_path",
66+
type=str,
67+
required=True,
68+
help="object store path for fine tuned checkpoints, e.g. ~/datasets",
69+
)
70+
@click.option(
71+
"--checkpoint_store",
72+
type=str,
73+
required=True,
74+
help="object store type ['S3', 'GCS', 'AZURE', 'R2', 'IBM']",
75+
)
6176
@click.option("-n", "--name", type=str, help="Name of this model run.", required=True)
6277
@click.option(
6378
"--description", type=str, default="", help="description of this model run"
@@ -100,6 +115,9 @@ def cli():
100115
def train(
101116
model_type: str,
102117
finetune_data: str,
118+
checkpoint_bucket: str,
119+
checkpoint_path: str,
120+
checkpoint_store: Optional[str],
103121
name: str,
104122
description: str,
105123
cluster: Optional[str],
@@ -118,12 +136,11 @@ def train(
118136
event="training launched",
119137
timestamp=datetime.utcnow(),
120138
)
121-
if RunTracker.run_exists(name):
122-
raise ValueError(
123-
f"Task with name {name} already exists in {llm_atc.constants.LLM_ATC_PATH}. Try again with a different name"
124-
)
125139
task = train_task(
126140
model_type,
141+
checkpoint_bucket=checkpoint_bucket,
142+
checkpoint_path=checkpoint_path,
143+
checkpoint_store=checkpoint_store,
127144
finetune_data=finetune_data,
128145
name=name,
129146
cloud=cloud,
@@ -146,7 +163,11 @@ def train(
146163
"--name",
147164
help="name of model to serve",
148165
required=True,
149-
multiple=True,
166+
)
167+
@click.option(
168+
"--source",
169+
help="object store path for llm-atc finetuned model checkpoints."
170+
"e.g. s3://<bucket-name>/<path>/<to>/<checkpoints>",
150171
)
151172
@click.option(
152173
"-e",
@@ -189,7 +210,8 @@ def train(
189210
help="Don't connect to this session",
190211
)
191212
def serve(
192-
name: List[str],
213+
name: str,
214+
source: Optional[str],
193215
accelerator: Optional[str],
194216
envs: Optional[str],
195217
cluster: Optional[str],
@@ -209,6 +231,7 @@ def serve(
209231
)
210232
serve_task = serve_route(
211233
name,
234+
source=source,
212235
accelerator=accelerator,
213236
envs=envs,
214237
cloud=cloud,

llm_atc/config/serve/serve.yml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@ resources:
44
ports:
55
- 8000
66

7-
file_mounts:
8-
/llm-atc:
9-
name: llm-atc # Make sure it is unique or you own this bucket name
10-
mode: MOUNT # MOUNT or COPY. Defaults to MOUNT if not specified
11-
127
setup: |
138
conda activate chatbot
149
if [ $? -ne 0 ]; then
@@ -22,7 +17,7 @@ setup: |
2217
conda install -y -c conda-forge accelerate
2318
pip install sentencepiece
2419
pip install vllm
25-
pip install git+https://github.com/lm-sys/FastChat.git
20+
pip install git+https://github.com/lm-sys/FastChat.git@v0.2.28
2621
pip install --upgrade openai
2722
if [[ "$HF_TOKEN" != "" ]];
2823
then
@@ -49,7 +44,6 @@ run: |
4944
master_addr=`echo "$SKYPILOT_NODE_IPS" | head -n1`
5045
let x='SKYPILOT_NODE_RANK + 1'
5146
this_addr=`echo "$SKYPILOT_NODE_IPS" | sed -n "${x}p"`
52-
MODEL_NAME=`echo "$MODELS_LIST" | sed -n "${x}p"`
5347
5448
echo "The ip address of this machine is ${this_addr}"
5549
echo "The head address is ${master_addr}"
@@ -82,5 +76,5 @@ run: |
8276
8377
8478
envs:
85-
MODELS_LIST: lmsys/vicuna-7b-v1.3
79+
MODEL_NAME: lmsys/vicuna-7b-v1.3
8680
HF_TOKEN: ""

llm_atc/config/train/vicuna.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@ num_nodes: 1
77

88
file_mounts:
99
/artifacts:
10-
name: llm-atc
10+
name: ${MY_BUCKET} # Name of the bucket.
1111
mode: MOUNT
12-
13-
workdir: .
12+
store: ${BUCKET_TYPE} # s3, gcs, r2, ibm
1413

1514
setup: |
1615
# Setup the environment
@@ -64,7 +63,7 @@ run: |
6463
# the training for saving checkpoints.
6564
mkdir -p ~/.checkpoints
6665
LOCAL_CKPT_PATH=~/.checkpoints
67-
CKPT_PATH=/artifacts/${MODEL_NAME}
66+
CKPT_PATH=/artifacts/${BUCKET_PATH}/${MODEL_NAME}
6867
mkdir -p $CKPT_PATH
6968
last_ckpt=$(ls ${CKPT_PATH} | grep -E '[0-9]+' | sort -t'-' -k1,1 -k2,2n | tail -1)
7069
mkdir -p ~/.checkpoints/${last_ckpt}
@@ -127,3 +126,6 @@ envs:
127126
WANDB_API_KEY: ""
128127
MODEL_NAME: "vicuna_test"
129128
HF_TOKEN: ""
129+
MY_BUCKET: "llm-atc"
130+
BUCKET_PATH: "my_vicuna" # object store path.
131+
BUCKET_TYPE: "S3"

llm_atc/launch.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
import sky
5+
from sky.data.storage import Storage
56

67
from omegaconf import OmegaConf
78
from typing import Any, Dict, Optional
@@ -10,7 +11,7 @@
1011
SUPPORTED_MODELS = ("vicuna",)
1112

1213

13-
def train_task(model_type: str, **launcher_kwargs) -> sky.Task:
14+
def train_task(model_type: str, *args, **launcher_kwargs) -> sky.Task:
1415
"""
1516
Dispatch train launch to corresponding task default config
1617
@@ -39,6 +40,9 @@ class Launcher:
3940
def __init__(
4041
self,
4142
finetune_data: str,
43+
checkpoint_bucket: str = "llm-atc",
44+
checkpoint_path: str = "my_vicuna",
45+
checkpoint_store: str = "S3",
4246
name: Optional[str] = None,
4347
cloud: Optional[str] = None,
4448
region: Optional[str] = None,
@@ -47,6 +51,9 @@ def __init__(
4751
envs: Optional[str] = "",
4852
):
4953
self.finetune_data: str = finetune_data
54+
self.checkpoint_bucket: str = checkpoint_bucket
55+
self.checkpoint_path: str = checkpoint_path
56+
self.checkpoint_store: str = checkpoint_store
5057
self.name: Optional[str] = name
5158
self.cloud: Optional[str] = cloud
5259
self.region: Optional[str] = region
@@ -85,8 +92,14 @@ def launch(self) -> sky.Task:
8592
logging.warning(
8693
"No huggingface token provided. You will not be able to finetune starting from private or gated models"
8794
)
95+
self.envs["MY_BUCKET"] = self.checkpoint_bucket
96+
self.envs["BUCKET_PATH"] = self.checkpoint_path
97+
self.envs["BUCKET_TYPE"] = self.checkpoint_store
8898
task.update_envs(self.envs)
8999
task.update_file_mounts({"/data/mydata.json": self.finetune_data})
100+
storage = Storage(name=self.checkpoint_bucket)
101+
storage.add_store(self.checkpoint_store)
102+
task.update_storage_mounts({"/artifacts": storage})
90103
resource = list(task.get_resources())[0]
91104
resource._set_accelerators(self.accelerator, None)
92105
resource._cloud = sky.clouds.CLOUD_REGISTRY.from_str(self.cloud)

llm_atc/serve.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Dict, List, Optional
99

1010

11-
def serve_route(model_names: List[str], **serve_kwargs):
11+
def serve_route(model_name: str, source: Optional[str] = None, **serve_kwargs):
1212
"""Routes model serve requests to the corresponding model serve config
1313
1414
Args:
@@ -17,26 +17,30 @@ def serve_route(model_names: List[str], **serve_kwargs):
1717
Raises:
1818
ValueError: requested non-existent model from llm-atc
1919
"""
20-
model_names = list(model_names)
21-
for i, name in enumerate(model_names):
22-
if name.startswith("llm-atc/") and not RunTracker.run_exists(
23-
name.split("/")[-1]
24-
):
25-
raise ValueError(f"model = {name} does not exist within llm-atc.")
26-
return Serve(model_names, **serve_kwargs).serve()
20+
if model_name.startswith("llm-atc/") and source is None:
21+
raise ValueError(
22+
"Attempting to use a finetuned model without a corresponding object store location"
23+
)
24+
elif not source is None and not model_name.startswith("llm-atc/"):
25+
logging.warning(
26+
"Specified object store mount but model is not an llm-atc model. Skipping mounting."
27+
)
28+
return Serve(model_name, source, **serve_kwargs).serve()
2729

2830

2931
class Serve:
3032
def __init__(
3133
self,
32-
names: List[str],
34+
names: str,
35+
source: Optional[str],
3336
accelerator: Optional[str] = None,
3437
cloud: Optional[str] = None,
3538
region: Optional[str] = None,
3639
zone: Optional[str] = None,
3740
envs: str = "",
3841
):
3942
self.names = names
43+
self.source = source
4044
self.num_models = len(names)
4145
self.accelerator = accelerator
4246
self.envs: Dict[Any, Any] = (
@@ -65,7 +69,7 @@ def default_serve_task(self) -> sky.Task:
6569
def serve(self) -> sky.Task:
6670
"""Deploy fastchat.serve.openai_api_server with vllm_worker"""
6771
serve_task = self.default_serve_task
68-
self.envs["MODELS_LIST"] = "\n".join(self.names)
72+
self.envs["MODEL_NAME"] = self.names
6973
if "HF_TOKEN" not in self.envs:
7074
logging.warning(
7175
"No huggingface token provided. You will not be able to access private or gated models"
@@ -76,5 +80,6 @@ def serve(self) -> sky.Task:
7680
resource._cloud = sky.clouds.CLOUD_REGISTRY.from_str(self.cloud)
7781
resource._set_region_zone(self.region, self.zone)
7882
serve_task.set_resources(resource)
79-
serve_task.num_noded = self.num_models
83+
if self.source and self.names.startswith("llm-atc/"):
84+
serve_task.update_file_mounts({"/" + self.names: self.source})
8085
return serve_task

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "llm_atc"
3-
version = "0.1.3"
3+
version = "0.1.4"
44
description = "Tools for fine tuning and serving LLMs"
55
authors = ["Andrew Aikawa <[email protected]>"]
66
readme = "README.md"

tests/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_train_vicuna():
121121
test = Test(
122122
"train_vicuna",
123123
[
124-
f"llm-atc train --model_type vicuna --finetune_data {test_chat} --name {name} --description 'test case vicuna fine tune' -c mycluster --cloud gcp --envs 'MODEL_SIZE=7' --accelerator A100-80G:4",
124+
f"llm-atc train --checkpoint_bucket llm-atc --checkpoint_path ~/test_vicuna --checkpoint_store S3 --model_type vicuna --finetune_data {test_chat} --name {name} --description 'test case vicuna fine tune' -c mycluster --cloud gcp --envs 'MODEL_SIZE=7' --accelerator A100-80G:4",
125125
],
126126
f"sky down --purge -y {name}",
127127
timeout=10 * 60,

tests/test_launch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
def test_train():
77
task = train_task(
88
"vicuna",
9+
checkpoint_bucket="llm-atc",
10+
checkpoint_path="myvicuna",
11+
checkpoint_store="S3",
912
finetune_data="./vicuna_test.json",
1013
name="myvicuna",
1114
cloud="aws",

0 commit comments

Comments
 (0)