diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 37aa08432..67a83ca0a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: timeout-minutes: 10 name: lint runs-on: ${{ github.repository == 'stainless-sdks/togetherai-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} - if: github.event_name == 'push' || github.event.pull_request.head.repo.fork + if: (github.event_name == 'push' || github.event.pull_request.head.repo.fork) && (github.event_name != 'push' || github.event.head_commit.message != 'codegen metadata') steps: - uses: actions/checkout@v6 @@ -38,7 +38,7 @@ jobs: run: uv run ruff format --check build: - if: github.event_name == 'push' || github.event.pull_request.head.repo.fork + if: (github.event_name == 'push' || github.event.pull_request.head.repo.fork) && (github.event_name != 'push' || github.event.head_commit.message != 'codegen metadata') timeout-minutes: 10 name: build permissions: diff --git a/.gitignore b/.gitignore index 95ceb189a..3824f4c48 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .prism.log +.stdy.log _dev __pycache__ diff --git a/.stats.yml b/.stats.yml index 67859756b..dc5073571 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 74 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-31893d157d3c85caa1d8615b73a5fa431ea2cc126bd2410e0f84f3defd5c7dec.yml -openapi_spec_hash: b652a4d504b4a3dbf585ab803b0f59fc +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/togetherai%2Ftogetherai-575d8f82668a560f98d2897556fbda60505d406d22032dd64d17a2eaf36a3f74.yml +openapi_spec_hash: 22cff840f581894c6a8c7d99d097f07d config_hash: 52d213100a0ca1a4b2cdcd2718936b51 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6da2d0c8a..d54cb8837 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,7 +85,7 @@ $ pip install ./path-to-wheel-file.whl ## Running tests -Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. +Most tests require you to [set up a mock server](https://github.com/dgellow/steady) against the OpenAPI spec to run the tests. ```sh $ ./scripts/mock diff --git a/scripts/mock b/scripts/mock index bcf3b392b..09eb49f65 100755 --- a/scripts/mock +++ b/scripts/mock @@ -19,34 +19,34 @@ fi echo "==> Starting mock server with URL ${URL}" -# Run prism mock on the given spec +# Run steady mock on the given spec if [ "$1" == "--daemon" ]; then # Pre-install the package so the download doesn't eat into the startup timeout - npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism --version + npm exec --package=@stdy/cli@0.19.7 -- steady --version - npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" &> .prism.log & + npm exec --package=@stdy/cli@0.19.7 -- steady --host 127.0.0.1 -p 4010 --validator-form-array-format=comma --validator-query-array-format=comma --validator-form-object-format=brackets --validator-query-object-format=brackets "$URL" &> .stdy.log & - # Wait for server to come online (max 30s) + # Wait for server to come online via health endpoint (max 30s) echo -n "Waiting for server" attempts=0 - while ! grep -q "✖ fatal\|Prism is listening" ".prism.log" ; do + while ! curl --silent --fail "http://127.0.0.1:4010/_x-steady/health" >/dev/null 2>&1; do + if ! kill -0 $! 2>/dev/null; then + echo + cat .stdy.log + exit 1 + fi attempts=$((attempts + 1)) if [ "$attempts" -ge 300 ]; then echo - echo "Timed out waiting for Prism server to start" - cat .prism.log + echo "Timed out waiting for Steady server to start" + cat .stdy.log exit 1 fi echo -n "." sleep 0.1 done - if grep -q "✖ fatal" ".prism.log"; then - cat .prism.log - exit 1 - fi - echo else - npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock "$URL" + npm exec --package=@stdy/cli@0.19.7 -- steady --host 127.0.0.1 -p 4010 --validator-form-array-format=comma --validator-query-array-format=comma --validator-form-object-format=brackets --validator-query-object-format=brackets "$URL" fi diff --git a/scripts/test b/scripts/test index b56970b78..d13602ae1 100755 --- a/scripts/test +++ b/scripts/test @@ -9,8 +9,8 @@ GREEN='\033[0;32m' YELLOW='\033[0;33m' NC='\033[0m' # No Color -function prism_is_running() { - curl --silent "http://localhost:4010" >/dev/null 2>&1 +function steady_is_running() { + curl --silent "http://127.0.0.1:4010/_x-steady/health" >/dev/null 2>&1 } kill_server_on_port() { @@ -25,7 +25,7 @@ function is_overriding_api_base_url() { [ -n "$TEST_API_BASE_URL" ] } -if ! is_overriding_api_base_url && ! prism_is_running ; then +if ! is_overriding_api_base_url && ! steady_is_running ; then # When we exit this script, make sure to kill the background mock server process trap 'kill_server_on_port 4010' EXIT @@ -36,19 +36,19 @@ fi if is_overriding_api_base_url ; then echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" echo -elif ! prism_is_running ; then - echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server" +elif ! steady_is_running ; then + echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Steady server" echo -e "running against your OpenAPI spec." echo echo -e "To run the server, pass in the path or url of your OpenAPI" - echo -e "spec to the prism command:" + echo -e "spec to the steady command:" echo - echo -e " \$ ${YELLOW}npm exec --package=@stainless-api/prism-cli@5.15.0 -- prism mock path/to/your.openapi.yml${NC}" + echo -e " \$ ${YELLOW}npm exec --package=@stdy/cli@0.19.7 -- steady path/to/your.openapi.yml --host 127.0.0.1 -p 4010 --validator-form-array-format=comma --validator-query-array-format=comma --validator-form-object-format=brackets --validator-query-object-format=brackets${NC}" echo exit 1 else - echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}" + echo -e "${GREEN}✔ Mock steady server is running with your OpenAPI spec${NC}" echo fi diff --git a/src/together/_utils/__init__.py b/src/together/_utils/__init__.py index dc64e29a1..10cb66d2d 100644 --- a/src/together/_utils/__init__.py +++ b/src/together/_utils/__init__.py @@ -1,3 +1,4 @@ +from ._path import path_template as path_template from ._sync import asyncify as asyncify from ._proxy import LazyProxy as LazyProxy from ._utils import ( diff --git a/src/together/_utils/_path.py b/src/together/_utils/_path.py new file mode 100644 index 000000000..4d6e1e4cb --- /dev/null +++ b/src/together/_utils/_path.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import re +from typing import ( + Any, + Mapping, + Callable, +) +from urllib.parse import quote + +# Matches '.' or '..' where each dot is either literal or percent-encoded (%2e / %2E). +_DOT_SEGMENT_RE = re.compile(r"^(?:\.|%2[eE]){1,2}$") + +_PLACEHOLDER_RE = re.compile(r"\{(\w+)\}") + + +def _quote_path_segment_part(value: str) -> str: + """Percent-encode `value` for use in a URI path segment. + + Considers characters not in `pchar` set from RFC 3986 §3.3 to be unsafe. + https://datatracker.ietf.org/doc/html/rfc3986#section-3.3 + """ + # quote() already treats unreserved characters (letters, digits, and -._~) + # as safe, so we only need to add sub-delims, ':', and '@'. + # Notably, unlike the default `safe` for quote(), / is unsafe and must be quoted. + return quote(value, safe="!$&'()*+,;=:@") + + +def _quote_query_part(value: str) -> str: + """Percent-encode `value` for use in a URI query string. + + Considers &, = and characters not in `query` set from RFC 3986 §3.4 to be unsafe. + https://datatracker.ietf.org/doc/html/rfc3986#section-3.4 + """ + return quote(value, safe="!$'()*+,;:@/?") + + +def _quote_fragment_part(value: str) -> str: + """Percent-encode `value` for use in a URI fragment. + + Considers characters not in `fragment` set from RFC 3986 §3.5 to be unsafe. + https://datatracker.ietf.org/doc/html/rfc3986#section-3.5 + """ + return quote(value, safe="!$&'()*+,;=:@/?") + + +def _interpolate( + template: str, + values: Mapping[str, Any], + quoter: Callable[[str], str], +) -> str: + """Replace {name} placeholders in `template`, quoting each value with `quoter`. + + Placeholder names are looked up in `values`. + + Raises: + KeyError: If a placeholder is not found in `values`. + """ + # re.split with a capturing group returns alternating + # [text, name, text, name, ..., text] elements. + parts = _PLACEHOLDER_RE.split(template) + + for i in range(1, len(parts), 2): + name = parts[i] + if name not in values: + raise KeyError(f"a value for placeholder {{{name}}} was not provided") + val = values[name] + if val is None: + parts[i] = "null" + elif isinstance(val, bool): + parts[i] = "true" if val else "false" + else: + parts[i] = quoter(str(values[name])) + + return "".join(parts) + + +def path_template(template: str, /, **kwargs: Any) -> str: + """Interpolate {name} placeholders in `template` from keyword arguments. + + Args: + template: The template string containing {name} placeholders. + **kwargs: Keyword arguments to interpolate into the template. + + Returns: + The template with placeholders interpolated and percent-encoded. + + Safe characters for percent-encoding are dependent on the URI component. + Placeholders in path and fragment portions are percent-encoded where the `segment` + and `fragment` sets from RFC 3986 respectively are considered safe. + Placeholders in the query portion are percent-encoded where the `query` set from + RFC 3986 §3.3 is considered safe except for = and & characters. + + Raises: + KeyError: If a placeholder is not found in `kwargs`. + ValueError: If resulting path contains /./ or /../ segments (including percent-encoded dot-segments). + """ + # Split the template into path, query, and fragment portions. + fragment_template: str | None = None + query_template: str | None = None + + rest = template + if "#" in rest: + rest, fragment_template = rest.split("#", 1) + if "?" in rest: + rest, query_template = rest.split("?", 1) + path_template = rest + + # Interpolate each portion with the appropriate quoting rules. + path_result = _interpolate(path_template, kwargs, _quote_path_segment_part) + + # Reject dot-segments (. and ..) in the final assembled path. The check + # runs after interpolation so that adjacent placeholders or a mix of static + # text and placeholders that together form a dot-segment are caught. + # Also reject percent-encoded dot-segments to protect against incorrectly + # implemented normalization in servers/proxies. + for segment in path_result.split("/"): + if _DOT_SEGMENT_RE.match(segment): + raise ValueError(f"Constructed path {path_result!r} contains dot-segment {segment!r} which is not allowed") + + result = path_result + if query_template is not None: + result += "?" + _interpolate(query_template, kwargs, _quote_query_part) + if fragment_template is not None: + result += "#" + _interpolate(fragment_template, kwargs, _quote_fragment_part) + + return result diff --git a/src/together/lib/cli/__init__.py b/src/together/lib/cli/__init__.py index 7140e0503..7043cf686 100644 --- a/src/together/lib/cli/__init__.py +++ b/src/together/lib/cli/__init__.py @@ -91,7 +91,7 @@ def block_requests_for_api_key(_: httpx.Request) -> None: invoked_command = click.get_current_context().command_path invoked_command_name = invoked_command.split("together ")[1] click.secho( - "Error: api key missing.\n\nThe api_key must be set either by passing --api-key to the command or by setting the TOGETHER_API_KEY environment variable", + "Error: api key missing.\n\nThe api key must be set either by passing --api-key to the command or by setting the TOGETHER_API_KEY environment variable", fg="red", ) click.secho("\nYou can find your api key at https://api.together.xyz/settings/api-keys", fg="yellow") diff --git a/src/together/lib/cli/api/_utils.py b/src/together/lib/cli/api/_utils.py index 5c7d19575..9ee305aff 100644 --- a/src/together/lib/cli/api/_utils.py +++ b/src/together/lib/cli/api/_utils.py @@ -10,8 +10,10 @@ from functools import wraps import click +from rich import print_json from together import APIError +from together._utils._json import openapi_dumps from together.lib.types.fine_tuning import COMPLETED_STATUSES, FinetuneResponse from together.types.finetune_response import FinetuneResponse as _FinetuneResponse from together.types.fine_tuning_list_response import Data @@ -175,6 +177,7 @@ def handle_api_errors(prefix: str) -> Callable[[F], F]: def decorator(f: F) -> F: @wraps(f) def wrapper(*args: Any, **kwargs: Any) -> Any: + json_mode = kwargs.get("json", False) try: return f(*args, **kwargs) # User aborted the command @@ -187,15 +190,25 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: error_msg = getattr(e.body, "message", str(e.body)) else: error_msg = str(e) - click.echo(prefix_styled + click.style("Failed", fg="red")) - click.echo(prefix_styled + click.style(error_msg, fg="red")) + + if json_mode: + print_json(openapi_dumps({"error": error_msg}).decode("utf-8")) + else: + click.echo(prefix_styled + click.style("Failed", fg="red"), file=sys.stderr) + click.echo(prefix_styled + click.style(error_msg, fg="red"), file=sys.stderr) sys.exit(1) except Exception as e: if os.getenv("TOGETHER_LOG", "").lower() == "debug": # Raise the error with the full traceback raise - click.echo(prefix_styled + click.style("Failed", fg="red")) - click.echo(prefix_styled + click.style(f"An unexpected error occurred - {str(e)}", fg="red")) + if json_mode: + print_json(openapi_dumps({"error": str(e)}).decode("utf-8")) + else: + click.echo(prefix_styled + click.style("Failed", fg="red"), file=sys.stderr) + click.echo( + prefix_styled + click.style(f"An unexpected error occurred - {str(e)}", fg="red"), + file=sys.stderr, + ) sys.exit(1) return wrapper # type: ignore diff --git a/src/together/lib/cli/api/beta/clusters/create.py b/src/together/lib/cli/api/beta/clusters/create.py index 1961edf60..bb9b41c4c 100644 --- a/src/together/lib/cli/api/beta/clusters/create.py +++ b/src/together/lib/cli/api/beta/clusters/create.py @@ -1,13 +1,13 @@ from __future__ import annotations -import json as json_lib import getpass from typing import List, Literal import click -from rich import print +from rich import print, print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors from together.types.beta.cluster_create_params import SharedVolume, ClusterCreateParams @@ -59,6 +59,7 @@ is_flag=True, help="Output in JSON format", ) +@click.option("--non-interactive", is_flag=True, default=False, help="Disable interactive mode") @click.pass_context @handle_api_errors("Clusters") def create( @@ -73,6 +74,7 @@ def create( cluster_type: Literal["KUBERNETES", "SLURM"] | None = None, volume: str | None = None, json: bool = False, + non_interactive: bool = False, ) -> None: """Create a cluster""" client: Together = ctx.obj @@ -93,7 +95,7 @@ def create( params["volume_id"] = volume # JSON Mode skips hand holding through the argument setup - if not json: + if not json and not non_interactive: if not name: params["cluster_name"] = click.prompt("Clusters: Cluster name:", default=getpass.getuser(), type=str) @@ -177,7 +179,7 @@ def create( response = client.beta.clusters.create(**params) if json: - click.echo(json_lib.dumps(response.model_dump(exclude_none=True), indent=4)) + print_json(openapi_dumps(response).decode("utf-8")) else: click.echo(f"Clusters: Cluster created successfully") click.echo(f"Clusters: {response.cluster_id}") diff --git a/src/together/lib/cli/api/beta/clusters/storage/create.py b/src/together/lib/cli/api/beta/clusters/storage/create.py index 91e3f6f4a..d227c89c8 100644 --- a/src/together/lib/cli/api/beta/clusters/storage/create.py +++ b/src/together/lib/cli/api/beta/clusters/storage/create.py @@ -1,8 +1,8 @@ -import json as json_lib - import click +from rich import print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors @@ -43,7 +43,7 @@ def create(ctx: click.Context, region: str, size_tib: int, volume_name: str, jso ) if json: - click.echo(json_lib.dumps(response.model_dump_json(), indent=2)) + print_json(openapi_dumps(response).decode("utf-8")) else: click.echo(f"Storage volume created successfully") click.echo(response.volume_id) diff --git a/src/together/lib/cli/api/beta/jig/jig.py b/src/together/lib/cli/api/beta/jig/jig.py index 23aca141f..2c5c0554e 100644 --- a/src/together/lib/cli/api/beta/jig/jig.py +++ b/src/together/lib/cli/api/beta/jig/jig.py @@ -148,7 +148,8 @@ def validate(value: Any, value_type: type, path: str = "") -> str | None: return err return None - if origin is Union or origin is getattr(types, "UnionType", None): + union_type = getattr(types, "UnionType", None) + if origin is Union or (union_type is not None and origin is union_type): errs = [validate(value, a, path) for a in args if a is not type(None)] if not all(errs): return None @@ -818,7 +819,7 @@ def format_status(self, d: Deployment) -> str: Name : {d.name} ┃ ID: {d.id} Image : {image} Status : {d.status} - Created : {_age(d.created_at)} ┃ Updated : {_age(d.updated_at)}""" + Created : {_age(d.created_at.isoformat() if d.created_at else None)} ┃ Updated : {_age(d.updated_at.isoformat() if d.updated_at else None)}""" ] if a := d.autoscaling: @@ -1071,8 +1072,12 @@ def queue_status(jig: Jig) -> Any: @jig.command("list") +# This method is always outputting json, so it's a bit nebulous to have a --json option +# Doing this for consistency with other commands and to have tests pass for this. +# Eventually we should change this to output human text and json text. +@click.option("--json", "_json_output", is_flag=True, help="Output raw JSON") @_command -def list_deployments(jig: Jig) -> Any: +def list_deployments(jig: Jig, _json_output: bool) -> Any: """List all deployments""" return jig.api.with_raw_response.list() diff --git a/src/together/lib/cli/api/endpoints/create.py b/src/together/lib/cli/api/endpoints/create.py index 058a3587e..551b3ffb5 100644 --- a/src/together/lib/cli/api/endpoints/create.py +++ b/src/together/lib/cli/api/endpoints/create.py @@ -149,8 +149,8 @@ def create( or "the selected configuration" in error_msg or "hardware is required" in error_msg ): - click.secho("Invalid hardware selected.", fg="red", err=True) - click.echo("\nAvailable hardware options:") + click.secho("Invalid hardware selected.", fg="red", err=True, file=sys.stderr) + click.echo("\nAvailable hardware options:", file=sys.stderr) ctx.invoke(list_hardware, available=True, model=model, json=False) sys.exit(1) elif "model" in error_msg and ( @@ -162,14 +162,17 @@ def create( click.echo( f"Error: Model '{model}' was not found or is not available for dedicated endpoints.", err=True, + file=sys.stderr, ) click.echo( "Please check that the model name is correct and that it supports dedicated endpoint deployment.", err=True, + file=sys.stderr, ) click.echo( "You can browse available models at: https://api.together.ai/models", err=True, + file=sys.stderr, ) sys.exit(1) raise e diff --git a/src/together/lib/cli/api/endpoints/delete.py b/src/together/lib/cli/api/endpoints/delete.py index f7b63fec7..2902c479c 100644 --- a/src/together/lib/cli/api/endpoints/delete.py +++ b/src/together/lib/cli/api/endpoints/delete.py @@ -16,6 +16,7 @@ def delete(client: Together, endpoint_id: str, json: bool) -> None: """Delete a dedicated inference endpoint.""" client.endpoints.delete(endpoint_id) + if json: click.echo(json_lib.dumps({"message": "Successfully deleted endpoint"}, indent=2)) return diff --git a/src/together/lib/cli/api/endpoints/hardware.py b/src/together/lib/cli/api/endpoints/hardware.py index 0b8467a00..5f91a5b1d 100644 --- a/src/together/lib/cli/api/endpoints/hardware.py +++ b/src/together/lib/cli/api/endpoints/hardware.py @@ -1,16 +1,16 @@ from __future__ import annotations import re -import json as json_lib from typing import Any, Dict, List import click +from rich import print_json from tabulate import tabulate from together import Together, omit from together.types import EndpointListHardwareResponse +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors -from together.lib.utils.serializer import datetime_serializer from together.lib.cli.api.endpoints._utils import handle_endpoint_api_errors @@ -36,8 +36,7 @@ def hardware(client: Together, model: str | None, json: bool, available: bool) - if hardware.availability is not None and hardware.availability.status == "available" ] if json: - json_output = [hardware.model_dump() for hardware in hardware_options.data] - click.echo(json_lib.dumps(json_output, default=datetime_serializer, indent=2)) + print_json(openapi_dumps(hardware_options.data).decode("utf-8")) else: _format_hardware_options(hardware_options, show_availability=model is not None) diff --git a/src/together/lib/cli/api/endpoints/list.py b/src/together/lib/cli/api/endpoints/list.py index cf594ff44..cd121523f 100644 --- a/src/together/lib/cli/api/endpoints/list.py +++ b/src/together/lib/cli/api/endpoints/list.py @@ -1,13 +1,13 @@ from __future__ import annotations -import json as json_lib from typing import Literal import click +from rich import print_json from together import Together, omit +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors -from together.lib.utils.serializer import datetime_serializer from together.lib.cli.api.endpoints._utils import handle_endpoint_api_errors @@ -49,11 +49,7 @@ def list( ) if json: - click.echo( - json_lib.dumps( - [endpoint.model_dump() for endpoint in endpoints.data], default=datetime_serializer, indent=2 - ) - ) + print_json(openapi_dumps(endpoints.data).decode("utf-8")) return if not endpoints: diff --git a/src/together/lib/cli/api/endpoints/retrieve.py b/src/together/lib/cli/api/endpoints/retrieve.py index 5873d3e31..a7340ae69 100644 --- a/src/together/lib/cli/api/endpoints/retrieve.py +++ b/src/together/lib/cli/api/endpoints/retrieve.py @@ -1,8 +1,9 @@ import click +from rich import print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors -from together.lib.utils.serializer import datetime_serializer from together.lib.cli.api.endpoints._utils import handle_endpoint_api_errors @@ -18,8 +19,6 @@ def retrieve(ctx: click.Context, endpoint_id: str, json: bool) -> None: endpoint = client.endpoints.retrieve(endpoint_id) if json: - import json as json_lib - - click.echo(json_lib.dumps(endpoint.model_dump(), indent=2, default=datetime_serializer)) + print_json(openapi_dumps(endpoint.model_dump()).decode("utf-8")) else: ctx.obj.print_endpoint(endpoint) diff --git a/src/together/lib/cli/api/endpoints/start.py b/src/together/lib/cli/api/endpoints/start.py index 003087681..4d2ef7ff0 100644 --- a/src/together/lib/cli/api/endpoints/start.py +++ b/src/together/lib/cli/api/endpoints/start.py @@ -1,10 +1,10 @@ -import json as json_lib import click +from rich import print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors -from together.lib.utils.serializer import datetime_serializer from together.lib.cli.api.endpoints._utils import handle_endpoint_api_errors @@ -20,7 +20,7 @@ def start(client: Together, endpoint_id: str, wait: bool, json: bool) -> None: response = client.endpoints.update(endpoint_id, state="STARTED") if json: - click.echo(json_lib.dumps(response.model_dump(), default=datetime_serializer, indent=2)) + print_json(openapi_dumps(response.model_dump()).decode("utf-8")) return click.echo("Successfully marked endpoint as starting", err=True) diff --git a/src/together/lib/cli/api/endpoints/update.py b/src/together/lib/cli/api/endpoints/update.py index 4781b475a..245bf6174 100644 --- a/src/together/lib/cli/api/endpoints/update.py +++ b/src/together/lib/cli/api/endpoints/update.py @@ -1,14 +1,14 @@ from __future__ import annotations import sys -import json as json_lib from typing import Any, Dict import click +from rich import print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors -from together.lib.utils.serializer import datetime_serializer from together.lib.cli.api.endpoints._utils import handle_endpoint_api_errors @@ -69,7 +69,7 @@ def update( response = client.endpoints.update(endpoint_id, **kwargs) if json: - click.echo(json_lib.dumps(response.model_dump(), default=datetime_serializer, indent=2)) + print_json(openapi_dumps(response.model_dump()).decode("utf-8")) return # Print what was updated diff --git a/src/together/lib/cli/api/evals/list.py b/src/together/lib/cli/api/evals/list.py index b39d90646..6c253c6e2 100644 --- a/src/together/lib/cli/api/evals/list.py +++ b/src/together/lib/cli/api/evals/list.py @@ -1,9 +1,11 @@ from typing import Any, Dict, List, Union, Literal import click +from rich import print_json from tabulate import tabulate from together import Together, omit +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors @@ -18,12 +20,14 @@ type=int, help="Limit number of results (max 100).", ) +@click.option("--json", is_flag=True, help="Print output in JSON format") @click.pass_context @handle_api_errors("Evals") def list( ctx: click.Context, status: Union[Literal["pending", "queued", "running", "completed", "error", "user_error"], None], limit: Union[int, None], + json: bool, ) -> None: """List evals""" @@ -31,6 +35,10 @@ def list( response = client.evals.list(status=status or omit, limit=limit or omit) + if json: + print_json(openapi_dumps(response).decode("utf-8")) + return + display_list: List[Dict[str, Any]] = [] for job in response: if job.parameters: diff --git a/src/together/lib/cli/api/evals/retrieve.py b/src/together/lib/cli/api/evals/retrieve.py index 15b964948..7beec6dbc 100644 --- a/src/together/lib/cli/api/evals/retrieve.py +++ b/src/together/lib/cli/api/evals/retrieve.py @@ -1,21 +1,43 @@ -import json +from typing import Any import click +from rich import print, print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors -from together.lib.utils.serializer import datetime_serializer @click.command() @click.pass_context @click.argument("evaluation_id", type=str, required=True) +@click.option("--json", is_flag=True, help="Print output in JSON format") @handle_api_errors("Evals") -def retrieve(ctx: click.Context, evaluation_id: str) -> None: +def retrieve(ctx: click.Context, evaluation_id: str, json: bool) -> None: """Get details of a specific evaluation job""" client: Together = ctx.obj response = client.evals.retrieve(evaluation_id) - click.echo(json.dumps(response.model_dump(exclude_none=True), default=datetime_serializer, indent=4)) + if json: + print_json(openapi_dumps(response.model_dump(exclude_none=True)).decode("utf-8")) + else: + print_dict(response.to_dict()) + + +def print_dict(data: Any, indent: int = 0) -> None: + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, dict): + print(f"{' ' * indent}[bold dim]{key}:[/bold dim]") + print_dict(value, indent=indent+2) + elif isinstance(value, list): + print(f"{' ' * indent}[bold dim]{key}:[/bold dim]") + for index, item in enumerate(value): + print(f"{' ' * indent}[bold dim][{index}]:[/bold dim]") + print_dict(item, indent=indent+2) + else: + print(f"{' ' * indent}[bold dim]{key}:[/bold dim] {value}") + else: + print(f"{' ' * indent}{data}") \ No newline at end of file diff --git a/src/together/lib/cli/api/evals/status.py b/src/together/lib/cli/api/evals/status.py index ea788dc50..5c95577d1 100644 --- a/src/together/lib/cli/api/evals/status.py +++ b/src/together/lib/cli/api/evals/status.py @@ -1,20 +1,38 @@ -import json +from __future__ import annotations + +from typing import Any import click +from rich import print, print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors @click.command() @click.pass_context @click.argument("evaluation_id", type=str, required=True) +@click.option("--json", is_flag=True, help="Print output in JSON format") @handle_api_errors("Evals") -def status(ctx: click.Context, evaluation_id: str) -> None: +def status(ctx: click.Context, evaluation_id: str, json: bool) -> None: """Get the status and results of a specific evaluation job""" client: Together = ctx.obj response = client.evals.status(evaluation_id) - click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4)) + if json: + print_json(openapi_dumps(response).decode("utf-8")) + return + else: + print(f"[bold dim]Status:[/bold dim] [green]{response.status}[/green]") + print_dict(response.results.to_dict() if response.results else {}) + +def print_dict(data: dict[str, Any], indent: int = 0) -> None: + for key, value in data.items(): + if isinstance(value, dict): + print(f"{' ' * indent}[bold dim]{key}:[/bold dim]") + print_dict(value, indent=indent+2) + else: + print(f"{' ' * indent}[bold dim]{key}:[/bold dim] {value}") diff --git a/src/together/lib/cli/api/files/check.py b/src/together/lib/cli/api/files/check.py index 62bd02744..3a16893d6 100644 --- a/src/together/lib/cli/api/files/check.py +++ b/src/together/lib/cli/api/files/check.py @@ -1,7 +1,9 @@ -import json +import sys +import json as json_lib import pathlib import click +from rich import print, print_json from together.lib.utils import check_file @@ -13,9 +15,20 @@ type=click.Path(exists=True, file_okay=True, resolve_path=True, readable=True, dir_okay=False), required=True, ) -def check(_ctx: click.Context, file: pathlib.Path) -> None: +@click.option( + "--json", + is_flag=True, + help="Output the response in JSON format", +) +def check(_ctx: click.Context, file: pathlib.Path, json: bool) -> None: """Check file for issues""" report = check_file(file) - click.echo(json.dumps(report, indent=4)) + if json: + print_json(json_lib.dumps(report)) + else: + icon = "✅" if report["is_check_passed"] else "❌" + print(f"{icon} {report['message']}") + if report["is_check_passed"] is False: + sys.exit(1) diff --git a/src/together/lib/cli/api/files/delete.py b/src/together/lib/cli/api/files/delete.py index 33cd8a233..4eb26f144 100644 --- a/src/together/lib/cli/api/files/delete.py +++ b/src/together/lib/cli/api/files/delete.py @@ -1,20 +1,29 @@ -import json - import click +from rich import print, print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors @click.command() @click.pass_context @click.argument("id", type=str, required=True) +@click.option( + "--json", + is_flag=True, + help="Output the response in JSON format", +) @handle_api_errors("Files") -def delete(ctx: click.Context, id: str) -> None: +def delete(ctx: click.Context, id: str, json: bool) -> None: """Delete remote file""" client: Together = ctx.obj response = client.files.delete(id=id) - click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4)) + if json: + print_json(openapi_dumps(response).decode("utf-8")) + return + + print(f"[green]File {id} deleted[/green]") diff --git a/src/together/lib/cli/api/files/list.py b/src/together/lib/cli/api/files/list.py index d48910ddc..99dec3242 100644 --- a/src/together/lib/cli/api/files/list.py +++ b/src/together/lib/cli/api/files/list.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone import click +from rich import print_json from tabulate import tabulate from together import Together @@ -29,7 +30,7 @@ def list(ctx: click.Context, json: bool) -> None: response.data.sort(key=lambda x: x.created_at or epoch_start, reverse=True) if json: - click.echo(openapi_dumps(response.data)) + print_json(openapi_dumps(response.data).decode("utf-8")) return display_list: List[Dict[str, Any]] = [] diff --git a/src/together/lib/cli/api/files/retrieve.py b/src/together/lib/cli/api/files/retrieve.py index f8b10f2a7..146f18d53 100644 --- a/src/together/lib/cli/api/files/retrieve.py +++ b/src/together/lib/cli/api/files/retrieve.py @@ -1,20 +1,36 @@ -import json - import click +from rich import print, print_json from together import Together +from together.lib.utils import convert_bytes, convert_unix_timestamp +from together._utils._json import openapi_dumps +from together.lib.utils.tools import format_timestamp from together.lib.cli.api._utils import handle_api_errors @click.command() @click.pass_context @click.argument("id", type=str, required=True) +@click.option( + "--json", + is_flag=True, + help="Output the response in JSON format", +) @handle_api_errors("Files") -def retrieve(ctx: click.Context, id: str) -> None: - """Upload file""" +def retrieve(ctx: click.Context, id: str, json: bool) -> None: + """Retrieve file details""" client: Together = ctx.obj response = client.files.retrieve(id=id) - click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4)) + if json: + print_json(openapi_dumps(response).decode("utf-8")) + return + + # print(f"[bold]File details [/bold][dim white]({response.id})[/dim white]") + print(f"[dim]Name[/dim]: [white]{response.filename}[/white]") + print(f"[dim]Size[/dim]: [white]{convert_bytes(response.bytes)}[/white]") + print(f"[dim]Type[/dim]: [white]{response.file_type}[/white]") + print(f"[dim]Purpose[/dim]: [white]{response.purpose}[/white]") + print(f"[dim]Created[/dim]: [white]{format_timestamp(convert_unix_timestamp(response.created_at))}[/white]") diff --git a/src/together/lib/cli/api/files/upload.py b/src/together/lib/cli/api/files/upload.py index 1a79560e9..7a447e9dc 100644 --- a/src/together/lib/cli/api/files/upload.py +++ b/src/together/lib/cli/api/files/upload.py @@ -1,9 +1,14 @@ +import os +import sys +import json as json_lib import pathlib from typing import get_args import click +from rich import print, print_json from together import Together +from together.lib import check_file from together.types import FilePurpose from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors @@ -37,11 +42,25 @@ def upload(ctx: click.Context, file: pathlib.Path, purpose: FilePurpose, check: """Upload file""" client: Together = ctx.obj + if json: + os.environ.setdefault("TOGETHER_DISABLE_TQDM", "true") + + # Manually handle check here so we can exit and provide the user good error messages + if check: + report = check_file(file) + if report["is_check_passed"] is False: + if json: + print_json(json_lib.dumps(report)) + else: + print(f"❌ {report['message']}") + + # Make sure to exit + sys.exit(1) - response = client.files.upload(file=file, purpose=purpose, check=check) + response = client.files.upload(file=file, purpose=purpose, check=False) if json: - click.echo(openapi_dumps(response.model_dump(exclude_none=True))) + print_json(openapi_dumps(response).decode("utf-8")) return click.echo( diff --git a/src/together/lib/cli/api/fine_tuning/cancel.py b/src/together/lib/cli/api/fine_tuning/cancel.py index bc4e80739..02cba217b 100644 --- a/src/together/lib/cli/api/fine_tuning/cancel.py +++ b/src/together/lib/cli/api/fine_tuning/cancel.py @@ -1,10 +1,11 @@ -import json +import sys import click +from rich import print, print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors -from together.lib.utils.serializer import datetime_serializer NON_CANCELLABLE_STATES = ["cancel_requested", "cancelled", "error", "completed", "user_error"] @@ -13,16 +14,22 @@ @click.pass_context @click.argument("fine_tune_id", type=str, required=True) @click.option("--quiet", is_flag=True, help="Do not prompt for confirmation before cancelling job") +@click.option("--json", is_flag=True, help="Print output in JSON format, must use --force to use this option") @handle_api_errors("Fine-tuning") -def cancel(ctx: click.Context, fine_tune_id: str, quiet: bool = False) -> None: +def cancel(ctx: click.Context, fine_tune_id: str, quiet: bool = False, json: bool = False) -> None: """Cancel fine-tuning job""" client: Together = ctx.obj job = client.fine_tuning.retrieve(fine_tune_id) + + if json and not quiet: + raise click.BadOptionUsage("json", "To use json mode, you must use --quiet") + if job.status in NON_CANCELLABLE_STATES: click.echo( click.style(f"Fine-tuning: ", fg="blue") + f"Training is not currently cancellable. Current status is " - + click.style(job.status, fg="yellow") + + click.style(job.status, fg="yellow"), + file=sys.stderr if json else None, ) return @@ -32,8 +39,16 @@ def cancel(ctx: click.Context, fine_tune_id: str, quiet: bool = False) -> None: f"Do you want to cancel job {fine_tune_id}? [y/N]" ) if "y" not in confirm_response.lower(): - click.echo(json.dumps({"status": "Cancel not submitted"}, indent=4)) + if json: + print_json('{"status": "Cancel not submitted"}') + else: + click.echo("Cancel not submitted") return + response = client.fine_tuning.cancel(fine_tune_id) - click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4, default=datetime_serializer)) + if json: + print_json(openapi_dumps(response).decode("utf-8")) + return + + print("Cancelled fine-tuning job") diff --git a/src/together/lib/cli/api/fine_tuning/create.py b/src/together/lib/cli/api/fine_tuning/create.py index 626952685..d13287d64 100644 --- a/src/together/lib/cli/api/fine_tuning/create.py +++ b/src/together/lib/cli/api/fine_tuning/create.py @@ -49,6 +49,7 @@ def get_confirmation_message(price: str, warning: str) -> str: default="", help="Validation file ID from Files API or local path to a file to be uploaded.", ) +@click.option("--packing", type=bool, default=True, help="Whether to use packing for training.") @click.option("--n-evals", type=int, default=0, help="Number of evaluation loops") @click.option("--n-checkpoints", "-c", type=int, default=1, help="Number of checkpoints to save") @click.option("--batch-size", "-b", type=INT_WITH_MAX, default="max", help="Train batch size") @@ -215,6 +216,7 @@ def create( training_file: str, validation_file: str, model: str | None, + packing: bool, n_epochs: int, n_evals: int, n_checkpoints: int, @@ -260,6 +262,7 @@ def create( model=model, n_epochs=n_epochs, validation_file=validation_file, + packing=packing, n_evals=n_evals, n_checkpoints=n_checkpoints, batch_size=batch_size, diff --git a/src/together/lib/cli/api/fine_tuning/delete.py b/src/together/lib/cli/api/fine_tuning/delete.py index fd1aade4c..b83581f96 100644 --- a/src/together/lib/cli/api/fine_tuning/delete.py +++ b/src/together/lib/cli/api/fine_tuning/delete.py @@ -1,8 +1,8 @@ -import json - import click +from rich import print, print_json from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors @@ -10,13 +10,19 @@ @click.pass_context @click.argument("fine_tune_id", type=str, required=True) @click.option("--force", is_flag=True, help="Force deletion without confirmation") -@click.option("--quiet", is_flag=True, help="Do not prompt for confirmation before deleting job") +@click.option("--quiet", is_flag=True, help="Deprecated, use --force instead") +@click.option("--json", is_flag=True, help="Print output in JSON format, must use --force to use this option") @handle_api_errors("Fine-tuning") -def delete(ctx: click.Context, fine_tune_id: str, force: bool = False, quiet: bool = False) -> None: +def delete(ctx: click.Context, fine_tune_id: str, force: bool = False, quiet: bool = False, json: bool = False) -> None: """Delete fine-tuning job""" client: Together = ctx.obj - if not quiet: + skip_confirmation = force or quiet + + if not skip_confirmation: + if json: + raise click.BadOptionUsage("json", "To use json mode, you must use --force") + confirm_response = input( f"Are you sure you want to delete fine-tuning job {fine_tune_id}? This action cannot be undone. [y/N] " ) @@ -24,6 +30,10 @@ def delete(ctx: click.Context, fine_tune_id: str, force: bool = False, quiet: bo click.echo("Deletion cancelled") return - response = client.fine_tuning.delete(fine_tune_id, force=force) + response = client.fine_tuning.delete(fine_tune_id) + + if json: + print_json(openapi_dumps(response).decode("utf-8")) + return - click.echo(json.dumps(response.model_dump(exclude_none=True), indent=4)) + print(f"Deleted fine-tuning job") diff --git a/src/together/lib/cli/api/fine_tuning/download.py b/src/together/lib/cli/api/fine_tuning/download.py index d30840694..cd7b123ac 100644 --- a/src/together/lib/cli/api/fine_tuning/download.py +++ b/src/together/lib/cli/api/fine_tuning/download.py @@ -1,7 +1,7 @@ from __future__ import annotations +import os import re -import json from typing import Union, Literal from pathlib import Path @@ -9,6 +9,7 @@ from together import NOT_GIVEN, APIError, NotGiven, Together, APIStatusError from together.lib import DownloadManager +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors from together.types.finetune_response import TrainingTypeFullTrainingType, TrainingTypeLoRaTrainingType @@ -41,6 +42,7 @@ default="merged", help="Specifies checkpoint type. 'merged' and 'adapter' options work only for LoRA jobs.", ) +@click.option("--json", is_flag=True, help="Print output in JSON format") @handle_api_errors("Fine-tuning") def download( ctx: click.Context, @@ -48,6 +50,7 @@ def download( output_dir: str | None = None, checkpoint_step: Union[int, NotGiven] = NOT_GIVEN, checkpoint_type: Literal["default", "merged", "adapter"] | NotGiven = NOT_GIVEN, + json: bool = False, ) -> None: """Download fine-tuning checkpoint""" client: Together = ctx.obj @@ -92,6 +95,10 @@ def download( if isinstance(output_dir, str): output = Path(output_dir) + # Disable tqdm for json mode + if json: + os.environ.setdefault("TOGETHER_DISABLE_TQDM", "true") + try: file_path, file_size = DownloadManager(client).download( url=url, @@ -101,7 +108,9 @@ def download( ) click.echo( - json.dumps({"object": "local", "id": fine_tune_id, "filename": file_path, "size": file_size}, indent=4) + openapi_dumps({"object": "local", "id": fine_tune_id, "filename": file_path, "size": file_size}).decode( + "utf-8" + ) ) except APIStatusError as e: raise APIError( diff --git a/src/together/lib/cli/api/fine_tuning/list.py b/src/together/lib/cli/api/fine_tuning/list.py index c11aa980e..b727af428 100644 --- a/src/together/lib/cli/api/fine_tuning/list.py +++ b/src/together/lib/cli/api/fine_tuning/list.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone import click +from rich import print_json from tabulate import tabulate from together import Together @@ -29,7 +30,7 @@ def list(ctx: click.Context, json: bool) -> None: response.data.sort(key=lambda x: x.created_at or epoch_start, reverse=True) if json: - click.echo(openapi_dumps(response.data)) + print_json(openapi_dumps(response.data).decode("utf-8")) return display_list: List[Dict[str, Any]] = [] diff --git a/src/together/lib/cli/api/fine_tuning/list_checkpoints.py b/src/together/lib/cli/api/fine_tuning/list_checkpoints.py index 44cd9a2d7..08064ee04 100644 --- a/src/together/lib/cli/api/fine_tuning/list_checkpoints.py +++ b/src/together/lib/cli/api/fine_tuning/list_checkpoints.py @@ -1,9 +1,11 @@ from typing import Any, Dict, List import click +from rich import print_json from tabulate import tabulate from together import Together +from together._utils._json import openapi_dumps from together.lib.utils.tools import format_timestamp from together.lib.cli.api._utils import handle_api_errors @@ -11,14 +13,19 @@ @click.command() @click.pass_context @click.argument("fine_tune_id", type=str, required=True) +@click.option("--json", is_flag=True, help="Print output in JSON format") @handle_api_errors("Fine-tuning") -def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None: +def list_checkpoints(ctx: click.Context, fine_tune_id: str, json: bool) -> None: """List available checkpoints for a fine-tuning job""" client: Together = ctx.obj checkpoints = client.fine_tuning.list_checkpoints(fine_tune_id) checkpoints.data = checkpoints.data or [] + if json: + print_json(openapi_dumps(checkpoints.data).decode("utf-8")) + return + display_list: List[Dict[str, Any]] = [] for checkpoint in checkpoints.data: name = ( diff --git a/src/together/lib/cli/api/fine_tuning/list_events.py b/src/together/lib/cli/api/fine_tuning/list_events.py index 0d0b4aa44..363c3de8b 100644 --- a/src/together/lib/cli/api/fine_tuning/list_events.py +++ b/src/together/lib/cli/api/fine_tuning/list_events.py @@ -2,17 +2,20 @@ from textwrap import wrap import click +from rich import print_json from tabulate import tabulate from together import Together +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors @click.command() @click.pass_context @click.argument("fine_tune_id", type=str, required=True) +@click.option("--json", is_flag=True, help="Print output in JSON format") @handle_api_errors("Fine-tuning") -def list_events(ctx: click.Context, fine_tune_id: str) -> None: +def list_events(ctx: click.Context, fine_tune_id: str, json: bool) -> None: """List fine-tuning events""" client: Together = ctx.obj @@ -20,6 +23,10 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None: response.data = response.data or [] + if json: + print_json(openapi_dumps(response.data).decode("utf-8")) + return + display_list: List[Dict[str, Any]] = [] for i in response.data: display_list.append( diff --git a/src/together/lib/cli/api/fine_tuning/retrieve.py b/src/together/lib/cli/api/fine_tuning/retrieve.py index 2747cec4b..a8d8746e3 100644 --- a/src/together/lib/cli/api/fine_tuning/retrieve.py +++ b/src/together/lib/cli/api/fine_tuning/retrieve.py @@ -1,13 +1,11 @@ from datetime import datetime, timezone import click -from rich import print as rprint -from rich.json import JSON +from rich import print as rprint, print_json from together import Together from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors, generate_progress_bar -from together.lib.utils.serializer import datetime_serializer @click.command() @@ -22,13 +20,13 @@ def retrieve(ctx: click.Context, fine_tune_id: str, json: bool) -> None: response = client.fine_tuning.retrieve(fine_tune_id) if json: - click.echo(openapi_dumps(response.model_dump(exclude_none=True))) + print_json(openapi_dumps(response).decode("utf-8")) return # remove events from response for cleaner output response.events = None - rprint(JSON.from_data(response.model_dump(exclude_none=True), default=datetime_serializer)) + print_json(openapi_dumps(response).decode("utf-8")) progress_text = generate_progress_bar(response, datetime.now(timezone.utc), use_rich=True) prefix = f"Status: [bold]{response.status}[/bold]," rprint(f"{prefix} {progress_text}") diff --git a/src/together/lib/cli/api/models/list.py b/src/together/lib/cli/api/models/list.py index e109d9988..23aca2bd6 100644 --- a/src/together/lib/cli/api/models/list.py +++ b/src/together/lib/cli/api/models/list.py @@ -1,13 +1,13 @@ -import json as json_lib from typing import Any, Dict, List, Optional import click +from rich import print_json from tabulate import tabulate from together import Together, omit from together._response import APIResponse as APIResponse +from together._utils._json import openapi_dumps from together.lib.cli.api._utils import handle_api_errors -from together.lib.utils.serializer import datetime_serializer @click.command() @@ -30,8 +30,7 @@ def list(ctx: click.Context, type: Optional[str], json: bool) -> None: models_list = client.models.list(dedicated=type == "dedicated" if type else omit) if json: - items = [model.model_dump() for model in models_list] - click.echo(json_lib.dumps(items, indent=2, default=datetime_serializer)) + print_json(openapi_dumps(models_list).decode("utf-8")) return display_list: List[Dict[str, Any]] = [] diff --git a/src/together/lib/constants.py b/src/together/lib/constants.py index 0061e5437..6bc0ba4b0 100644 --- a/src/together/lib/constants.py +++ b/src/together/lib/constants.py @@ -13,7 +13,6 @@ # Download defaults DOWNLOAD_BLOCK_SIZE = 10 * 1024 * 1024 # 10 MB -DISABLE_TQDM = False MAX_DOWNLOAD_RETRIES = 5 # Maximum retries for download failures DOWNLOAD_INITIAL_RETRY_DELAY = 1.0 # Initial retry delay in seconds DOWNLOAD_MAX_RETRY_DELAY = 30.0 # Maximum retry delay in seconds @@ -75,3 +74,7 @@ class DatasetFormat(enum.Enum): } REQUIRED_COLUMNS_MESSAGE = ["role"] POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant", "tool"] + +# DO NOT USE THIS +# Pull this from the environment variable TOGETHER_DISABLE_TQDM so cli can disable tqdm if needed +DISABLE_TQDM = False diff --git a/src/together/lib/resources/files.py b/src/together/lib/resources/files.py index 6645fe19a..be98b5aa0 100644 --- a/src/together/lib/resources/files.py +++ b/src/together/lib/resources/files.py @@ -25,7 +25,6 @@ from ...types import FileType, FilePurpose, FileResponse from ..._types import RequestOptions from ..constants import ( - DISABLE_TQDM, NUM_BYTES_IN_GB, MAX_FILE_SIZE_GB, MIN_PART_SIZE_MB, @@ -211,6 +210,8 @@ def download( retry_count = 0 retry_delay = DOWNLOAD_INITIAL_RETRY_DELAY + DISABLE_TQDM = os.environ.get("TOGETHER_DISABLE_TQDM", "false").lower() == "true" + with tqdm( total=file_size, unit="B", @@ -391,6 +392,7 @@ def _upload_single_file( redirect_url, file_id = self.get_upload_url(url, file, checksum, purpose, filetype) # type: ignore file_size = os.stat(file.as_posix()).st_size + DISABLE_TQDM = os.environ.get("TOGETHER_DISABLE_TQDM", "false").lower() == "true" with tqdm( total=file_size, @@ -564,6 +566,8 @@ def _upload_parts_concurrent(self, file: Path, upload_info: Dict[str, Any], part parts = upload_info["parts"] completed_parts: List[Dict[str, Any]] = [] + DISABLE_TQDM = os.environ.get("TOGETHER_DISABLE_TQDM", "false").lower() == "true" + with ThreadPoolExecutor(max_workers=self.max_concurrent_parts) as executor: with tqdm(total=len(parts), desc="Uploading parts", unit="part", disable=bool(DISABLE_TQDM)) as pbar: with open(file, "rb") as f: @@ -795,6 +799,8 @@ async def _upload_single_file( file_size = os.stat(file.as_posix()).st_size + DISABLE_TQDM = os.environ.get("TOGETHER_DISABLE_TQDM", "false").lower() == "true" + with tqdm( total=file_size, unit="B", @@ -956,6 +962,8 @@ async def _upload_parts_concurrent( # Use ThreadPoolExecutor for HTTP I/O efficiency loop = asyncio.get_event_loop() + DISABLE_TQDM = os.environ.get("TOGETHER_DISABLE_TQDM", "false").lower() == "true" + with ThreadPoolExecutor(max_workers=self.max_concurrent_parts) as executor: with tqdm(total=len(parts), desc="Uploading parts", unit="part", disable=bool(DISABLE_TQDM)) as pbar: with open(file, "rb") as f: diff --git a/src/together/lib/resources/fine_tuning.py b/src/together/lib/resources/fine_tuning.py index 89ea5b09b..26bd64790 100644 --- a/src/together/lib/resources/fine_tuning.py +++ b/src/together/lib/resources/fine_tuning.py @@ -37,6 +37,7 @@ def create_finetune_request( model: str | None = None, n_epochs: int = 1, validation_file: str | None = "", + packing: bool = True, n_evals: int | None = 0, n_checkpoints: int | None = 1, batch_size: int | Literal["max"] = "max", @@ -234,6 +235,7 @@ def create_finetune_request( model=model, training_file=training_file, validation_file=validation_file, + packing=packing, n_epochs=n_epochs, n_evals=n_evals, n_checkpoints=n_checkpoints, diff --git a/src/together/lib/types/fine_tuning.py b/src/together/lib/types/fine_tuning.py index 88eae4390..95a84280e 100644 --- a/src/together/lib/types/fine_tuning.py +++ b/src/together/lib/types/fine_tuning.py @@ -451,6 +451,8 @@ class FinetuneRequest(BaseModel): training_file: str # validation file id validation_file: Union[str, None] = None + # whether to use packing for training + packing: bool = True # base model string model: Union[str, None] = None # number of epochs to train for diff --git a/src/together/lib/utils/files.py b/src/together/lib/utils/files.py index d86c175bf..86d0c0ceb 100644 --- a/src/together/lib/utils/files.py +++ b/src/together/lib/utils/files.py @@ -12,7 +12,6 @@ from together.types import FilePurpose from together.lib.constants import ( MIN_SAMPLES, - DISABLE_TQDM, MAX_IMAGE_BYTES, NUM_BYTES_IN_GB, MAX_FILE_SIZE_GB, @@ -627,6 +626,8 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]: if not report_dict["utf8"]: return report_dict + DISABLE_TQDM = os.environ.get("TOGETHER_DISABLE_TQDM", "false").lower() == "true" + dataset_format = None with file.open() as f: idx = -1 diff --git a/src/together/lib/utils/serializer.py b/src/together/lib/utils/serializer.py deleted file mode 100644 index 302286bc2..000000000 --- a/src/together/lib/utils/serializer.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from typing import Any -from datetime import datetime - - -def datetime_serializer(obj: Any) -> str: - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") diff --git a/src/together/resources/audio/speech.py b/src/together/resources/audio/speech.py index b2b195180..6bc982716 100644 --- a/src/together/resources/audio/speech.py +++ b/src/together/resources/audio/speech.py @@ -89,7 +89,7 @@ def create( You can view the voices supported for each model using the /v1/voices endpoint sending the model name as the query parameter. - [View all supported voices here](https://docs.together.ai/docs/text-to-speech#voices-available). + [View all supported voices here](https://docs.together.ai/docs/text-to-speech#supported-voices). language: Language of input text. @@ -160,7 +160,7 @@ def create( You can view the voices supported for each model using the /v1/voices endpoint sending the model name as the query parameter. - [View all supported voices here](https://docs.together.ai/docs/text-to-speech#voices-available). + [View all supported voices here](https://docs.together.ai/docs/text-to-speech#supported-voices). language: Language of input text. @@ -227,7 +227,7 @@ def create( You can view the voices supported for each model using the /v1/voices endpoint sending the model name as the query parameter. - [View all supported voices here](https://docs.together.ai/docs/text-to-speech#voices-available). + [View all supported voices here](https://docs.together.ai/docs/text-to-speech#supported-voices). language: Language of input text. @@ -357,7 +357,7 @@ async def create( You can view the voices supported for each model using the /v1/voices endpoint sending the model name as the query parameter. - [View all supported voices here](https://docs.together.ai/docs/text-to-speech#voices-available). + [View all supported voices here](https://docs.together.ai/docs/text-to-speech#supported-voices). language: Language of input text. @@ -428,7 +428,7 @@ async def create( You can view the voices supported for each model using the /v1/voices endpoint sending the model name as the query parameter. - [View all supported voices here](https://docs.together.ai/docs/text-to-speech#voices-available). + [View all supported voices here](https://docs.together.ai/docs/text-to-speech#supported-voices). language: Language of input text. @@ -495,7 +495,7 @@ async def create( You can view the voices supported for each model using the /v1/voices endpoint sending the model name as the query parameter. - [View all supported voices here](https://docs.together.ai/docs/text-to-speech#voices-available). + [View all supported voices here](https://docs.together.ai/docs/text-to-speech#supported-voices). language: Language of input text. diff --git a/src/together/resources/batches.py b/src/together/resources/batches.py index cb9067598..58c61accc 100644 --- a/src/together/resources/batches.py +++ b/src/together/resources/batches.py @@ -6,7 +6,7 @@ from ..types import batch_create_params from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given -from .._utils import maybe_transform, async_maybe_transform +from .._utils import path_template, maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -126,7 +126,7 @@ def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/batches/{id}", + path_template("/batches/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -180,7 +180,7 @@ def cancel( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._post( - f"/batches/{id}/cancel", + path_template("/batches/{id}/cancel", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -291,7 +291,7 @@ async def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/batches/{id}", + path_template("/batches/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -345,7 +345,7 @@ async def cancel( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._post( - f"/batches/{id}/cancel", + path_template("/batches/{id}/cancel", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/beta/clusters/clusters.py b/src/together/resources/beta/clusters/clusters.py index 6f5eebcc0..fe8c85508 100644 --- a/src/together/resources/beta/clusters/clusters.py +++ b/src/together/resources/beta/clusters/clusters.py @@ -15,7 +15,7 @@ AsyncStorageResourceWithStreamingResponse, ) from ...._types import Body, Omit, Query, Headers, NotGiven, omit, not_given -from ...._utils import maybe_transform, async_maybe_transform +from ...._utils import path_template, maybe_transform, async_maybe_transform from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource from ...._response import ( @@ -170,7 +170,7 @@ def retrieve( if not cluster_id: raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") return self._get( - f"/compute/clusters/{cluster_id}", + path_template("/compute/clusters/{cluster_id}", cluster_id=cluster_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -212,7 +212,7 @@ def update( if not cluster_id: raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") return self._put( - f"/compute/clusters/{cluster_id}", + path_template("/compute/clusters/{cluster_id}", cluster_id=cluster_id), body=maybe_transform( { "cluster_type": cluster_type, @@ -273,7 +273,7 @@ def delete( if not cluster_id: raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") return self._delete( - f"/compute/clusters/{cluster_id}", + path_template("/compute/clusters/{cluster_id}", cluster_id=cluster_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -436,7 +436,7 @@ async def retrieve( if not cluster_id: raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") return await self._get( - f"/compute/clusters/{cluster_id}", + path_template("/compute/clusters/{cluster_id}", cluster_id=cluster_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -478,7 +478,7 @@ async def update( if not cluster_id: raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") return await self._put( - f"/compute/clusters/{cluster_id}", + path_template("/compute/clusters/{cluster_id}", cluster_id=cluster_id), body=await async_maybe_transform( { "cluster_type": cluster_type, @@ -539,7 +539,7 @@ async def delete( if not cluster_id: raise ValueError(f"Expected a non-empty value for `cluster_id` but received {cluster_id!r}") return await self._delete( - f"/compute/clusters/{cluster_id}", + path_template("/compute/clusters/{cluster_id}", cluster_id=cluster_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/beta/clusters/storage.py b/src/together/resources/beta/clusters/storage.py index 0d6260e03..c6abf44e8 100644 --- a/src/together/resources/beta/clusters/storage.py +++ b/src/together/resources/beta/clusters/storage.py @@ -5,7 +5,7 @@ import httpx from ...._types import Body, Omit, Query, Headers, NotGiven, omit, not_given -from ...._utils import maybe_transform, async_maybe_transform +from ...._utils import path_template, maybe_transform, async_maybe_transform from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource from ...._response import ( @@ -122,7 +122,7 @@ def retrieve( if not volume_id: raise ValueError(f"Expected a non-empty value for `volume_id` but received {volume_id!r}") return self._get( - f"/compute/clusters/storage/volumes/{volume_id}", + path_template("/compute/clusters/storage/volumes/{volume_id}", volume_id=volume_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -221,7 +221,7 @@ def delete( if not volume_id: raise ValueError(f"Expected a non-empty value for `volume_id` but received {volume_id!r}") return self._delete( - f"/compute/clusters/storage/volumes/{volume_id}", + path_template("/compute/clusters/storage/volumes/{volume_id}", volume_id=volume_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -328,7 +328,7 @@ async def retrieve( if not volume_id: raise ValueError(f"Expected a non-empty value for `volume_id` but received {volume_id!r}") return await self._get( - f"/compute/clusters/storage/volumes/{volume_id}", + path_template("/compute/clusters/storage/volumes/{volume_id}", volume_id=volume_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -427,7 +427,7 @@ async def delete( if not volume_id: raise ValueError(f"Expected a non-empty value for `volume_id` but received {volume_id!r}") return await self._delete( - f"/compute/clusters/storage/volumes/{volume_id}", + path_template("/compute/clusters/storage/volumes/{volume_id}", volume_id=volume_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/beta/jig/jig.py b/src/together/resources/beta/jig/jig.py index 953eb2eb4..24f4a1b59 100644 --- a/src/together/resources/beta/jig/jig.py +++ b/src/together/resources/beta/jig/jig.py @@ -32,7 +32,7 @@ AsyncVolumesResourceWithStreamingResponse, ) from ...._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given -from ...._utils import maybe_transform, async_maybe_transform +from ...._utils import path_template, maybe_transform, async_maybe_transform from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource from ...._response import ( @@ -110,7 +110,7 @@ def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/deployments/{id}", + path_template("/deployments/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -128,7 +128,7 @@ def update( description: str | Omit = omit, environment_variables: Iterable[jig_update_params.EnvironmentVariable] | Omit = omit, gpu_count: int | Omit = omit, - gpu_type: Literal["h100-80gb", " a100-80gb"] | Omit = omit, + gpu_type: Literal["h100-80gb"] | Omit = omit, health_check_path: str | Omit = omit, image: str | Omit = omit, max_replicas: int | Omit = omit, @@ -155,8 +155,7 @@ def update( args: Args overrides the container's CMD. Provide as an array of arguments (e.g., ["python", "app.py"]) - autoscaling: Autoscaling configuration for the deployment. Omit or set to null to disable - autoscaling + autoscaling: Autoscaling configuration for the deployment. Set to {} to disable autoscaling command: Command overrides the container's ENTRYPOINT. Provide as an array (e.g., ["/bin/sh", "-c"]) @@ -211,7 +210,7 @@ def update( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._patch( - f"/deployments/{id}", + path_template("/deployments/{id}", id=id), body=maybe_transform( { "args": args, @@ -263,7 +262,7 @@ def list( def deploy( self, *, - gpu_type: Literal["h100-80gb", "a100-80gb"], + gpu_type: Literal["h100-80gb"], image: str, name: str, args: SequenceNotStr[str] | Omit = omit, @@ -411,7 +410,7 @@ def destroy( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._delete( - f"/deployments/{id}", + path_template("/deployments/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -449,7 +448,7 @@ def retrieve_logs( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/deployments/{id}/logs", + path_template("/deployments/{id}/logs", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, @@ -521,7 +520,7 @@ async def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/deployments/{id}", + path_template("/deployments/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -539,7 +538,7 @@ async def update( description: str | Omit = omit, environment_variables: Iterable[jig_update_params.EnvironmentVariable] | Omit = omit, gpu_count: int | Omit = omit, - gpu_type: Literal["h100-80gb", " a100-80gb"] | Omit = omit, + gpu_type: Literal["h100-80gb"] | Omit = omit, health_check_path: str | Omit = omit, image: str | Omit = omit, max_replicas: int | Omit = omit, @@ -566,8 +565,7 @@ async def update( args: Args overrides the container's CMD. Provide as an array of arguments (e.g., ["python", "app.py"]) - autoscaling: Autoscaling configuration for the deployment. Omit or set to null to disable - autoscaling + autoscaling: Autoscaling configuration for the deployment. Set to {} to disable autoscaling command: Command overrides the container's ENTRYPOINT. Provide as an array (e.g., ["/bin/sh", "-c"]) @@ -622,7 +620,7 @@ async def update( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._patch( - f"/deployments/{id}", + path_template("/deployments/{id}", id=id), body=await async_maybe_transform( { "args": args, @@ -674,7 +672,7 @@ async def list( async def deploy( self, *, - gpu_type: Literal["h100-80gb", "a100-80gb"], + gpu_type: Literal["h100-80gb"], image: str, name: str, args: SequenceNotStr[str] | Omit = omit, @@ -822,7 +820,7 @@ async def destroy( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._delete( - f"/deployments/{id}", + path_template("/deployments/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -860,7 +858,7 @@ async def retrieve_logs( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/deployments/{id}/logs", + path_template("/deployments/{id}/logs", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, diff --git a/src/together/resources/beta/jig/secrets.py b/src/together/resources/beta/jig/secrets.py index f1b874ca2..64c80bcae 100644 --- a/src/together/resources/beta/jig/secrets.py +++ b/src/together/resources/beta/jig/secrets.py @@ -5,7 +5,7 @@ import httpx from ...._types import Body, Omit, Query, Headers, NotGiven, omit, not_given -from ...._utils import maybe_transform, async_maybe_transform +from ...._utils import path_template, maybe_transform, async_maybe_transform from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource from ...._response import ( @@ -126,7 +126,7 @@ def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/deployments/secrets/{id}", + path_template("/deployments/secrets/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -178,7 +178,7 @@ def update( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._patch( - f"/deployments/secrets/{id}", + path_template("/deployments/secrets/{id}", id=id), body=maybe_transform( { "description": description, @@ -241,7 +241,7 @@ def delete( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._delete( - f"/deployments/secrets/{id}", + path_template("/deployments/secrets/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -353,7 +353,7 @@ async def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/deployments/secrets/{id}", + path_template("/deployments/secrets/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -405,7 +405,7 @@ async def update( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._patch( - f"/deployments/secrets/{id}", + path_template("/deployments/secrets/{id}", id=id), body=await async_maybe_transform( { "description": description, @@ -468,7 +468,7 @@ async def delete( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._delete( - f"/deployments/secrets/{id}", + path_template("/deployments/secrets/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/beta/jig/volumes.py b/src/together/resources/beta/jig/volumes.py index f93447ee2..817058984 100644 --- a/src/together/resources/beta/jig/volumes.py +++ b/src/together/resources/beta/jig/volumes.py @@ -7,7 +7,7 @@ import httpx from ...._types import Body, Omit, Query, Headers, NotGiven, omit, not_given -from ...._utils import maybe_transform, async_maybe_transform +from ...._utils import path_template, maybe_transform, async_maybe_transform from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource from ...._response import ( @@ -119,7 +119,7 @@ def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/deployments/storage/volumes/{id}", + path_template("/deployments/storage/volumes/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -163,7 +163,7 @@ def update( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._patch( - f"/deployments/storage/volumes/{id}", + path_template("/deployments/storage/volumes/{id}", id=id), body=maybe_transform( { "content": content, @@ -225,7 +225,7 @@ def delete( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._delete( - f"/deployments/storage/volumes/{id}", + path_template("/deployments/storage/volumes/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -328,7 +328,7 @@ async def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/deployments/storage/volumes/{id}", + path_template("/deployments/storage/volumes/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -372,7 +372,7 @@ async def update( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._patch( - f"/deployments/storage/volumes/{id}", + path_template("/deployments/storage/volumes/{id}", id=id), body=await async_maybe_transform( { "content": content, @@ -434,7 +434,7 @@ async def delete( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._delete( - f"/deployments/storage/volumes/{id}", + path_template("/deployments/storage/volumes/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/endpoints.py b/src/together/resources/endpoints.py index 573e82bcf..afa3cfce3 100644 --- a/src/together/resources/endpoints.py +++ b/src/together/resources/endpoints.py @@ -14,7 +14,7 @@ endpoint_list_hardware_params, ) from .._types import Body, Omit, Query, Headers, NoneType, NotGiven, omit, not_given -from .._utils import maybe_transform, async_maybe_transform +from .._utils import path_template, maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -158,7 +158,7 @@ def retrieve( if not endpoint_id: raise ValueError(f"Expected a non-empty value for `endpoint_id` but received {endpoint_id!r}") return self._get( - f"/endpoints/{endpoint_id}", + path_template("/endpoints/{endpoint_id}", endpoint_id=endpoint_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -208,7 +208,7 @@ def update( if not endpoint_id: raise ValueError(f"Expected a non-empty value for `endpoint_id` but received {endpoint_id!r}") return self._patch( - f"/endpoints/{endpoint_id}", + path_template("/endpoints/{endpoint_id}", endpoint_id=endpoint_id), body=maybe_transform( { "autoscaling": autoscaling, @@ -306,7 +306,7 @@ def delete( raise ValueError(f"Expected a non-empty value for `endpoint_id` but received {endpoint_id!r}") extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._delete( - f"/endpoints/{endpoint_id}", + path_template("/endpoints/{endpoint_id}", endpoint_id=endpoint_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -500,7 +500,7 @@ async def retrieve( if not endpoint_id: raise ValueError(f"Expected a non-empty value for `endpoint_id` but received {endpoint_id!r}") return await self._get( - f"/endpoints/{endpoint_id}", + path_template("/endpoints/{endpoint_id}", endpoint_id=endpoint_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -550,7 +550,7 @@ async def update( if not endpoint_id: raise ValueError(f"Expected a non-empty value for `endpoint_id` but received {endpoint_id!r}") return await self._patch( - f"/endpoints/{endpoint_id}", + path_template("/endpoints/{endpoint_id}", endpoint_id=endpoint_id), body=await async_maybe_transform( { "autoscaling": autoscaling, @@ -648,7 +648,7 @@ async def delete( raise ValueError(f"Expected a non-empty value for `endpoint_id` but received {endpoint_id!r}") extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._delete( - f"/endpoints/{endpoint_id}", + path_template("/endpoints/{endpoint_id}", endpoint_id=endpoint_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/evals.py b/src/together/resources/evals.py index b3ebd5634..16b93c22d 100644 --- a/src/together/resources/evals.py +++ b/src/together/resources/evals.py @@ -8,7 +8,7 @@ from ..types import eval_list_params, eval_create_params from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given -from .._utils import maybe_transform, async_maybe_transform +from .._utils import path_template, maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -117,7 +117,7 @@ def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/evaluation/{id}", + path_template("/evaluation/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -198,7 +198,7 @@ def status( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/evaluation/{id}/status", + path_template("/evaluation/{id}/status", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -297,7 +297,7 @@ async def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/evaluation/{id}", + path_template("/evaluation/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -378,7 +378,7 @@ async def status( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/evaluation/{id}/status", + path_template("/evaluation/{id}/status", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/files.py b/src/together/resources/files.py index 74121dab1..a72af09fc 100644 --- a/src/together/resources/files.py +++ b/src/together/resources/files.py @@ -14,6 +14,7 @@ from ..lib import FileTypeError, UploadManager, AsyncUploadManager, check_file from ..types import FilePurpose from .._types import Body, Query, Headers, NotGiven, not_given +from .._utils import path_template from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -86,7 +87,7 @@ def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/files/{id}", + path_template("/files/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -140,7 +141,7 @@ def delete( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._delete( - f"/files/{id}", + path_template("/files/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -213,7 +214,7 @@ def content( raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") extra_headers = {"Accept": "application/binary", **(extra_headers or {})} return self._get( - f"/files/{id}/content", + path_template("/files/{id}/content", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -269,7 +270,7 @@ async def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/files/{id}", + path_template("/files/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -323,7 +324,7 @@ async def delete( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._delete( - f"/files/{id}", + path_template("/files/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -396,7 +397,7 @@ async def content( raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") extra_headers = {"Accept": "application/binary", **(extra_headers or {})} return await self._get( - f"/files/{id}/content", + path_template("/files/{id}/content", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/fine_tuning.py b/src/together/resources/fine_tuning.py index 025b73ce8..ec05286f9 100644 --- a/src/together/resources/fine_tuning.py +++ b/src/together/resources/fine_tuning.py @@ -10,7 +10,7 @@ from ..types import fine_tuning_delete_params, fine_tuning_content_params, fine_tuning_estimate_price_params from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given -from .._utils import maybe_transform, async_maybe_transform +from .._utils import path_template, maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -82,6 +82,7 @@ def create( model: str | None = None, n_epochs: int = 1, validation_file: str | None = "", + packing: bool = True, n_evals: int | None = 0, n_checkpoints: int | None = 1, batch_size: int | Literal["max"] = "max", @@ -127,6 +128,7 @@ def create( model (str, optional): Name of the base model to run fine-tune job on n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1. validation file (str, optional): File ID of a file uploaded to the Together API for validation. + packing (bool, optional): Whether to use packing for training. Defaults to True. n_evals (int, optional): Number of evaluation loops to run. Defaults to 0. n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning. Defaults to 1. @@ -210,6 +212,7 @@ def create( model=model, n_epochs=n_epochs, validation_file=validation_file, + packing=packing, n_evals=n_evals, n_checkpoints=n_checkpoints, batch_size=batch_size, @@ -311,7 +314,7 @@ def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/fine-tunes/{id}", + path_template("/fine-tunes/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -372,7 +375,7 @@ def delete( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._delete( - f"/fine-tunes/{id}", + path_template("/fine-tunes/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, @@ -413,7 +416,7 @@ def cancel( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._post( - f"/fine-tunes/{id}/cancel", + path_template("/fine-tunes/{id}/cancel", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -574,7 +577,7 @@ def list_checkpoints( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/fine-tunes/{id}/checkpoints", + path_template("/fine-tunes/{id}/checkpoints", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -609,7 +612,7 @@ def list_events( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/fine-tunes/{id}/events", + path_template("/fine-tunes/{id}/events", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -644,6 +647,7 @@ async def create( model: str | None = None, n_epochs: int = 1, validation_file: str | None = "", + packing: bool = True, n_evals: int | None = 0, n_checkpoints: int | None = 1, batch_size: int | Literal["max"] = "max", @@ -689,6 +693,7 @@ async def create( model (str, optional): Name of the base model to run fine-tune job on n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1. validation file (str, optional): File ID of a file uploaded to the Together API for validation. + packing (bool, optional): Whether to use packing for training. Defaults to True. n_evals (int, optional): Number of evaluation loops to run. Defaults to 0. n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning. Defaults to 1. @@ -772,6 +777,7 @@ async def create( model=model, n_epochs=n_epochs, validation_file=validation_file, + packing=packing, n_evals=n_evals, n_checkpoints=n_checkpoints, batch_size=batch_size, @@ -868,7 +874,7 @@ async def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/fine-tunes/{id}", + path_template("/fine-tunes/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -929,7 +935,7 @@ async def delete( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._delete( - f"/fine-tunes/{id}", + path_template("/fine-tunes/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, @@ -970,7 +976,7 @@ async def cancel( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._post( - f"/fine-tunes/{id}/cancel", + path_template("/fine-tunes/{id}/cancel", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -1131,7 +1137,7 @@ async def list_checkpoints( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/fine-tunes/{id}/checkpoints", + path_template("/fine-tunes/{id}/checkpoints", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -1166,7 +1172,7 @@ async def list_events( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/fine-tunes/{id}/events", + path_template("/fine-tunes/{id}/events", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/models/uploads.py b/src/together/resources/models/uploads.py index f0cacfcd4..95260eff3 100644 --- a/src/together/resources/models/uploads.py +++ b/src/together/resources/models/uploads.py @@ -5,6 +5,7 @@ import httpx from ..._types import Body, Query, Headers, NotGiven, not_given +from ..._utils import path_template from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -67,7 +68,7 @@ def status( if not job_id: raise ValueError(f"Expected a non-empty value for `job_id` but received {job_id!r}") return self._get( - f"/jobs/{job_id}", + path_template("/jobs/{job_id}", job_id=job_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -123,7 +124,7 @@ async def status( if not job_id: raise ValueError(f"Expected a non-empty value for `job_id` but received {job_id!r}") return await self._get( - f"/jobs/{job_id}", + path_template("/jobs/{job_id}", job_id=job_id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/resources/videos.py b/src/together/resources/videos.py index 81a7a78aa..d721fb980 100644 --- a/src/together/resources/videos.py +++ b/src/together/resources/videos.py @@ -9,7 +9,7 @@ from ..types import video_create_params from .._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given -from .._utils import maybe_transform, async_maybe_transform +from .._utils import path_template, maybe_transform, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -169,7 +169,8 @@ def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( - f"/videos/{id}" if self._client._base_url_overridden else f"https://api.together.xyz/v2/videos/{id}", + ("https://api.together.xyz/v2" if not self._client._base_url_overridden else "") + + path_template("/videos/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -322,7 +323,8 @@ async def retrieve( if not id: raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( - f"/videos/{id}" if self._client._base_url_overridden else f"https://api.together.xyz/v2/videos/{id}", + ("https://api.together.xyz/v2" if not self._client._base_url_overridden else "") + + path_template("/videos/{id}", id=id), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/together/types/audio/speech_create_params.py b/src/together/types/audio/speech_create_params.py index 1db9f9ca6..b2e0767ed 100644 --- a/src/together/types/audio/speech_create_params.py +++ b/src/together/types/audio/speech_create_params.py @@ -30,7 +30,7 @@ class SpeechCreateParamsBase(TypedDict, total=False): You can view the voices supported for each model using the /v1/voices endpoint sending the model name as the query parameter. - [View all supported voices here](https://docs.together.ai/docs/text-to-speech#voices-available). + [View all supported voices here](https://docs.together.ai/docs/text-to-speech#supported-voices). """ language: Literal["en", "de", "fr", "es", "hi", "it", "ja", "ko", "nl", "pl", "pt", "ru", "sv", "tr", "zh"] diff --git a/src/together/types/beta/deployment.py b/src/together/types/beta/deployment.py index 30fbb9213..64bf81ab5 100644 --- a/src/together/types/beta/deployment.py +++ b/src/together/types/beta/deployment.py @@ -1,6 +1,7 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Dict, List, Union, Optional +from datetime import datetime from typing_extensions import Literal, TypeAlias from ..._models import BaseModel @@ -180,7 +181,7 @@ class Deployment(BaseModel): value is allowed) """ - created_at: Optional[str] = None + created_at: Optional[datetime] = None """CreatedAt is the ISO8601 timestamp when this deployment was created""" description: Optional[str] = None @@ -246,7 +247,7 @@ class Deployment(BaseModel): allocated to each replica """ - updated_at: Optional[str] = None + updated_at: Optional[datetime] = None """UpdatedAt is the ISO8601 timestamp when this deployment was last updated""" volumes: Optional[List[Volume]] = None diff --git a/src/together/types/beta/jig_deploy_params.py b/src/together/types/beta/jig_deploy_params.py index 097fd890f..fbd30b6ac 100644 --- a/src/together/types/beta/jig_deploy_params.py +++ b/src/together/types/beta/jig_deploy_params.py @@ -19,7 +19,7 @@ class JigDeployParams(TypedDict, total=False): - gpu_type: Required[Literal["h100-80gb", "a100-80gb"]] + gpu_type: Required[Literal["h100-80gb"]] """GPUType specifies the GPU hardware to use (e.g., "h100-80gb").""" image: Required[str] diff --git a/src/together/types/beta/jig_update_params.py b/src/together/types/beta/jig_update_params.py index 7e742b6df..deb33e9eb 100644 --- a/src/together/types/beta/jig_update_params.py +++ b/src/together/types/beta/jig_update_params.py @@ -26,10 +26,7 @@ class JigUpdateParams(TypedDict, total=False): """ autoscaling: Autoscaling - """Autoscaling configuration for the deployment. - - Omit or set to null to disable autoscaling - """ + """Autoscaling configuration for the deployment. Set to {} to disable autoscaling""" command: SequenceNotStr[str] """Command overrides the container's ENTRYPOINT. @@ -55,7 +52,7 @@ class JigUpdateParams(TypedDict, total=False): gpu_count: int """GPUCount is the number of GPUs to allocate per container instance""" - gpu_type: Literal["h100-80gb", " a100-80gb"] + gpu_type: Literal["h100-80gb"] """GPUType specifies the GPU hardware to use (e.g., "h100-80gb")""" health_check_path: str diff --git a/src/together/types/fine_tuning_cancel_response.py b/src/together/types/fine_tuning_cancel_response.py index 345973c5c..531bb6b5b 100644 --- a/src/together/types/fine_tuning_cancel_response.py +++ b/src/together/types/fine_tuning_cancel_response.py @@ -174,9 +174,19 @@ class FineTuningCancelResponse(BaseModel): owner_address: Optional[str] = None """Owner address information""" + packing: Optional[bool] = None + """Whether sequence packing is being used for training.""" + progress: Optional[Progress] = None """Progress information for the fine-tuning job""" + random_seed: Optional[int] = None + """Random seed used for training. + + Integer when set; null if not stored (e.g. legacy jobs) or no explicit seed was + recorded. + """ + started_at: Optional[datetime] = None """Start timestamp of the current stage of the fine-tune job""" diff --git a/src/together/types/fine_tuning_list_response.py b/src/together/types/fine_tuning_list_response.py index 5b87c9461..b7cccd413 100644 --- a/src/together/types/fine_tuning_list_response.py +++ b/src/together/types/fine_tuning_list_response.py @@ -175,9 +175,19 @@ class Data(BaseModel): owner_address: Optional[str] = None """Owner address information""" + packing: Optional[bool] = None + """Whether sequence packing is being used for training.""" + progress: Optional[DataProgress] = None """Progress information for the fine-tuning job""" + random_seed: Optional[int] = None + """Random seed used for training. + + Integer when set; null if not stored (e.g. legacy jobs) or no explicit seed was + recorded. + """ + started_at: Optional[datetime] = None """Start timestamp of the current stage of the fine-tune job""" diff --git a/tests/api_resources/audio/test_transcriptions.py b/tests/api_resources/audio/test_transcriptions.py index 19b71667b..32b0a5d2f 100644 --- a/tests/api_resources/audio/test_transcriptions.py +++ b/tests/api_resources/audio/test_transcriptions.py @@ -25,6 +25,7 @@ def test_method_create(self, client: Together) -> None: assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) @parametrize + @pytest.mark.skip(reason="Skipping this test for now - I beleive stainless introduced a bug here") def test_method_create_with_all_params(self, client: Together) -> None: transcription = client.audio.transcriptions.create( file=b"Example data", @@ -78,6 +79,7 @@ async def test_method_create(self, async_client: AsyncTogether) -> None: assert_matches_type(TranscriptionCreateResponse, transcription, path=["response"]) @parametrize + @pytest.mark.skip(reason="Skipping this test for now - I beleive stainless introduced a bug here") async def test_method_create_with_all_params(self, async_client: AsyncTogether) -> None: transcription = await async_client.audio.transcriptions.create( file=b"Example data", diff --git a/tests/api_resources/audio/test_translations.py b/tests/api_resources/audio/test_translations.py index 2bec84b6a..f0ea12375 100644 --- a/tests/api_resources/audio/test_translations.py +++ b/tests/api_resources/audio/test_translations.py @@ -25,6 +25,7 @@ def test_method_create(self, client: Together) -> None: assert_matches_type(TranslationCreateResponse, translation, path=["response"]) @parametrize + @pytest.mark.skip(reason="Skipping this test for now - I beleive stainless introduced a bug here") def test_method_create_with_all_params(self, client: Together) -> None: translation = client.audio.translations.create( file=b"Example data", @@ -75,6 +76,7 @@ async def test_method_create(self, async_client: AsyncTogether) -> None: assert_matches_type(TranslationCreateResponse, translation, path=["response"]) @parametrize + @pytest.mark.skip(reason="Skipping this test for now - I beleive stainless introduced a bug here") async def test_method_create_with_all_params(self, async_client: AsyncTogether) -> None: translation = await async_client.audio.translations.create( file=b"Example data", diff --git a/tests/api_resources/beta/test_jig.py b/tests/api_resources/beta/test_jig.py index dc1113532..30f5e8303 100644 --- a/tests/api_resources/beta/test_jig.py +++ b/tests/api_resources/beta/test_jig.py @@ -91,11 +91,11 @@ def test_method_update_with_all_params(self, client: Together) -> None: health_check_path="health_check_path", image="image", max_replicas=0, - memory=0.1, + memory=1000, min_replicas=0, name="x", - port=0, - storage=0, + port=1, + storage=400, termination_grace_period_seconds=0, volumes=[ { @@ -197,10 +197,10 @@ def test_method_deploy_with_all_params(self, client: Together) -> None: gpu_count=0, health_check_path="health_check_path", max_replicas=0, - memory=0.1, + memory=1000, min_replicas=0, - port=0, - storage=0, + port=1, + storage=400, termination_grace_period_seconds=0, volumes=[ { @@ -400,11 +400,11 @@ async def test_method_update_with_all_params(self, async_client: AsyncTogether) health_check_path="health_check_path", image="image", max_replicas=0, - memory=0.1, + memory=1000, min_replicas=0, name="x", - port=0, - storage=0, + port=1, + storage=400, termination_grace_period_seconds=0, volumes=[ { @@ -506,10 +506,10 @@ async def test_method_deploy_with_all_params(self, async_client: AsyncTogether) gpu_count=0, health_check_path="health_check_path", max_replicas=0, - memory=0.1, + memory=1000, min_replicas=0, - port=0, - storage=0, + port=1, + storage=400, termination_grace_period_seconds=0, volumes=[ { diff --git a/tests/api_resources/code_interpreter/test_sessions.py b/tests/api_resources/code_interpreter/test_sessions.py index 687efd2ad..f3959569e 100644 --- a/tests/api_resources/code_interpreter/test_sessions.py +++ b/tests/api_resources/code_interpreter/test_sessions.py @@ -17,13 +17,11 @@ class TestSessions: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize def test_method_list(self, client: Together) -> None: session = client.code_interpreter.sessions.list() assert_matches_type(SessionListResponse, session, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize def test_raw_response_list(self, client: Together) -> None: response = client.code_interpreter.sessions.with_raw_response.list() @@ -33,7 +31,6 @@ def test_raw_response_list(self, client: Together) -> None: session = response.parse() assert_matches_type(SessionListResponse, session, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize def test_streaming_response_list(self, client: Together) -> None: with client.code_interpreter.sessions.with_streaming_response.list() as response: @@ -51,13 +48,11 @@ class TestAsyncSessions: "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] ) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize async def test_method_list(self, async_client: AsyncTogether) -> None: session = await async_client.code_interpreter.sessions.list() assert_matches_type(SessionListResponse, session, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize async def test_raw_response_list(self, async_client: AsyncTogether) -> None: response = await async_client.code_interpreter.sessions.with_raw_response.list() @@ -67,7 +62,6 @@ async def test_raw_response_list(self, async_client: AsyncTogether) -> None: session = await response.parse() assert_matches_type(SessionListResponse, session, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize async def test_streaming_response_list(self, async_client: AsyncTogether) -> None: async with async_client.code_interpreter.sessions.with_streaming_response.list() as response: diff --git a/tests/api_resources/test_code_interpreter.py b/tests/api_resources/test_code_interpreter.py index 2d06e917e..9f73f6bba 100644 --- a/tests/api_resources/test_code_interpreter.py +++ b/tests/api_resources/test_code_interpreter.py @@ -17,7 +17,6 @@ class TestCodeInterpreter: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize def test_method_execute(self, client: Together) -> None: code_interpreter = client.code_interpreter.execute( @@ -26,7 +25,6 @@ def test_method_execute(self, client: Together) -> None: ) assert_matches_type(ExecuteResponse, code_interpreter, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize def test_method_execute_with_all_params(self, client: Together) -> None: code_interpreter = client.code_interpreter.execute( @@ -43,7 +41,6 @@ def test_method_execute_with_all_params(self, client: Together) -> None: ) assert_matches_type(ExecuteResponse, code_interpreter, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize def test_raw_response_execute(self, client: Together) -> None: response = client.code_interpreter.with_raw_response.execute( @@ -56,7 +53,6 @@ def test_raw_response_execute(self, client: Together) -> None: code_interpreter = response.parse() assert_matches_type(ExecuteResponse, code_interpreter, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize def test_streaming_response_execute(self, client: Together) -> None: with client.code_interpreter.with_streaming_response.execute( @@ -77,7 +73,6 @@ class TestAsyncCodeInterpreter: "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] ) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize async def test_method_execute(self, async_client: AsyncTogether) -> None: code_interpreter = await async_client.code_interpreter.execute( @@ -86,7 +81,6 @@ async def test_method_execute(self, async_client: AsyncTogether) -> None: ) assert_matches_type(ExecuteResponse, code_interpreter, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize async def test_method_execute_with_all_params(self, async_client: AsyncTogether) -> None: code_interpreter = await async_client.code_interpreter.execute( @@ -103,7 +97,6 @@ async def test_method_execute_with_all_params(self, async_client: AsyncTogether) ) assert_matches_type(ExecuteResponse, code_interpreter, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize async def test_raw_response_execute(self, async_client: AsyncTogether) -> None: response = await async_client.code_interpreter.with_raw_response.execute( @@ -116,7 +109,6 @@ async def test_raw_response_execute(self, async_client: AsyncTogether) -> None: code_interpreter = await response.parse() assert_matches_type(ExecuteResponse, code_interpreter, path=["response"]) - @pytest.mark.skip(reason="Mock server doesn't support callbacks yet") @parametrize async def test_streaming_response_execute(self, async_client: AsyncTogether) -> None: async with async_client.code_interpreter.with_streaming_response.execute( diff --git a/tests/cli/data.jsonl b/tests/cli/data.jsonl new file mode 100644 index 000000000..9e26dfeeb --- /dev/null +++ b/tests/cli/data.jsonl @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/tests/cli/test_beta_clusters.py b/tests/cli/test_beta_clusters.py new file mode 100644 index 000000000..da0177279 --- /dev/null +++ b/tests/cli/test_beta_clusters.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import os +import json +import base64 +from typing import Any, cast + +import httpx +import pytest +from respx import MockRouter +from respx.models import Call +from click.testing import CliRunner + +from together.lib.cli import main + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +API_KEY = "0000000000000000000000000000000000000000" +_ENV = {"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY} + + +def _cluster_body(cluster_id: str = "cluster-1", name: str = "my-cluster", **overrides: Any) -> dict[str, Any]: + body: dict[str, Any] = { + "cluster_id": cluster_id, + "cluster_name": name, + "cluster_type": "KUBERNETES", + "control_plane_nodes": [], + "driver_version": "CUDA_12_6_565", + "duration_hours": 24, + "gpu_type": "H100_SXM", + "gpu_worker_nodes": [], + "kube_config": base64.b64encode(b"").decode("ascii"), + "num_gpus": 8, + "region": "us-central-8", + "status": "Ready", + "volumes": [], + } + body.update(overrides) + return body + + +_REGIONS_BODY = { + "regions": [ + { + "name": "us-central-8", + "driver_versions": ["CUDA_12_6_565"], + "supported_instance_types": ["H100_SXM"], + } + ] +} + +_VOLUME_BODY = { + "volume_id": "vol-1", + "volume_name": "data", + "size_tib": 2, + "status": "available", +} + + +class TestBetaClustersList: + @pytest.mark.respx(base_url=base_url) + def test_list_table(self, respx_mock: MockRouter) -> None: + respx_mock.get("/compute/clusters").mock( + return_value=httpx.Response( + 200, + json={"clusters": [_cluster_body("a", "alpha"), _cluster_body("b", "beta")]}, + ) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "list"]) + assert result.exit_code == 0 + assert "a" in result.output + assert "alpha" in result.output + assert "b" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_list_json(self, respx_mock: MockRouter) -> None: + payload = {"clusters": [_cluster_body()]} + respx_mock.get("/compute/clusters").mock(return_value=httpx.Response(200, json=payload)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "list", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output) == payload + + +class TestBetaClustersListRegions: + @pytest.mark.respx(base_url=base_url) + def test_list_regions_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/compute/regions").mock(return_value=httpx.Response(200, json=_REGIONS_BODY)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "list-regions", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output) == _REGIONS_BODY + + +class TestBetaClustersRetrieve: + @pytest.mark.respx(base_url=base_url) + def test_retrieve_json(self, respx_mock: MockRouter) -> None: + c = _cluster_body() + respx_mock.get("/compute/clusters/cluster-1").mock(return_value=httpx.Response(200, json=c)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "retrieve", "cluster-1", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output) == c + + +class TestBetaClustersCreate: + @pytest.mark.respx(base_url=base_url) + def test_create_non_interactive_posts_expected_body(self, respx_mock: MockRouter) -> None: + created = _cluster_body("new-id", "together-py-testing-suite") + route = respx_mock.post("/compute/clusters").mock(return_value=httpx.Response(200, json=created)) + runner = CliRunner(env=_ENV) + result = runner.invoke( + main, + [ + "beta", + "clusters", + "create", + "--non-interactive", + "--cluster-type", + "KUBERNETES", + "--gpu-type", + "H100_SXM", + "--driver-version", + "CUDA_12_6_565", + "--region", + "us-central-8", + "--num-gpus", + "8", + "--billing-type", + "ON_DEMAND", + "--name", + "together-py-testing-suite", + "--volume", + "vol-attach", + ], + ) + assert result.exit_code == 0 + assert "new-id" in result.output + raw = cast(Call, route.calls[0]).request.content.decode() + body = json.loads(raw) + assert body["cluster_name"] == "together-py-testing-suite" + assert body["volume_id"] == "vol-attach" + assert body["num_gpus"] == 8 + assert body["billing_type"] == "ON_DEMAND" + + +class TestBetaClustersUpdate: + @pytest.mark.respx(base_url=base_url) + def test_update_json_triggers_put_and_second_get(self, respx_mock: MockRouter) -> None: + updated = _cluster_body("c1", num_gpus=16, cluster_type="SLURM") + put = respx_mock.put("/compute/clusters/c1").mock(return_value=httpx.Response(200, json=updated)) + get = respx_mock.get("/compute/clusters/c1").mock(return_value=httpx.Response(200, json=updated)) + runner = CliRunner(env=_ENV) + result = runner.invoke( + main, + ["beta", "clusters", "update", "c1", "--num-gpus", "16", "--cluster-type", "SLURM", "--json"], + ) + assert result.exit_code == 0 + assert put.calls + assert get.calls + assert json.loads(result.output)["num_gpus"] == 16 + put_body = json.loads(cast(Call, put.calls[0]).request.content.decode()) + assert put_body["num_gpus"] == 16 + assert put_body["cluster_type"] == "SLURM" + + +class TestBetaClustersDelete: + @pytest.mark.respx(base_url=base_url) + def test_delete_json(self, respx_mock: MockRouter) -> None: + respx_mock.delete("/compute/clusters/c-del").mock( + return_value=httpx.Response(200, json={"cluster_id": "c-del"}) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "delete", "c-del", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output) == {"cluster_id": "c-del"} + + @pytest.mark.respx(base_url=base_url) + def test_delete_confirm_yes(self, respx_mock: MockRouter) -> None: + c = _cluster_body("c1", "to-delete") + respx_mock.get("/compute/clusters/c1").mock(return_value=httpx.Response(200, json=c)) + respx_mock.delete("/compute/clusters/c1").mock(return_value=httpx.Response(200, json={"cluster_id": "c1"})) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "delete", "c1"], input="y\n") + assert result.exit_code == 0 + assert "Deleted" in result.output + + +class TestBetaClustersGetCredentials: + @pytest.mark.respx(base_url=base_url) + def test_get_credentials_stdout(self, respx_mock: MockRouter) -> None: + cfg = "apiVersion: v1\nkind: Config\n" + c = _cluster_body(kube_config=base64.b64encode(cfg.encode()).decode("ascii")) + respx_mock.get("/compute/clusters/c1").mock(return_value=httpx.Response(200, json=c)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "get-credentials", "c1", "--file", "-"]) + assert result.exit_code == 0 + assert result.output.strip() == cfg.strip() + + +class TestBetaClustersStorage: + @pytest.mark.respx(base_url=base_url) + def test_storage_list_json(self, respx_mock: MockRouter) -> None: + payload = {"volumes": [_VOLUME_BODY]} + respx_mock.get("/compute/clusters/storage/volumes").mock(return_value=httpx.Response(200, json=payload)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "storage", "list", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output) == payload + + @pytest.mark.respx(base_url=base_url) + def test_storage_create_json(self, respx_mock: MockRouter) -> None: + route = respx_mock.post("/compute/clusters/storage/volumes").mock( + return_value=httpx.Response(200, json=_VOLUME_BODY) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke( + main, + [ + "beta", + "clusters", + "storage", + "create", + "--region", + "us-east-1", + "--size-tib", + "1", + "--volume-name", + "test-volume", + "--json", + ], + ) + assert result.exit_code == 0 + out = json.loads(result.output) + assert out["volume_id"] == "vol-1" + raw = cast(Call, route.calls[0]).request.content.decode() + assert json.loads(raw) == {"region": "us-east-1", "size_tib": 1, "volume_name": "test-volume"} + + @pytest.mark.respx(base_url=base_url) + def test_storage_retrieve_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/compute/clusters/storage/volumes/vol-1").mock( + return_value=httpx.Response(200, json=_VOLUME_BODY) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "storage", "retrieve", "vol-1", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output) == _VOLUME_BODY + + @pytest.mark.respx(base_url=base_url) + def test_storage_delete_json(self, respx_mock: MockRouter) -> None: + respx_mock.delete("/compute/clusters/storage/volumes/vol-1").mock( + return_value=httpx.Response(200, json={"success": True}) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["beta", "clusters", "storage", "delete", "vol-1", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output) == {"success": True} diff --git a/tests/cli/test_beta_jig.py b/tests/cli/test_beta_jig.py new file mode 100644 index 000000000..c78bd8677 --- /dev/null +++ b/tests/cli/test_beta_jig.py @@ -0,0 +1,378 @@ +from __future__ import annotations + +import os +import sys +import json +from typing import Any, cast +from pathlib import Path +from contextlib import contextmanager +from unittest.mock import patch + +import httpx +import pytest +from respx import MockRouter +from respx.models import Call +from click.testing import CliRunner + +from together.lib.cli import main + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +API_KEY = "0000000000000000000000000000000000000000" +_ENV = {"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY} + +# Imported into jig CLI module namespace +_jig_mod = sys.modules["together.lib.cli.api.beta.jig.jig"] + +_DEPLOY_NAME = "jig-cli-test" + + +@contextmanager +def _chdir(path: Path): + prev = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(prev) + + +def _noop_config_post_init(_self: Any) -> None: + """Stub replacing Config.__post_init__ when skipping validation in tests.""" + return None + + +@contextmanager +def _patched_jig_config(tmp_path: Path): + """Avoid Config.find() + validate on py3.9 (DeployConfig uses PEP 604 hints).""" + with patch.object(_jig_mod.Config, "__post_init__", _noop_config_post_init): + cfg = _jig_mod.Config( + model_name=_DEPLOY_NAME, + image=_jig_mod.ImageConfig(), + deploy=_jig_mod.DeployConfig(), + _path=tmp_path / "pyproject.toml", + _unique_name_hint="h", + ) + + def _find(*_args: Any): + return cfg + + with patch.object(_jig_mod.Config, "find", classmethod(_find)): + yield + + +_PYPROJECT = f"""[project] +name = "{_DEPLOY_NAME}" +version = "0.1.0" + +[tool.jig.image] +python_version = "3.11" +cmd = "python app.py" + +[tool.jig.deploy] +description = "test" +gpu_type = "h100-80gb" +gpu_count = 1 +""" + + +def _write_jig_project(path: Path) -> None: + path.joinpath("pyproject.toml").write_text(_PYPROJECT, encoding="utf-8") + + +def _secret_api_body(name: str) -> dict[str, object]: + return { + "id": "sec-1", + "name": name, + "object": "secret", + "description": "", + } + + +def _volume_api_body(name: str, **extra: object) -> dict[str, object]: + body: dict[str, object] = { + "id": "vol-id-1", + "name": name, + "object": "volume", + "type": "readOnly", + "current_version": 0, + } + body.update(extra) + return body + + +class TestBetaJigSecretsSet: + @pytest.mark.respx(base_url=base_url) + def test_set_creates_when_update_returns_not_found(self, respx_mock: MockRouter, tmp_path: Path) -> None: + scoped = f"{_DEPLOY_NAME}-apikey" + respx_mock.get(f"/deployments/{_DEPLOY_NAME}").mock(return_value=httpx.Response(404, json={})) + respx_mock.patch(f"/deployments/secrets/{scoped}").mock(return_value=httpx.Response(404, json={})) + post = respx_mock.post("/deployments/secrets").mock( + return_value=httpx.Response(200, json=_secret_api_body(scoped)) + ) + + runner = CliRunner(env=_ENV) + with _patched_jig_config(tmp_path), _chdir(tmp_path): + result = runner.invoke( + main, + [ + "beta", + "jig", + "secrets", + "set", + "--name", + "apikey", + "--value", + "secret-val", + "--description", + "d1", + ], + ) + assert result.exit_code == 0 + assert "Created secret apikey" in result.output + raw = cast(Call, post.calls[0]).request.content.decode() + body = json.loads(raw) + assert body["name"] == scoped + assert body["value"] == "secret-val" + assert body["description"] == "d1" + state = json.loads((tmp_path / ".jig.json").read_text()) + assert state[_DEPLOY_NAME]["secrets"]["apikey"] == scoped + + @pytest.mark.respx(base_url=base_url) + def test_set_updates_when_secret_exists(self, respx_mock: MockRouter, tmp_path: Path) -> None: + scoped = f"{_DEPLOY_NAME}-apikey" + respx_mock.get(f"/deployments/{_DEPLOY_NAME}").mock(return_value=httpx.Response(404, json={})) + patch_route = respx_mock.patch(f"/deployments/secrets/{scoped}").mock( + return_value=httpx.Response(200, json=_secret_api_body(scoped)) + ) + + runner = CliRunner(env=_ENV) + with _patched_jig_config(tmp_path), _chdir(tmp_path): + result = runner.invoke( + main, + ["beta", "jig", "secrets", "set", "--name", "apikey", "--value", "v2"], + ) + assert result.exit_code == 0 + assert "Updated secret apikey" in result.output + assert patch_route.called + raw = cast(Call, patch_route.calls[0]).request.content.decode() + assert json.loads(raw)["value"] == "v2" + + +class TestBetaJigSecretsList: + @pytest.mark.respx(base_url=base_url) + def test_list_merges_local_and_remote(self, respx_mock: MockRouter, tmp_path: Path) -> None: + (tmp_path / ".jig.json").write_text( + json.dumps({_DEPLOY_NAME: {"secrets": {"localonly": f"{_DEPLOY_NAME}-localonly"}}}), + encoding="utf-8", + ) + respx_mock.get("/deployments/secrets").mock( + return_value=httpx.Response( + 200, + json={ + "object": "list", + "data": [ + {"name": f"{_DEPLOY_NAME}-localonly", "object": "secret"}, + {"name": f"{_DEPLOY_NAME}-remoteonly", "object": "secret"}, + {"name": "other-project-x", "object": "secret"}, + ], + }, + ) + ) + + runner = CliRunner(env=_ENV) + with _patched_jig_config(tmp_path), _chdir(tmp_path): + result = runner.invoke(main, ["beta", "jig", "secrets", "list"]) + assert result.exit_code == 0 + assert "localonly" in result.output + assert "remoteonly" in result.output + assert "synced" in result.output or "local only" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_list_empty_message(self, respx_mock: MockRouter, tmp_path: Path) -> None: + respx_mock.get("/deployments/secrets").mock( + return_value=httpx.Response(200, json={"object": "list", "data": []}) + ) + + runner = CliRunner(env=_ENV) + with _patched_jig_config(tmp_path), _chdir(tmp_path): + result = runner.invoke(main, ["beta", "jig", "secrets", "list"]) + assert result.exit_code == 0 + assert "No secrets configured" in result.output + + +class TestBetaJigSecretsUnset: + def test_unset_removes_known_secret(self, tmp_path: Path) -> None: + (tmp_path / ".jig.json").write_text( + json.dumps({_DEPLOY_NAME: {"secrets": {"tok": f"{_DEPLOY_NAME}-tok"}}}), + encoding="utf-8", + ) + + runner = CliRunner(env=_ENV) + with _patched_jig_config(tmp_path), _chdir(tmp_path): + result = runner.invoke(main, ["beta", "jig", "secrets", "unset", "--name", "tok"]) + assert result.exit_code == 0 + assert "Deleted secret tok" in result.output + state = json.loads((tmp_path / ".jig.json").read_text()) + assert "tok" not in state[_DEPLOY_NAME].get("secrets", {}) + + def test_unset_missing_secret_message(self, tmp_path: Path) -> None: + (tmp_path / ".jig.json").write_text( + json.dumps({_DEPLOY_NAME: {"secrets": {}}}), + encoding="utf-8", + ) + + runner = CliRunner(env=_ENV) + with _patched_jig_config(tmp_path), _chdir(tmp_path): + result = runner.invoke(main, ["beta", "jig", "secrets", "unset", "--name", "nope"]) + assert result.exit_code == 0 + assert "Secret nope is not set" in result.output + + +class TestBetaJigVolumes: + @pytest.mark.respx(base_url=base_url) + def test_delete(self, respx_mock: MockRouter, tmp_path: Path) -> None: + _write_jig_project(tmp_path) + respx_mock.delete("/deployments/storage/volumes/data-vol").mock(return_value=httpx.Response(200, json={})) + + runner = CliRunner(env=_ENV) + with _chdir(tmp_path): + result = runner.invoke(main, ["beta", "jig", "volumes", "delete", "--name", "data-vol"]) + assert result.exit_code == 0 + assert "Deleted volume data-vol" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_delete_not_found(self, respx_mock: MockRouter, tmp_path: Path) -> None: + _write_jig_project(tmp_path) + respx_mock.delete("/deployments/storage/volumes/missing").mock( + return_value=httpx.Response(404, json={"error": {"message": "not found"}}) + ) + + runner = CliRunner(env=_ENV) + with _chdir(tmp_path): + result = runner.invoke(main, ["beta", "jig", "volumes", "delete", "--name", "missing"]) + assert result.exit_code == 1 + assert "not found" in result.output.lower() + + @pytest.mark.respx(base_url=base_url) + def test_describe_json(self, respx_mock: MockRouter, tmp_path: Path) -> None: + _write_jig_project(tmp_path) + payload = _volume_api_body("v1", current_version=2) + respx_mock.get("/deployments/storage/volumes/v1").mock(return_value=httpx.Response(200, json=payload)) + + runner = CliRunner(env=_ENV) + with _chdir(tmp_path): + result = runner.invoke(main, ["beta", "jig", "volumes", "describe", "--name", "v1"]) + assert result.exit_code == 0 + assert json.loads(result.output) == payload + + @pytest.mark.respx(base_url=base_url) + def test_list_json(self, respx_mock: MockRouter, tmp_path: Path) -> None: + _write_jig_project(tmp_path) + payload = {"object": "list", "data": [_volume_api_body("a"), _volume_api_body("b")]} + respx_mock.get("/deployments/storage/volumes").mock(return_value=httpx.Response(200, json=payload)) + + runner = CliRunner(env=_ENV) + with _chdir(tmp_path): + result = runner.invoke(main, ["beta", "jig", "volumes", "list"]) + assert result.exit_code == 0 + assert json.loads(result.output) == payload + + @pytest.mark.respx(base_url=base_url) + def test_create_invokes_upload(self, respx_mock: MockRouter, tmp_path: Path) -> None: + _write_jig_project(tmp_path) + src = tmp_path / "srcdir" + src.mkdir() + (src / "f.txt").write_text("x", encoding="utf-8") + + respx_mock.post("/deployments/storage/volumes").mock( + return_value=httpx.Response(200, json=_volume_api_body("myvol")) + ) + + uploaded: list[tuple[Path, str]] = [] + + class _FakeUploader: + def __init__(self, _client: object) -> None: + pass + + async def upload_files(self, source: Path, prefix: str) -> None: + uploaded.append((source, prefix)) + + with patch.object(_jig_mod, "Uploader", _FakeUploader): + runner = CliRunner(env=_ENV) + with _chdir(tmp_path): + result = runner.invoke( + main, + ["beta", "jig", "volumes", "create", "--name", "myvol", "--source", str(src)], + ) + + assert result.exit_code == 0 + assert uploaded == [(src, "myvol/0")] + assert "Volume created" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_create_rolls_back_on_upload_failure(self, respx_mock: MockRouter, tmp_path: Path) -> None: + _write_jig_project(tmp_path) + src = tmp_path / "srcdir" + src.mkdir() + + respx_mock.post("/deployments/storage/volumes").mock( + return_value=httpx.Response(200, json=_volume_api_body("badvol")) + ) + del_vol = respx_mock.delete("/deployments/storage/volumes/badvol").mock( + return_value=httpx.Response(200, json={}) + ) + + class _FakeUploader: + def __init__(self, _client: object) -> None: + pass + + async def upload_files(self, _source: Path, _prefix: str) -> None: + raise RuntimeError("upload boom") + + with patch.object(_jig_mod, "Uploader", _FakeUploader): + runner = CliRunner(env=_ENV) + with _chdir(tmp_path): + result = runner.invoke( + main, + ["beta", "jig", "volumes", "create", "--name", "badvol", "--source", str(src)], + ) + + assert result.exit_code == 1 + assert del_vol.called + + @pytest.mark.respx(base_url=base_url) + def test_update_bumps_version_and_uploads(self, respx_mock: MockRouter, tmp_path: Path) -> None: + _write_jig_project(tmp_path) + src = tmp_path / "newsrc" + src.mkdir() + (src / "a.bin").write_bytes(b"\0") + + respx_mock.get("/deployments/storage/volumes/shared").mock( + return_value=httpx.Response(200, json=_volume_api_body("shared", current_version=3)) + ) + patch_r = respx_mock.patch("/deployments/storage/volumes/shared").mock( + return_value=httpx.Response(200, json=_volume_api_body("shared", current_version=4)) + ) + + uploaded: list[tuple[Path, str]] = [] + + class _FakeUploader: + def __init__(self, _client: object) -> None: + pass + + async def upload_files(self, source: Path, prefix: str) -> None: + uploaded.append((source, prefix)) + + with patch.object(_jig_mod, "Uploader", _FakeUploader): + runner = CliRunner(env=_ENV) + with _chdir(tmp_path): + result = runner.invoke( + main, + ["beta", "jig", "volumes", "update", "--name", "shared", "--source", str(src)], + ) + + assert result.exit_code == 0 + assert uploaded == [(src, "shared/4")] + assert patch_r.called + patch_body = json.loads(cast(Call, patch_r.calls[0]).request.content.decode()) + assert patch_body["content"] == {"type": "files", "source_prefix": "shared/4"} diff --git a/tests/cli/test_endpoints.py b/tests/cli/test_endpoints.py new file mode 100644 index 000000000..3b9013fff --- /dev/null +++ b/tests/cli/test_endpoints.py @@ -0,0 +1,344 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +import json +from typing import cast + +import httpx +import pytest +from respx import MockRouter +from respx.models import Call +from click.testing import CliRunner + +from together.lib.cli import main + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +API_KEY = "0000000000000000000000000000000000000000" +_ENV = {"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY} + +model_data = { + "data": [ + { + "id": "2x_nvidia_a100_80gb_sxm", + "object": "hardware", + "specs": { + "gpu_type": "a100", + "gpu_memory": 80, + "gpu_count": 2, + "gpu_link": "sxm", + }, + "pricing": { + "cents_per_minute": 5, + }, + "updated_at": "2026-03-23T12:00:00Z", + "availability": { + "status": "available", + }, + }, + { + "id": "1x_nvidia_a100_80gb_sxm", + "object": "hardware", + "specs": { + "gpu_type": "a100", + "gpu_memory": 80, + "gpu_count": 1, + "gpu_link": "sxm", + }, + "pricing": { + "cents_per_minute": 5, + }, + "updated_at": "2026-03-23T12:00:00Z", + "availability": { + "status": "unavailable", + }, + }, + ], + "object": "list", +} + +hardware_list_unfiltered = { + "object": "list", + "data": [model_data["data"][0]], +} + +DEDICATED_EP = { + "id": "endpoint-123", + "object": "endpoint", + "type": "dedicated", + "name": "sys-name", + "display_name": "My Endpoint", + "hardware": "2x_nvidia_a100_80gb_sxm", + "model": "deepseek-ai/DeepSeek-R1", + "owner": "user", + "state": "STARTED", + "created_at": "2024-01-01T00:00:00Z", + "autoscaling": {"min_replicas": 1, "max_replicas": 4}, +} + +ENDPOINT_LIST_ITEM = { + "id": "ep-list-1", + "object": "endpoint", + "type": "dedicated", + "name": "n1", + "model": "m1", + "owner": "o1", + "state": "STARTED", + "created_at": "2024-01-01T00:00:00Z", +} + + +class TestEndpointsCreate: + # Test for endpoint create requiring the model + def test_requires_model(self) -> None: + runner = CliRunner(env=_ENV) + assert runner.invoke(main, ["endpoints", "create"]).exit_code == 2 + + # Test for when the API returns an error saying hardware is required + @pytest.mark.respx(base_url=base_url) + def test_invalid_hardware(self, respx_mock: MockRouter) -> None: + respx_mock.post("/endpoints").mock( + return_value=httpx.Response(400, json={"error": {"message": "Hardware is required", "type": "bad_request"}}) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke( + main, ["endpoints", "create", "--model", "deepseek-ai/DeepSeek-R1", "--hardware", "foooooooo"] + ) + assert result.exit_code == 1 + assert "Invalid hardware selected." in result.output + + # Test for when the API returns an error saying model not found + @pytest.mark.respx(base_url=base_url) + def test_invalid_model(self, respx_mock: MockRouter) -> None: + respx_mock.post("/endpoints").mock( + return_value=httpx.Response(400, json={"error": {"message": "Model not found", "type": "bad_request"}}) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "create", "--model", "deepseek-ai/DeepSeek-R1"]) + assert result.exit_code == 1 + assert ( + "Model 'deepseek-ai/DeepSeek-R1' was not found or is not available for dedicated endpoints." + in result.output + ) + + +class TestEndpointsHardware: + @pytest.mark.respx(base_url=base_url) + def test_hardware_list(self, respx_mock: MockRouter) -> None: + respx_mock.get("/hardware").mock(return_value=httpx.Response(200, json=hardware_list_unfiltered)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "hardware"]) + assert result.exit_code == 0 + assert ( + result.output.strip() + == """ +Hardware ID GPU Memory Count Price (per minute) +----------------------- ----- -------- ------- -------------------- +2x_nvidia_a100_80gb_sxm a100 80GB 2 $0.05 +""".strip() + ) + + @pytest.mark.respx(base_url=base_url) + def test_hardware_list_with_model(self, respx_mock: MockRouter) -> None: + respx_mock.get("/hardware").mock(return_value=httpx.Response(200, json=model_data)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "hardware", "--model", "deepseek-ai/DeepSeek-R1"]) + assert result.exit_code == 0 + assert ( + result.output.strip() + == """ +Hardware ID GPU Memory Count Price (per minute) availability +----------------------- ----- -------- ------- -------------------- -------------- +2x_nvidia_a100_80gb_sxm a100 80GB 2 $0.05 ✓ available +1x_nvidia_a100_80gb_sxm a100 80GB 1 $0.05 ✗ unavailable +""".strip() + ) + + @pytest.mark.respx(base_url=base_url) + def test_hardware_list_with_model_and_available(self, respx_mock: MockRouter) -> None: + respx_mock.get("/hardware").mock(return_value=httpx.Response(200, json=model_data)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "hardware", "--model", "deepseek-ai/DeepSeek-R1", "--available"]) + assert result.exit_code == 0 + assert ( + result.output.strip() + == """ +Hardware ID GPU Memory Count Price (per minute) availability +----------------------- ----- -------- ------- -------------------- -------------- +2x_nvidia_a100_80gb_sxm a100 80GB 2 $0.05 ✓ available +""".strip() + ) + + @pytest.mark.respx(base_url=base_url) + def test_hardware_list_with_model_and_available_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/hardware").mock(return_value=httpx.Response(200, json=model_data)) + runner = CliRunner(env=_ENV) + result = runner.invoke( + main, ["endpoints", "hardware", "--model", "deepseek-ai/DeepSeek-R1", "--available", "--json"] + ) + + data = json.loads(result.output) + + for item in data: + assert item["availability"]["status"] == "available" + + +class TestEndpointsStart: + def test_start_requires_id(self) -> None: + runner = CliRunner(env=_ENV) + assert runner.invoke(main, ["endpoints", "start"]).exit_code == 2 + + @pytest.mark.respx(base_url=base_url) + def test_start_endpoint(self, respx_mock: MockRouter) -> None: + respx_mock.patch("/endpoints/endpoint-123").mock(return_value=httpx.Response(200, json=DEDICATED_EP)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "start", "endpoint-123"]) + assert result.exit_code == 0 + assert result.output.strip() == "Successfully marked endpoint as starting\nendpoint-123" + + @pytest.mark.respx(base_url=base_url) + def test_start_json(self, respx_mock: MockRouter) -> None: + respx_mock.patch("/endpoints/endpoint-123").mock(return_value=httpx.Response(200, json=DEDICATED_EP)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "start", "endpoint-123", "--json"]) + assert result.exit_code == 0 + body = json.loads(result.output) + assert body["id"] == "endpoint-123" + assert body["state"] == "STARTED" + + @pytest.mark.respx(base_url=base_url) + def test_start_wait(self, respx_mock: MockRouter) -> None: + from unittest.mock import patch + + starting = {**DEDICATED_EP, "state": "STARTING"} + respx_mock.patch("/endpoints/endpoint-123").mock(return_value=httpx.Response(200, json=DEDICATED_EP)) + respx_mock.get("/endpoints/endpoint-123").mock( + side_effect=[ + httpx.Response(200, json=starting), + httpx.Response(200, json=DEDICATED_EP), + ] + ) + runner = CliRunner(env=_ENV) + with patch("time.sleep"): + result = runner.invoke(main, ["endpoints", "start", "endpoint-123", "--wait"]) + assert result.exit_code == 0 + assert "Endpoint started" in result.output + assert "endpoint-123" in result.output + + +class TestEndpointsStop: + def test_stop_requires_id(self) -> None: + runner = CliRunner(env=_ENV) + assert runner.invoke(main, ["endpoints", "stop"]).exit_code == 2 + + @pytest.mark.respx(base_url=base_url) + def test_stop_endpoint(self, respx_mock: MockRouter) -> None: + stopped = {**DEDICATED_EP, "state": "STOPPED"} + respx_mock.patch("/endpoints/endpoint-123").mock(return_value=httpx.Response(200, json=stopped)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "stop", "endpoint-123"]) + assert result.exit_code == 0 + assert result.output.strip() == "Successfully marked endpoint as stopping\nendpoint-123" + + @pytest.mark.respx(base_url=base_url) + def test_stop_json(self, respx_mock: MockRouter) -> None: + stopped = {**DEDICATED_EP, "state": "STOPPED"} + respx_mock.patch("/endpoints/endpoint-123").mock(return_value=httpx.Response(200, json=stopped)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "stop", "endpoint-123", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output)["message"] == "Successfully marked endpoint as stopping" + + @pytest.mark.respx(base_url=base_url) + def test_stop_wait(self, respx_mock: MockRouter) -> None: + from unittest.mock import patch + + stopping = {**DEDICATED_EP, "state": "STOPPING"} + stopped = {**DEDICATED_EP, "state": "STOPPED"} + respx_mock.patch("/endpoints/endpoint-123").mock(return_value=httpx.Response(200, json=stopping)) + respx_mock.get("/endpoints/endpoint-123").mock( + side_effect=[ + httpx.Response(200, json=stopping), + httpx.Response(200, json=stopped), + ] + ) + runner = CliRunner(env=_ENV) + with patch("time.sleep"): + result = runner.invoke(main, ["endpoints", "stop", "endpoint-123", "--wait"]) + assert result.exit_code == 0 + assert "Endpoint stopped" in result.output + + +class TestEndpointsListRetrieveDeleteUpdateAz: + @pytest.mark.respx(base_url=base_url) + def test_list_type_and_mine_query(self, respx_mock: MockRouter) -> None: + list_body = {"object": "list", "data": [ENDPOINT_LIST_ITEM]} + route = respx_mock.get("/endpoints").mock(return_value=httpx.Response(200, json=list_body)) + runner = CliRunner(env=_ENV) + assert ( + runner.invoke( + main, + ["endpoints", "list", "--type", "dedicated", "--mine", "--usage-type", "on-demand"], + ).exit_code + == 0 + ) + url = str(cast(Call, route.calls[0]).request.url) + assert "type=dedicated" in url + assert "mine=true" in url + assert "usage_type=on-demand" in url or "usage-type" in url + + @pytest.mark.respx(base_url=base_url) + def test_list_json(self, respx_mock: MockRouter) -> None: + list_body = {"object": "list", "data": [ENDPOINT_LIST_ITEM]} + respx_mock.get("/endpoints").mock(return_value=httpx.Response(200, json=list_body)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "list", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output)[0]["id"] == "ep-list-1" + + @pytest.mark.respx(base_url=base_url) + def test_retrieve_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/endpoints/ep-1").mock(return_value=httpx.Response(200, json=DEDICATED_EP)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "retrieve", "ep-1", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output)["display_name"] == "My Endpoint" + + @pytest.mark.respx(base_url=base_url) + def test_delete_json(self, respx_mock: MockRouter) -> None: + respx_mock.delete("/endpoints/ep-del").mock(return_value=httpx.Response(200)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "delete", "ep-del", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output)["message"] == "Successfully deleted endpoint" + + def test_update_requires_option(self) -> None: + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "update", "ep-1"]) + assert result.exit_code == 1 + assert "At least one update option" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_update_min_max_replicas(self, respx_mock: MockRouter) -> None: + patch_route = respx_mock.patch("/endpoints/ep-1").mock(return_value=httpx.Response(200, json=DEDICATED_EP)) + runner = CliRunner(env=_ENV) + result = runner.invoke( + main, + ["endpoints", "update", "ep-1", "--min-replicas", "1", "--max-replicas", "3"], + ) + assert result.exit_code == 0 + assert "ep-1" in result.output + req = cast(Call, patch_route.calls[0]).request + body = json.loads(req.content.decode()) + assert body["autoscaling"] == {"min_replicas": 1, "max_replicas": 3} + + @pytest.mark.respx(base_url=base_url) + def test_availability_zones_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/clusters/availability-zones").mock( + return_value=httpx.Response(200, json={"avzones": ["us-east-1a", "us-west-2b"]}) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["endpoints", "availability-zones", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output)["avzones"] == ["us-east-1a", "us-west-2b"] diff --git a/tests/cli/test_evals.py b/tests/cli/test_evals.py new file mode 100644 index 000000000..dec5dffe4 --- /dev/null +++ b/tests/cli/test_evals.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import os +from typing import cast + +import httpx +import pytest +from respx import MockRouter +from respx.models import Call +from click.testing import CliRunner + +from together.lib.cli import main + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +API_KEY = "0000000000000000000000000000000000000000" + +_ENV = {"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY} + +_EVAL_JOB = { + "workflow_id": "eval-wf-1", + "type": "classify", + "status": "completed", + "created_at": "2024-01-01T00:00:00Z", + "parameters": {"model_to_evaluate": "m1", "model_a": "", "model_b": ""}, +} + +_EVAL_STATUS = {"status": "completed", "results": None} + + +class TestEvalsList: + @pytest.mark.respx(base_url=base_url) + def test_list_passes_status_and_limit(self, respx_mock: MockRouter) -> None: + route = respx_mock.get("/evaluation").mock(return_value=httpx.Response(200, json=[_EVAL_JOB])) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["evals", "list", "--status", "completed", "--limit", "5"]) + assert result.exit_code == 0 + assert "eval-wf-1" in result.output + req = cast(Call, route.calls[0]).request + assert "status=completed" in str(req.url) + assert "limit=5" in str(req.url) + + @pytest.mark.respx(base_url=base_url) + def test_list_requires_nothing(self, respx_mock: MockRouter) -> None: + respx_mock.get("/evaluation").mock(return_value=httpx.Response(200, json=[])) + runner = CliRunner(env=_ENV) + assert runner.invoke(main, ["evals", "list"]).exit_code == 0 + + +class TestEvalsRetrieveAndStatus: + @pytest.mark.respx(base_url=base_url) + def test_retrieve(self, respx_mock: MockRouter) -> None: + respx_mock.get("/evaluation/eval-wf-1").mock(return_value=httpx.Response(200, json=_EVAL_JOB)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["evals", "retrieve", "eval-wf-1"]) + assert result.exit_code == 0 + assert "workflow_id: eval-wf-1" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_status(self, respx_mock: MockRouter) -> None: + respx_mock.get("/evaluation/eval-wf-1/status").mock(return_value=httpx.Response(200, json=_EVAL_STATUS)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["evals", "status", "eval-wf-1"]) + assert result.exit_code == 0 + assert "Status: completed" in result.output diff --git a/tests/cli/test_files.py b/tests/cli/test_files.py new file mode 100644 index 000000000..779b66442 --- /dev/null +++ b/tests/cli/test_files.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import os +import sys +import json +from typing import Any +from pathlib import Path +from unittest.mock import patch + +import httpx +import pytest +from respx import MockRouter +from click.testing import CliRunner + +from together.lib.cli import main +from together.types.file_response import FileResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + +API_KEY = "0000000000000000000000000000000000000000" + +# Submodule path; package attribute `upload` is the Click command and shadows this module. +_files_upload_cli = sys.modules["together.lib.cli.api.files.upload"] + +FILE_ROW_NEWER = { + "id": "file-newer", + "bytes": 2048, + "created_at": 1700000000, + "filename": "newer.jsonl", + "FileType": "jsonl", + "object": "file", + "Processed": True, + "purpose": "fine-tune", +} + +FILE_ROW_OLDER = { + "id": "file-older", + "bytes": 512, + "created_at": 1600000000, + "filename": "older.jsonl", + "FileType": "jsonl", + "object": "file", + "Processed": False, + "purpose": "eval", +} + + +def _file_response(**kwargs: Any) -> FileResponse: + defaults: dict[str, Any] = { + "id": "file-up", + "bytes": 10, + "created_at": 1, + "filename": "x.jsonl", + "FileType": "jsonl", + "object": "file", + "Processed": True, + "purpose": "fine-tune", + } + defaults.update(kwargs) + if hasattr(FileResponse, "model_validate"): + return FileResponse.model_validate(defaults) + return FileResponse.parse_obj(defaults) # pyright: ignore[reportDeprecated] + + +class TestFilesCheck: + def test_check(self, tmp_path: Path) -> None: + sample = tmp_path / "ok.jsonl" + sample.write_text('{"text": "hello"}\n', encoding="utf-8") + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "check", str(sample)]) + assert result.exit_code == 0 + + +class TestFilesDelete: + @pytest.mark.respx(base_url=base_url) + def test_delete(self, respx_mock: MockRouter) -> None: + respx_mock.delete("/files/file-to-delete").mock( + return_value=httpx.Response(200, json={"id": "file-to-delete", "deleted": True}) + ) + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "delete", "file-to-delete"]) + assert result.exit_code == 0 + assert "file-to-delete" in result.output + assert "deleted" in result.output.lower() + + +class TestFilesList: + @pytest.mark.respx(base_url=base_url) + def test_list(self, respx_mock: MockRouter) -> None: + respx_mock.get("/files").mock(return_value=httpx.Response(200, json={"data": [FILE_ROW_OLDER, FILE_ROW_NEWER]})) + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "list"]) + assert result.exit_code == 0 + assert "file-newer" in result.output + assert "file-older" in result.output + newer_pos = result.output.index("file-newer") + older_pos = result.output.index("file-older") + assert newer_pos < older_pos + + @pytest.mark.respx(base_url=base_url) + def test_list_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/files").mock(return_value=httpx.Response(200, json={"data": [FILE_ROW_OLDER, FILE_ROW_NEWER]})) + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "list", "--json"]) + assert result.exit_code == 0 + parsed = json.loads(result.output) + assert [row["id"] for row in parsed] == ["file-newer", "file-older"] + + +class TestFilesRetrieve: + @pytest.mark.respx(base_url=base_url) + def test_retrieve(self, respx_mock: MockRouter) -> None: + respx_mock.get("/files/file-meta").mock(return_value=httpx.Response(200, json=FILE_ROW_NEWER)) + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "retrieve", "file-meta"]) + assert result.exit_code == 0 + assert "newer.jsonl" in result.output + assert "fine-tune" in result.output + + +class TestFilesRetrieveContent: + def test_retrieve_content_no_options(self) -> None: + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "retrieve-content", "file-1"]) + assert result.exit_code == 2 + assert "Either --output" in result.output or "must be specified" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_specifying_output(self, respx_mock: MockRouter, tmp_path: Path) -> None: + respx_mock.get("/files/file-1/content").mock(return_value=httpx.Response(200, content=b"line1\nline2\n")) + out = tmp_path / "saved.jsonl" + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "retrieve-content", "file-1", "--output", str(out)]) + assert result.exit_code == 0 + assert out.read_bytes() == b"line1\nline2\n" + assert str(out) in result.output + + @pytest.mark.respx(base_url=base_url) + def test_specifying_stdout(self, respx_mock: MockRouter) -> None: + respx_mock.get("/files/file-1/content").mock(return_value=httpx.Response(200, content=b"stdout-bytes")) + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "retrieve-content", "file-1", "--stdout"]) + assert result.exit_code == 0 + assert result.output == "stdout-bytes\n" + + @pytest.mark.respx(base_url=base_url) + def test_specifying_both_output_and_stdout(self, respx_mock: MockRouter, tmp_path: Path) -> None: + respx_mock.get("/files/file-1/content").mock(return_value=httpx.Response(200, content=b"to-stdout")) + out = tmp_path / "should-not-exist.bin" + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke( + main, + ["files", "retrieve-content", "file-1", "--stdout", "--output", str(out)], + ) + assert result.exit_code == 0 + assert result.output == "to-stdout\n" + assert not out.exists() + + +class TestFilesUpload: + def test_upload_with_invalid_purpose(self, tmp_path: Path) -> None: + f = tmp_path / "empty.jsonl" + f.write_text("{}\n") + with patch("together.resources.files.FilesResource.upload") as upload_mock: + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke( + main, + ["files", "upload", str(f), "--purpose", "not-a-real-purpose"], + ) + assert result.exit_code == 2 + upload_mock.assert_not_called() + + def test_upload_does_check_by_default(self, tmp_path: Path) -> None: + f = tmp_path / "data.jsonl" + f.write_text("{}\n") + with patch.object(_files_upload_cli, "check_file") as check_mock, patch( + "together.resources.files.FilesResource.upload" + ) as upload_mock: + check_mock.return_value = {"is_check_passed": False, "message": "failed validation"} + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "upload", str(f)]) + assert result.exit_code == 1 + check_mock.assert_called_once() + upload_mock.assert_not_called() + + def test_upload_does_not_check_if_disabled(self, tmp_path: Path) -> None: + f = tmp_path / "data.jsonl" + f.write_text("{}\n") + uploaded = _file_response(id="uploaded-id", purpose="fine-tune") + with patch.object(_files_upload_cli, "check_file") as check_mock, patch( + "together.resources.files.FilesResource.upload", return_value=uploaded + ) as upload_mock: + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "upload", str(f), "--no-check"]) + assert result.exit_code == 0 + check_mock.assert_not_called() + upload_mock.assert_called_once() + call_kw = upload_mock.call_args.kwargs + assert call_kw["check"] is False + assert "uploaded-id" in result.output + + def test_upload_does_check_if_enabled(self, tmp_path: Path) -> None: + f = tmp_path / "data.jsonl" + f.write_text("{}\n") + uploaded = _file_response() + with patch.object(_files_upload_cli, "check_file") as check_mock, patch( + "together.resources.files.FilesResource.upload", return_value=uploaded + ) as upload_mock: + check_mock.return_value = {"is_check_passed": True, "message": "Checks passed"} + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}) + result = runner.invoke(main, ["files", "upload", str(f), "--check"]) + assert result.exit_code == 0 + check_mock.assert_called_once() + upload_mock.assert_called_once() diff --git a/tests/cli/test_fine_tuning.py b/tests/cli/test_fine_tuning.py new file mode 100644 index 000000000..51a80cb14 --- /dev/null +++ b/tests/cli/test_fine_tuning.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import os +import sys +import json +from pathlib import Path +from unittest.mock import patch + +import httpx +import pytest +from respx import MockRouter +from click.testing import CliRunner + +from together.lib.cli import main + +# Real module; package attribute `download` is the Click command and shadows this name. +_ft_download_mod = sys.modules["together.lib.cli.api.fine_tuning.download"] + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +API_KEY = "0000000000000000000000000000000000000000" + +_ENV = {"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY} + +_FT_LIST_ITEM = { + "id": "ft-newer", + "created_at": "2024-06-02T12:00:00Z", + "updated_at": "2024-06-02T12:00:00Z", + "status": "completed", + "total_price": 200, + "model": "meta-llama/Llama-3-8b", + "suffix": "my-run", +} + +_FT_LIST_ITEM_OLDER = { + "id": "ft-older", + "created_at": "2024-01-01T12:00:00Z", + "updated_at": "2024-01-01T12:00:00Z", + "status": "running", + "total_price": 50, + "model": "meta-llama/Llama-3-8b", + "suffix": "", + "progress": {"estimate_available": True, "seconds_remaining": 120}, +} + +_FT_RETRIEVE_BODY = { + "id": "ft-1", + "status": "completed", + "training_type": {"type": "Full"}, + "model_output_name": "weights.tar", + "created_at": "2024-01-01T00:00:00Z", + "started_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T01:00:00Z", +} + +_FT_EVENT = { + "checkpoint_path": "/ckpt", + "created_at": "2024-01-01T00:00:00Z", + "hash": "abc", + "message": "training started", + "model_path": "/m", + "object": "fine-tune-event", + "param_count": 7, + "step": 0, + "token_count": 0, + "total_steps": 10, + "training_offset": 0, + "type": "training_start", + "wandb_url": "", +} + +_FT_CHECKPOINT = { + "checkpoint_type": "intermediate", + "created_at": "2024-01-01T00:00:00Z", + "path": "/p", + "step": 5, +} + + +class TestFineTuningList: + @pytest.mark.respx(base_url=base_url) + def test_list_table(self, respx_mock: MockRouter) -> None: + respx_mock.get("/fine-tunes").mock( + return_value=httpx.Response(200, json={"data": [_FT_LIST_ITEM_OLDER, _FT_LIST_ITEM]}) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "list"]) + assert result.exit_code == 0 + assert "ft-newer" in result.output + assert "ft-older" in result.output + assert result.output.index("ft-newer") < result.output.index("ft-older") + + @pytest.mark.respx(base_url=base_url) + def test_list_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/fine-tunes").mock( + return_value=httpx.Response(200, json={"data": [_FT_LIST_ITEM_OLDER, _FT_LIST_ITEM]}) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "list", "--json"]) + assert result.exit_code == 0 + parsed = json.loads(result.output) + assert [x["id"] for x in parsed] == ["ft-newer", "ft-older"] + + +class TestFineTuningRetrieve: + @pytest.mark.respx(base_url=base_url) + def test_retrieve_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/fine-tunes/ft-1").mock(return_value=httpx.Response(200, json=_FT_RETRIEVE_BODY)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "retrieve", "ft-1", "--json"]) + assert result.exit_code == 0 + body = json.loads(result.output) + assert body["id"] == "ft-1" + assert body["status"] == "completed" + + +class TestFineTuningCancel: + @pytest.mark.respx(base_url=base_url) + def test_cancel_not_cancellable(self, respx_mock: MockRouter) -> None: + body = {**_FT_RETRIEVE_BODY, "status": "completed"} + respx_mock.get("/fine-tunes/ft-1").mock(return_value=httpx.Response(200, json=body)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "cancel", "ft-1", "--quiet"]) + assert result.exit_code == 0 + assert "not currently cancellable" in result.output + assert "completed" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_cancel_quiet_calls_api(self, respx_mock: MockRouter) -> None: + running = {**_FT_RETRIEVE_BODY, "status": "running"} + respx_mock.get("/fine-tunes/ft-1").mock(return_value=httpx.Response(200, json=running)) + respx_mock.post("/fine-tunes/ft-1/cancel").mock( + return_value=httpx.Response(200, json={**running, "status": "cancel_requested"}) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "cancel", "ft-1", "--quiet"]) + assert result.exit_code == 0 + assert "Cancelled" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_cancel_json_requires_quiet(self, respx_mock: MockRouter) -> None: + running = {**_FT_RETRIEVE_BODY, "status": "running"} + respx_mock.get("/fine-tunes/ft-1").mock(return_value=httpx.Response(200, json=running)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "cancel", "ft-1", "--json"]) + assert result.exit_code != 0 + assert "quiet" in result.output.lower() + + @pytest.mark.respx(base_url=base_url) + def test_cancel_not_cancellable_json(self, respx_mock: MockRouter) -> None: + body = {**_FT_RETRIEVE_BODY, "status": "completed"} + respx_mock.get("/fine-tunes/ft-1").mock(return_value=httpx.Response(200, json=body)) + try: + runner = CliRunner(env=_ENV, mix_stderr=False) + except Exception: + # Python 3.14 doesnt have the mix_stderr parameter + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "cancel", "ft-1", "--quiet", "--json"]) + assert result.exit_code == 0 + assert result.stdout_bytes.decode("utf-8") == "" + assert result.stderr_bytes is not None + assert len(result.stderr_bytes) > 0 + assert "Training is not currently cancellable" in result.stderr_bytes.decode("utf-8") + + +class TestFineTuningDelete: + def test_delete_json_requires_force(self) -> None: + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "delete", "ft-1", "--json"]) + assert result.exit_code != 0 + assert "force" in result.output.lower() + + @pytest.mark.respx(base_url=base_url) + def test_delete_force(self, respx_mock: MockRouter) -> None: + respx_mock.delete("/fine-tunes/ft-1").mock(return_value=httpx.Response(200, json={"message": "deleted"})) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "delete", "ft-1", "--force"]) + assert result.exit_code == 0 + assert "Deleted" in result.output + + +class TestFineTuningEventsAndCheckpoints: + @pytest.mark.respx(base_url=base_url) + def test_list_events_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/fine-tunes/ft-1/events").mock(return_value=httpx.Response(200, json={"data": [_FT_EVENT]})) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "list-events", "ft-1", "--json"]) + assert result.exit_code == 0 + assert json.loads(result.output)[0]["message"] == "training started" + + @pytest.mark.respx(base_url=base_url) + def test_list_checkpoints_table(self, respx_mock: MockRouter) -> None: + respx_mock.get("/fine-tunes/ft-1/checkpoints").mock( + return_value=httpx.Response(200, json={"data": [_FT_CHECKPOINT]}) + ) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "list-checkpoints", "ft-1"]) + assert result.exit_code == 0 + assert "ft-1:5" in result.output + assert "intermediate" in result.output + + @pytest.mark.respx(base_url=base_url) + def test_list_checkpoints_empty_message(self, respx_mock: MockRouter) -> None: + respx_mock.get("/fine-tunes/ft-1/checkpoints").mock(return_value=httpx.Response(200, json={"data": []})) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["fine-tuning", "list-checkpoints", "ft-1"]) + assert result.exit_code == 0 + assert "No checkpoints found" in result.output + + +class TestFineTuningDownload: + @pytest.mark.respx(base_url=base_url) + def test_download_invokes_download_manager(self, respx_mock: MockRouter, tmp_path: Path) -> None: + respx_mock.get("/fine-tunes/ft-abcd-12").mock(return_value=httpx.Response(200, json=_FT_RETRIEVE_BODY)) + out_file = tmp_path / "weights.tar" + out_file.write_bytes(b"x") + + class _DM: + def __init__(self, _client: object) -> None: + pass + + def download(self, **kwargs: object) -> tuple[str, int]: + assert "ft_id=ft-abcd-12" in str(kwargs.get("url", "")) + assert "checkpoint=model_output_path" in str(kwargs.get("url", "")) + return str(out_file), 1 + + with patch.object(_ft_download_mod, "DownloadManager", _DM): + runner = CliRunner(env=_ENV) + # Full fine-tunes require explicit --checkpoint-type default (CLI default is merged for LoRA). + result = runner.invoke( + main, + [ + "fine-tuning", + "download", + "ft-abcd-12", + "--checkpoint-type", + "default", + "--output_dir", + str(tmp_path), + ], + ) + assert result.exit_code == 0 + payload = json.loads(result.output.strip()) + assert payload["id"] == "ft-abcd-12" + assert payload["size"] == 1 diff --git a/tests/cli/test_json_mode_pipeable_to_jq.py b/tests/cli/test_json_mode_pipeable_to_jq.py new file mode 100644 index 000000000..98eaab1dc --- /dev/null +++ b/tests/cli/test_json_mode_pipeable_to_jq.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import os +import subprocess +from typing import Any + + +class JSONValidator: + _skip: bool = False + + def __init__(self, namespace: str): + self.namespace = namespace + + @property + def skip(self) -> JSONValidator: + self._skip = True + return self + + # Invokes the command on the command line + # It then pipes the results to jq to assert that the JSON is valid + def run_and_assert(self, command: str, **kwargs: Any) -> None: + if self._skip: + print(f"Skipping {command} because it is not supported in JSON mode") + return + + def run_command(command: str) -> subprocess.CompletedProcess[str]: + base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + return subprocess.run( + ["together", "--base-url", base_url, self.namespace, *command.split(" "), "--json"], + capture_output=True, + text=True, + **kwargs, + ) + + def run_jq(command_response: str) -> subprocess.CompletedProcess[str]: + return subprocess.run( + ["jq"], + input=command_response, + capture_output=True, + text=True, + ) + + command_result = run_command(command) + result = run_jq(command_result.stdout) + + if result.returncode != 0: + raise AssertionError(f'{self.namespace} {command} failed to parse JSON: "{result.stdout}"') + + +class TestJSONMode: + # All Endpoint commands + def test_endpoints_json_mode(self) -> None: + endpoints = JSONValidator("endpoints") + endpoints.run_and_assert("availability-zones") + endpoints.run_and_assert("create --model deepseek-ai/DeepSeek-R1 --hardware 1x_nvidia_a100_80gb_sxm") + endpoints.run_and_assert("delete endpoint-123") + endpoints.run_and_assert("hardware") + endpoints.run_and_assert("hardware --model deepseek-ai/DeepSeek-R1") + endpoints.run_and_assert("list") + endpoints.run_and_assert("list --type dedicated") + endpoints.run_and_assert("list --usage-type on-demand") + endpoints.run_and_assert("list --usage-type reserved") + endpoints.run_and_assert("list --mine") + endpoints.run_and_assert("retrieve endpoint-123") + endpoints.run_and_assert("start endpoint-123") + endpoints.run_and_assert("stop endpoint-123") + endpoints.run_and_assert("update endpoint-123 --min-replicas 2 --max-replicas 4 --inactive-timeout 60") + + # All Evals commands + def test_evals_json_mode(self) -> None: + evals = JSONValidator("evals") + evals.skip.run_and_assert( + "create --type classify --judge-model deepseek-ai/DeepSeek-R1 --judge-model-source dedicated --judge-system-template 'You are a helpful assistant' --input-data-file-path data.json --model-field 'generated_text' --model-to-evaluate deepseek-ai/DeepSeek-R1 --model-to-evaluate-source dedicated --model-to-evaluate-system-template 'You are a helpful assistant' --model-to-evaluate-input-template 'You are a helpful assistant' --labels 'yes,no' --pass-labels 'yes' --min-score 0.5 --max-score 1.0 --pass-threshold 0.75" + ) + evals.run_and_assert("list") + evals.run_and_assert("list --status completed") + evals.run_and_assert("list --limit 1") + evals.run_and_assert("retrieve eval-123") + evals.run_and_assert("status eval-123") + + # All files commands + def test_files_json_mode(self) -> None: + files = JSONValidator("files") + files.run_and_assert("check data.jsonl", cwd=os.path.dirname(__file__)) + files.run_and_assert("delete file-123") + files.run_and_assert("list") + files.run_and_assert("retrieve file-123") + files.run_and_assert("upload data.jsonl --purpose fine-tune", cwd=os.path.dirname(__file__)) + + # All fine-tuning commands + def test_fine_tuning_json_mode(self) -> None: + fine_tuning = JSONValidator("fine-tuning") + fine_tuning.skip.run_and_assert("create") # TODO: + fine_tuning.run_and_assert("list") + fine_tuning.run_and_assert("retrieve ft-123") + fine_tuning.run_and_assert("cancel ft-123 --quiet") + fine_tuning.run_and_assert("download ft-123") + fine_tuning.run_and_assert("delete ft-123 --force") + fine_tuning.run_and_assert("list-events ft-123") + fine_tuning.run_and_assert("list-checkpoints ft-123") + fine_tuning.run_and_assert("retrieve-checkpoint ft-123/checkpoint-123") + fine_tuning.run_and_assert("retrieve-checkpoint ft-123/checkpoint-123") + + def test_models_json_mode(self) -> None: + models = JSONValidator("models") + models.run_and_assert("list") + models.run_and_assert("list --type dedicated") + models.run_and_assert("upload --model-name model-123/version-123 --model-source s3://model-123/version-123") + + def test_beta_clusters_json_mode(self) -> None: + beta_clusters = JSONValidator("beta clusters") + beta_clusters.run_and_assert( + "create --non-interactive --cluster-type KUBERNETES --gpu-type H100_SXM --driver-version CUDA_12_6_565 --region us-central-8 --num-gpus 0 --billing-type ON_DEMAND --name together-py-testing-suite --volume 123" + ) + beta_clusters.run_and_assert("delete cluster-123") + beta_clusters.run_and_assert("get-credentials cluster-123") + beta_clusters.run_and_assert("list") + beta_clusters.run_and_assert("list-regions") + beta_clusters.run_and_assert("retrieve cluster-123") + beta_clusters.run_and_assert("update cluster-123 --min-replicas 2 --max-replicas 4 --inactive-timeout 60") + + def test_beta_clusters_storage_json_mode(self) -> None: + beta_clusters_storage = JSONValidator("beta clusters storage") + beta_clusters_storage.run_and_assert("create --region us-east-1 --size-tib 1 --volume-name test-volume") + beta_clusters_storage.run_and_assert("delete storage-123") + beta_clusters_storage.run_and_assert("list") + beta_clusters_storage.run_and_assert("retrieve storage-123") + + def test_jig_json_mode(self) -> None: + jig = JSONValidator("beta jig") + jig.skip.run_and_assert("init") + jig.skip.run_and_assert("dockerfile") + jig.skip.run_and_assert("build") + jig.skip.run_and_assert("push") + jig.skip.run_and_assert("deploy") + jig.skip.run_and_assert("endpoint") + jig.skip.run_and_assert("logs") + jig.skip.run_and_assert("destroy") + jig.skip.run_and_assert("submit") + jig.skip.run_and_assert("job-status") + jig.skip.run_and_assert("queue-status") + + jig.run_and_assert("list") + jig.run_and_assert("status") + + def test_jig_secrets_json_mode(self) -> None: + jig = JSONValidator("beta jig secrets") + jig.skip.run_and_assert("set") + jig.skip.run_and_assert("unset") + jig.skip.run_and_assert("list") + + def test_jig_volumes_json_mode(self) -> None: + jig = JSONValidator("beta jig volumes") + jig.skip.run_and_assert("create") + jig.skip.run_and_assert("update") + jig.skip.run_and_assert("delete") + jig.skip.run_and_assert("describe") + jig.skip.run_and_assert("list") diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py new file mode 100644 index 000000000..6bd27ef6d --- /dev/null +++ b/tests/cli/test_main.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import os + +from click.testing import CliRunner + +from together.lib.cli import main +from together._version import __version__ + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +API_KEY = "0000000000000000000000000000000000000000" + + +class TestMainGlobalOptions: + def test_version_exits_zero(self) -> None: + runner = CliRunner() + result = runner.invoke(main, ["--version"]) + assert result.exit_code == 0 + assert __version__ in result.output + + def test_help_without_api_key_still_works(self) -> None: + runner = CliRunner(env={"TOGETHER_BASE_URL": base_url}) + result = runner.invoke(main, ["--help"]) + assert result.exit_code == 0 + assert "together" in result.output.lower() or "CLI" in result.output + + def test_timeout_and_max_retries_passed_to_client(self) -> None: + from unittest.mock import patch + + with patch("together.lib.cli.together.Together") as ctor: + runner = CliRunner( + env={"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY}, + ) + r = runner.invoke( + main, + [ + "--timeout", + "99", + "--max-retries", + "3", + "models", + "--help", + ], + ) + assert r.exit_code == 0 + call_kw = ctor.call_args.kwargs + assert call_kw.get("timeout") == 99 + assert call_kw.get("max_retries") == 3 diff --git a/tests/cli/test_models.py b/tests/cli/test_models.py new file mode 100644 index 000000000..d23b54bf8 --- /dev/null +++ b/tests/cli/test_models.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import os +import json +from typing import cast + +import httpx +import pytest +from respx import MockRouter +from click.testing import CliRunner + +from together.lib.cli import main + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +API_KEY = "0000000000000000000000000000000000000000" +_ENV = {"TOGETHER_BASE_URL": base_url, "TOGETHER_API_KEY": API_KEY} + +_UPLOAD_BODY = { + "data": { + "job_id": "job-a15dad11-8d8e-4007-97c5-a211304de284", + "model_name": "necolinehubner/Qwen2.5-72B-Instruct", + "model_id": "model-c0e32dfc-637e-47b2-bf4e-e9b2e58c9da7", + "model_source": "huggingface", + }, + "message": "Processing model weights. Job created.", +} + +list_data = [ + { + "id": "model/chat", + "created": 1742764800, + "object": "model", + "type": "chat", + "context_length": 1000, + "display_name": "Chat Model", + "license": None, + "link": None, + "organization": "org/1", + "pricing": {"base": None, "finetune": None, "hourly": None, "input": 0.05, "output": 0.10}, + }, + { + "id": "model/lang", + "created": 1742764800, + "object": "model", + "type": "language", + "context_length": 1000, + "display_name": "Language Model", + "license": None, + "link": None, + "organization": "org/1", + "pricing": {"base": None, "finetune": None, "hourly": None, "input": 0.5, "output": 1.0}, + }, + { + "id": "model/video", + "created": 1742764800, + "object": "model", + "type": "video", + "context_length": None, + "display_name": "Video Model", + "license": None, + "link": None, + "organization": "org/1", + "pricing": None, + }, +] + + +class TestModelsList: + # Test for endpoint create requiring the model + @pytest.mark.respx(base_url=base_url) + def test_list(self, respx_mock: MockRouter) -> None: + respx_mock.get("/models").mock(return_value=httpx.Response(200, json=list_data)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["models", "list"]) + assert ( + result.output.strip() + == """ +Model Type Context length Price per 1M Tokens (input/output) +----------- -------- ---------------- ------------------------------------ +model/chat chat 1000 $0.05/$0.10 +model/lang language 1000 $0.50/$1.00 +model/video video +""".strip() + ) + + # Test for endpoint create requiring the model + @pytest.mark.respx(base_url=base_url) + def test_list_dedicated(self, respx_mock: MockRouter) -> None: + route = respx_mock.get("/models").mock(return_value=httpx.Response(200, json=list_data)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["models", "list", "--type", "dedicated"]) + assert ( + result.output.strip() + == """ +Model Type Context length Price per 1M Tokens (input/output) +----------- -------- ---------------- ------------------------------------ +model/chat chat 1000 $0.05/$0.10 +model/lang language 1000 $0.50/$1.00 +model/video video +""".strip() + ) + url = str(route.calls[0].request.url) # type: ignore[arg-type] + assert "dedicated=true" in url + + # Test for endpoint create requiring the model + @pytest.mark.respx(base_url=base_url) + def test_list_json(self, respx_mock: MockRouter) -> None: + respx_mock.get("/models").mock(return_value=httpx.Response(200, json=list_data)) + runner = CliRunner(env=_ENV) + result = runner.invoke(main, ["models", "list", "--json"]) + assert result.output.strip() == json.dumps(list_data, indent=2).strip() + + +class TestModelsUpload: + @pytest.mark.respx(base_url=base_url) + def test_upload(self, respx_mock: MockRouter) -> None: + route = respx_mock.post("/models").mock(return_value=httpx.Response(200, json=_UPLOAD_BODY)) + runner = CliRunner(env=_ENV) + result = runner.invoke( + main, ["models", "upload", "--model-name", "model-123", "--model-source", "s3://model-123"] + ) + assert ( + result.output.strip() + == """Model upload job created successfully! +Job ID: job-a15dad11-8d8e-4007-97c5-a211304de284 +Model Name: necolinehubner/Qwen2.5-72B-Instruct +Model ID: model-c0e32dfc-637e-47b2-bf4e-e9b2e58c9da7 +Model Source: huggingface +Message: Processing model weights. Job created. +""".strip() + ) + raw = cast(str, route.calls[0].request.content.decode()) # type: ignore[arg-type] + body = json.loads(raw) + assert body["model_type"] == "model" + + @pytest.mark.respx(base_url=base_url) + def test_upload_adapter_sends_model_type(self, respx_mock: MockRouter) -> None: + post = respx_mock.post("/models").mock(return_value=httpx.Response(200, json=_UPLOAD_BODY)) + runner = CliRunner(env=_ENV) + result = runner.invoke( + main, + [ + "models", + "upload", + "--model-name", + "m", + "--model-source", + "s3://x", + "--model-type", + "adapter", + "--base-model", + "base-m", + ], + ) + assert result.exit_code == 0 + raw = cast(str, post.calls[0].request.content.decode()) # type: ignore[arg-type] + body = json.loads(raw) + assert body["model_type"] == "adapter" + assert body["base_model"] == "base-m" + + @pytest.mark.respx(base_url=base_url) + def test_upload_json(self, respx_mock: MockRouter) -> None: + respx_mock.post("/models").mock(return_value=httpx.Response(200, json=_UPLOAD_BODY)) + runner = CliRunner(env=_ENV) + result = runner.invoke( + main, + ["models", "upload", "--model-name", "model-123", "--model-source", "s3://model-123", "--json"], + ) + assert result.exit_code == 0 + out = json.loads(result.output) + assert out["message"] == _UPLOAD_BODY["message"] + + +class TestModelsListInvalid: + def test_list_invalid_type_choice(self) -> None: + runner = CliRunner(env=_ENV) + r = runner.invoke(main, ["models", "list", "--type", "serverless"]) + assert r.exit_code == 2 diff --git a/tests/test_utils/test_path.py b/tests/test_utils/test_path.py new file mode 100644 index 000000000..b450312bb --- /dev/null +++ b/tests/test_utils/test_path.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from together._utils._path import path_template + + +@pytest.mark.parametrize( + "template, kwargs, expected", + [ + ("/v1/{id}", dict(id="abc"), "/v1/abc"), + ("/v1/{a}/{b}", dict(a="x", b="y"), "/v1/x/y"), + ("/v1/{a}{b}/path/{c}?val={d}#{e}", dict(a="x", b="y", c="z", d="u", e="v"), "/v1/xy/path/z?val=u#v"), + ("/{w}/{w}", dict(w="echo"), "/echo/echo"), + ("/v1/static", {}, "/v1/static"), + ("", {}, ""), + ("/v1/?q={n}&count=10", dict(n=42), "/v1/?q=42&count=10"), + ("/v1/{v}", dict(v=None), "/v1/null"), + ("/v1/{v}", dict(v=True), "/v1/true"), + ("/v1/{v}", dict(v=False), "/v1/false"), + ("/v1/{v}", dict(v=".hidden"), "/v1/.hidden"), # dot prefix ok + ("/v1/{v}", dict(v="file.txt"), "/v1/file.txt"), # dot in middle ok + ("/v1/{v}", dict(v="..."), "/v1/..."), # triple dot ok + ("/v1/{a}{b}", dict(a=".", b="txt"), "/v1/.txt"), # dot var combining with adjacent to be ok + ("/items?q={v}#{f}", dict(v=".", f=".."), "/items?q=.#.."), # dots in query/fragment are fine + ( + "/v1/{a}?query={b}", + dict(a="../../other/endpoint", b="a&bad=true"), + "/v1/..%2F..%2Fother%2Fendpoint?query=a%26bad%3Dtrue", + ), + ("/v1/{val}", dict(val="a/b/c"), "/v1/a%2Fb%2Fc"), + ("/v1/{val}", dict(val="a/b/c?query=value"), "/v1/a%2Fb%2Fc%3Fquery=value"), + ("/v1/{val}", dict(val="a/b/c?query=value&bad=true"), "/v1/a%2Fb%2Fc%3Fquery=value&bad=true"), + ("/v1/{val}", dict(val="%20"), "/v1/%2520"), # escapes escape sequences in input + # Query: slash and ? are safe, # is not + ("/items?q={v}", dict(v="a/b"), "/items?q=a/b"), + ("/items?q={v}", dict(v="a?b"), "/items?q=a?b"), + ("/items?q={v}", dict(v="a#b"), "/items?q=a%23b"), + ("/items?q={v}", dict(v="a b"), "/items?q=a%20b"), + # Fragment: slash and ? are safe + ("/docs#{v}", dict(v="a/b"), "/docs#a/b"), + ("/docs#{v}", dict(v="a?b"), "/docs#a?b"), + # Path: slash, ? and # are all encoded + ("/v1/{v}", dict(v="a/b"), "/v1/a%2Fb"), + ("/v1/{v}", dict(v="a?b"), "/v1/a%3Fb"), + ("/v1/{v}", dict(v="a#b"), "/v1/a%23b"), + # same var encoded differently by component + ( + "/v1/{v}?q={v}#{v}", + dict(v="a/b?c#d"), + "/v1/a%2Fb%3Fc%23d?q=a/b?c%23d#a/b?c%23d", + ), + ("/v1/{val}", dict(val="x?admin=true"), "/v1/x%3Fadmin=true"), # query injection + ("/v1/{val}", dict(val="x#admin"), "/v1/x%23admin"), # fragment injection + ], +) +def test_interpolation(template: str, kwargs: dict[str, Any], expected: str) -> None: + assert path_template(template, **kwargs) == expected + + +def test_missing_kwarg_raises_key_error() -> None: + with pytest.raises(KeyError, match="org_id"): + path_template("/v1/{org_id}") + + +@pytest.mark.parametrize( + "template, kwargs", + [ + ("{a}/path", dict(a=".")), + ("{a}/path", dict(a="..")), + ("/v1/{a}", dict(a=".")), + ("/v1/{a}", dict(a="..")), + ("/v1/{a}/path", dict(a=".")), + ("/v1/{a}/path", dict(a="..")), + ("/v1/{a}{b}", dict(a=".", b=".")), # adjacent vars → ".." + ("/v1/{a}.", dict(a=".")), # var + static → ".." + ("/v1/{a}{b}", dict(a="", b=".")), # empty + dot → "." + ("/v1/%2e/{x}", dict(x="ok")), # encoded dot in static text + ("/v1/%2e./{x}", dict(x="ok")), # mixed encoded ".." in static + ("/v1/.%2E/{x}", dict(x="ok")), # mixed encoded ".." in static + ("/v1/{v}?q=1", dict(v="..")), + ("/v1/{v}#frag", dict(v="..")), + ], +) +def test_dot_segment_rejected(template: str, kwargs: dict[str, Any]) -> None: + with pytest.raises(ValueError, match="dot-segment"): + path_template(template, **kwargs) diff --git a/uv.lock b/uv.lock index dc39e1b31..6d6659fe6 100644 --- a/uv.lock +++ b/uv.lock @@ -2058,7 +2058,7 @@ wheels = [ [[package]] name = "together" -version = "2.4.0" +version = "2.5.0" source = { editable = "." } dependencies = [ { name = "anyio" },