diff --git a/tinker_cookbook/distillation/train_on_policy.py b/tinker_cookbook/distillation/train_on_policy.py index 8b431935..8e950073 100644 --- a/tinker_cookbook/distillation/train_on_policy.py +++ b/tinker_cookbook/distillation/train_on_policy.py @@ -142,8 +142,10 @@ class Config: kl_penalty_coef: float = 1.0 kl_discount_factor: float = 0.0 - # Loss function to use for training: "importance_sampling" or "ppo" + # Loss function and configuration. + # See https://tinker-docs.thinkingmachines.ai/losses loss_fn: LossFnType = "importance_sampling" + loss_fn_config: dict[str, Any] | None = None # Number of optimizer steps per training iteration. # Useful for very large batch sizes. @@ -244,11 +246,12 @@ async def do_train_step_and_get_sampling_client( with timed("train", metrics): training_logprobs_D = await train_step( - data_D, - training_client, - cfg.learning_rate, - cfg.num_substeps, - cfg.loss_fn, + data_D=data_D, + training_client=training_client, + learning_rate=cfg.learning_rate, + num_substeps=cfg.num_substeps, + loss_fn=cfg.loss_fn, + loss_fn_config=cfg.loss_fn_config, ) sampling_client, full_batch_metrics = await compute_full_batch_metrics_and_get_sampling_client( diff --git a/tinker_cookbook/recipes/distillation/on_policy_distillation.py b/tinker_cookbook/recipes/distillation/on_policy_distillation.py index 6616327c..06e8aa17 100644 --- a/tinker_cookbook/recipes/distillation/on_policy_distillation.py +++ b/tinker_cookbook/recipes/distillation/on_policy_distillation.py @@ -29,8 +29,10 @@ import logging import os from datetime import datetime +from typing import Any import chz +from tinker.types import LossFnType from tinker_cookbook import cli_utils, model_info from tinker_cookbook.distillation import train_on_policy from tinker_cookbook.distillation.datasets import ( @@ -70,7 +72,11 @@ class CLIConfig: # Optimizer configuration num_substeps: int = 1 - loss_fn: str = "importance_sampling" + + # Loss function and configuration. + # See https://tinker-docs.thinkingmachines.ai/losses + loss_fn: LossFnType = "importance_sampling" + loss_fn_config: dict[str, Any] | None = None # Logging configuration log_path: str | None = None @@ -146,7 +152,8 @@ async def cli_main(cli_config: CLIConfig): kl_penalty_coef=cli_config.kl_penalty_coef, kl_discount_factor=cli_config.kl_discount_factor, num_substeps=cli_config.num_substeps, - loss_fn=cli_config.loss_fn, # type: ignore + loss_fn=cli_config.loss_fn, + loss_fn_config=cli_config.loss_fn_config, wandb_project=cli_config.wandb_project, wandb_name=wandb_name, log_path=log_path, diff --git a/tinker_cookbook/recipes/distillation/on_policy_multi_teacher.py b/tinker_cookbook/recipes/distillation/on_policy_multi_teacher.py index 8a7b47ba..0964bc84 100644 --- a/tinker_cookbook/recipes/distillation/on_policy_multi_teacher.py +++ b/tinker_cookbook/recipes/distillation/on_policy_multi_teacher.py @@ -21,8 +21,10 @@ import logging import os from datetime import datetime +from typing import Any import chz +from tinker.types import LossFnType from tinker_cookbook import cli_utils from tinker_cookbook.distillation import train_on_policy from tinker_cookbook.distillation.datasets import ( @@ -64,7 +66,11 @@ class CLIConfig: # Optimizer configuration num_substeps: int = 1 - loss_fn: str = "importance_sampling" + + # Loss function and configuration. + # See https://tinker-docs.thinkingmachines.ai/losses + loss_fn: LossFnType = "importance_sampling" + loss_fn_config: dict[str, Any] | None = None # Logging configuration log_path: str | None = None @@ -156,7 +162,8 @@ async def cli_main(cli_config: CLIConfig): kl_penalty_coef=cli_config.kl_penalty_coef, kl_discount_factor=cli_config.kl_discount_factor, num_substeps=cli_config.num_substeps, - loss_fn=cli_config.loss_fn, # type: ignore + loss_fn=cli_config.loss_fn, + loss_fn_config=cli_config.loss_fn_config, wandb_project=cli_config.wandb_project, wandb_name=wandb_name, log_path=log_path, diff --git a/tinker_cookbook/recipes/math_rl/train.py b/tinker_cookbook/recipes/math_rl/train.py index c93fac13..f7576e0b 100644 --- a/tinker_cookbook/recipes/math_rl/train.py +++ b/tinker_cookbook/recipes/math_rl/train.py @@ -1,6 +1,7 @@ import asyncio import logging from datetime import datetime +from typing import Any import chz from tinker_cookbook import cli_utils, model_info @@ -59,7 +60,11 @@ class CLIConfig: behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask" max_steps_off_policy: int | None = None + + # Loss function and configuration. + # See https://tinker-docs.thinkingmachines.ai/losses loss_fn: LossFnType = "importance_sampling" + loss_fn_config: dict[str, Any] | None = None def get_dataset_builder( @@ -143,6 +148,7 @@ async def cli_main(cli_config: CLIConfig): if cli_config.max_steps_off_policy is not None else None, loss_fn=cli_config.loss_fn, + loss_fn_config=cli_config.loss_fn_config, ) cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index c49ae6a9..0aa5d694 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -158,10 +158,11 @@ async def forward_backward( training_client: tinker.TrainingClient, batch_d: List[tinker.Datum], loss_fn: LossFnType, + loss_fn_config: dict[str, Any] | None = None, ) -> List[torch.Tensor]: """Accumulate gradients on a minibatch of data""" fwd_bwd_future = await training_client.forward_backward_async( - list(map(remove_mask, batch_d)), loss_fn=loss_fn + data=list(map(remove_mask, batch_d)), loss_fn=loss_fn, loss_fn_config=loss_fn_config ) fwd_bwd_result = await fwd_bwd_future.result_async() @@ -182,12 +183,18 @@ async def train_step( learning_rate: float, num_substeps: int, loss_fn: LossFnType, + loss_fn_config: dict[str, Any] | None = None, ) -> List[torch.Tensor]: """Train the model on collected trajectories.""" batches_md = split_list(data_D, min(num_substeps, len(data_D))) training_logprobs_D: list[torch.Tensor] = [] for batch_d in batches_md: - training_logprobs = await forward_backward(training_client, batch_d, loss_fn) + training_logprobs = await forward_backward( + training_client=training_client, + batch_d=batch_d, + loss_fn=loss_fn, + loss_fn_config=loss_fn_config, + ) training_logprobs_D.extend(training_logprobs) await optim_step(training_client, learning_rate) return training_logprobs_D @@ -237,8 +244,10 @@ class Config: kl_penalty_coef: float = 0.0 kl_discount_factor: float = 0.0 - # Loss function to use for training: "importance_sampling" or "ppo" + # Loss function and configuration. + # See https://tinker-docs.thinkingmachines.ai/losses loss_fn: LossFnType = "importance_sampling" + loss_fn_config: dict[str, Any] | None = None # Number of optimizer steps per training iteration. # Useful for very large batch sizes. @@ -854,9 +863,10 @@ async def do_train_step_streaming_and_get_sampling_client( f"train/forward_backward_substep_{i_substep}_minibatch_{i_minibatch}", metrics ): training_logprobs_D = await forward_backward( - training_client, - data_D, - cfg.loss_fn, + training_client=training_client, + batch_d=data_D, + loss_fn=cfg.loss_fn, + loss_fn_config=cfg.loss_fn_config, ) all_data_D.extend(data_D) all_training_logprobs_D.extend(training_logprobs_D) @@ -919,11 +929,12 @@ async def do_train_step_and_get_sampling_client( with timed("train", metrics): training_logprobs_D = await train_step( - data_D, - training_client, - cfg.learning_rate, - cfg.num_substeps, - cfg.loss_fn, + data_D=data_D, + training_client=training_client, + learning_rate=cfg.learning_rate, + num_substeps=cfg.num_substeps, + loss_fn=cfg.loss_fn, + loss_fn_config=cfg.loss_fn_config, ) sampling_client, full_batch_metrics = await compute_full_batch_metrics_and_get_sampling_client(