Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions tinker_cookbook/distillation/train_on_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,10 @@ class Config:
kl_penalty_coef: float = 1.0
kl_discount_factor: float = 0.0

# Loss function to use for training: "importance_sampling" or "ppo"
# Loss function and configuration.
# See https://tinker-docs.thinkingmachines.ai/losses
loss_fn: LossFnType = "importance_sampling"
loss_fn_config: dict[str, Any] | None = None

# Number of optimizer steps per training iteration.
# Useful for very large batch sizes.
Expand Down Expand Up @@ -244,11 +246,12 @@ async def do_train_step_and_get_sampling_client(

with timed("train", metrics):
training_logprobs_D = await train_step(
data_D,
training_client,
cfg.learning_rate,
cfg.num_substeps,
cfg.loss_fn,
data_D=data_D,
training_client=training_client,
learning_rate=cfg.learning_rate,
num_substeps=cfg.num_substeps,
loss_fn=cfg.loss_fn,
loss_fn_config=cfg.loss_fn_config,
)

sampling_client, full_batch_metrics = await compute_full_batch_metrics_and_get_sampling_client(
Expand Down
11 changes: 9 additions & 2 deletions tinker_cookbook/recipes/distillation/on_policy_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import logging
import os
from datetime import datetime
from typing import Any

import chz
from tinker.types import LossFnType
from tinker_cookbook import cli_utils, model_info
from tinker_cookbook.distillation import train_on_policy
from tinker_cookbook.distillation.datasets import (
Expand Down Expand Up @@ -70,7 +72,11 @@ class CLIConfig:

# Optimizer configuration
num_substeps: int = 1
loss_fn: str = "importance_sampling"

# Loss function and configuration.
# See https://tinker-docs.thinkingmachines.ai/losses
loss_fn: LossFnType = "importance_sampling"
loss_fn_config: dict[str, Any] | None = None

# Logging configuration
log_path: str | None = None
Expand Down Expand Up @@ -146,7 +152,8 @@ async def cli_main(cli_config: CLIConfig):
kl_penalty_coef=cli_config.kl_penalty_coef,
kl_discount_factor=cli_config.kl_discount_factor,
num_substeps=cli_config.num_substeps,
loss_fn=cli_config.loss_fn, # type: ignore
loss_fn=cli_config.loss_fn,
loss_fn_config=cli_config.loss_fn_config,
wandb_project=cli_config.wandb_project,
wandb_name=wandb_name,
log_path=log_path,
Expand Down
11 changes: 9 additions & 2 deletions tinker_cookbook/recipes/distillation/on_policy_multi_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import logging
import os
from datetime import datetime
from typing import Any

import chz
from tinker.types import LossFnType
from tinker_cookbook import cli_utils
from tinker_cookbook.distillation import train_on_policy
from tinker_cookbook.distillation.datasets import (
Expand Down Expand Up @@ -64,7 +66,11 @@ class CLIConfig:

# Optimizer configuration
num_substeps: int = 1
loss_fn: str = "importance_sampling"

# Loss function and configuration.
# See https://tinker-docs.thinkingmachines.ai/losses
loss_fn: LossFnType = "importance_sampling"
loss_fn_config: dict[str, Any] | None = None

# Logging configuration
log_path: str | None = None
Expand Down Expand Up @@ -156,7 +162,8 @@ async def cli_main(cli_config: CLIConfig):
kl_penalty_coef=cli_config.kl_penalty_coef,
kl_discount_factor=cli_config.kl_discount_factor,
num_substeps=cli_config.num_substeps,
loss_fn=cli_config.loss_fn, # type: ignore
loss_fn=cli_config.loss_fn,
loss_fn_config=cli_config.loss_fn_config,
wandb_project=cli_config.wandb_project,
wandb_name=wandb_name,
log_path=log_path,
Expand Down
6 changes: 6 additions & 0 deletions tinker_cookbook/recipes/math_rl/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from datetime import datetime
from typing import Any

import chz
from tinker_cookbook import cli_utils, model_info
Expand Down Expand Up @@ -59,7 +60,11 @@ class CLIConfig:
behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask"

max_steps_off_policy: int | None = None

# Loss function and configuration.
# See https://tinker-docs.thinkingmachines.ai/losses
loss_fn: LossFnType = "importance_sampling"
loss_fn_config: dict[str, Any] | None = None


def get_dataset_builder(
Expand Down Expand Up @@ -143,6 +148,7 @@ async def cli_main(cli_config: CLIConfig):
if cli_config.max_steps_off_policy is not None
else None,
loss_fn=cli_config.loss_fn,
loss_fn_config=cli_config.loss_fn_config,
)

cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)
Expand Down
33 changes: 22 additions & 11 deletions tinker_cookbook/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,11 @@ async def forward_backward(
training_client: tinker.TrainingClient,
batch_d: List[tinker.Datum],
loss_fn: LossFnType,
loss_fn_config: dict[str, Any] | None = None,
) -> List[torch.Tensor]:
"""Accumulate gradients on a minibatch of data"""
fwd_bwd_future = await training_client.forward_backward_async(
list(map(remove_mask, batch_d)), loss_fn=loss_fn
data=list(map(remove_mask, batch_d)), loss_fn=loss_fn, loss_fn_config=loss_fn_config
)
fwd_bwd_result = await fwd_bwd_future.result_async()

Expand All @@ -182,12 +183,18 @@ async def train_step(
learning_rate: float,
num_substeps: int,
loss_fn: LossFnType,
loss_fn_config: dict[str, Any] | None = None,
) -> List[torch.Tensor]:
"""Train the model on collected trajectories."""
batches_md = split_list(data_D, min(num_substeps, len(data_D)))
training_logprobs_D: list[torch.Tensor] = []
for batch_d in batches_md:
training_logprobs = await forward_backward(training_client, batch_d, loss_fn)
training_logprobs = await forward_backward(
training_client=training_client,
batch_d=batch_d,
loss_fn=loss_fn,
loss_fn_config=loss_fn_config,
)
training_logprobs_D.extend(training_logprobs)
await optim_step(training_client, learning_rate)
return training_logprobs_D
Expand Down Expand Up @@ -237,8 +244,10 @@ class Config:
kl_penalty_coef: float = 0.0
kl_discount_factor: float = 0.0

# Loss function to use for training: "importance_sampling" or "ppo"
# Loss function and configuration.
# See https://tinker-docs.thinkingmachines.ai/losses
loss_fn: LossFnType = "importance_sampling"
loss_fn_config: dict[str, Any] | None = None

# Number of optimizer steps per training iteration.
# Useful for very large batch sizes.
Expand Down Expand Up @@ -854,9 +863,10 @@ async def do_train_step_streaming_and_get_sampling_client(
f"train/forward_backward_substep_{i_substep}_minibatch_{i_minibatch}", metrics
):
training_logprobs_D = await forward_backward(
training_client,
data_D,
cfg.loss_fn,
training_client=training_client,
batch_d=data_D,
loss_fn=cfg.loss_fn,
loss_fn_config=cfg.loss_fn_config,
)
all_data_D.extend(data_D)
all_training_logprobs_D.extend(training_logprobs_D)
Expand Down Expand Up @@ -919,11 +929,12 @@ async def do_train_step_and_get_sampling_client(

with timed("train", metrics):
training_logprobs_D = await train_step(
data_D,
training_client,
cfg.learning_rate,
cfg.num_substeps,
cfg.loss_fn,
data_D=data_D,
training_client=training_client,
learning_rate=cfg.learning_rate,
num_substeps=cfg.num_substeps,
loss_fn=cfg.loss_fn,
loss_fn_config=cfg.loss_fn_config,
)

sampling_client, full_batch_metrics = await compute_full_batch_metrics_and_get_sampling_client(
Expand Down
Loading