diff --git a/src/twinkle/hub/__init__.py b/src/twinkle/hub/__init__.py index 0eb798227..4180d5f63 100644 --- a/src/twinkle/hub/__init__.py +++ b/src/twinkle/hub/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .hub import HFHub, HubOperation, MSHub +from .model_alias import MODEL_ID_ALIASES_ENV, build_model_alias_map, load_model_alias_map, resolve_model_id_alias diff --git a/src/twinkle/hub/hub.py b/src/twinkle/hub/hub.py index 5ee00cd0d..c53727449 100644 --- a/src/twinkle/hub/hub.py +++ b/src/twinkle/hub/hub.py @@ -9,6 +9,7 @@ from typing import Dict, List, Literal, Optional, Union from ..utils import requires +from .model_alias import resolve_model_id_alias _executor = concurrent.futures.ProcessPoolExecutor(max_workers=8) _futures = {} @@ -204,6 +205,7 @@ def download_model(cls, Returns: The local dir """ + model_id_or_path = resolve_model_id_alias(model_id_or_path) if kwargs.pop('ignore_model', False): ignore_patterns = set(ignore_patterns or []) | set(large_file_pattern) if os.path.exists(model_id_or_path): diff --git a/src/twinkle/hub/model_alias.py b/src/twinkle/hub/model_alias.py new file mode 100644 index 000000000..af9dd6237 --- /dev/null +++ b/src/twinkle/hub/model_alias.py @@ -0,0 +1,59 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from __future__ import annotations + +import json +import os +from collections.abc import Iterable, Mapping +from typing import Any + +MODEL_ID_ALIASES_ENV = 'TWINKLE_MODEL_ID_ALIASES' + + +def _get_value(obj: Any, key: str, default: Any = None) -> Any: + if isinstance(obj, Mapping): + return obj.get(key, default) + return getattr(obj, key, default) + + +def extract_model_alias(route_prefix: str, service_type: str = 'model') -> str | None: + marker = f'/{service_type}/' + if not route_prefix or marker not in route_prefix: + return None + alias = route_prefix.split(marker, 1)[1].strip('/') + return alias or None + + +def build_model_alias_map(applications: Iterable[Any]) -> dict[str, str]: + alias_map: dict[str, str] = {} + for application in applications or []: + if _get_value(application, 'import_path') != 'model': + continue + alias = extract_model_alias(_get_value(application, 'route_prefix', ''), service_type='model') + args = _get_value(application, 'args') + model_id = _get_value(args, 'model_id') if args is not None else None + if alias and model_id and alias != model_id: + alias_map[alias] = model_id + return alias_map + + +def load_model_alias_map(raw: str | Mapping[str, str] | None = None) -> dict[str, str]: + if raw is None: + raw = os.environ.get(MODEL_ID_ALIASES_ENV) + if not raw: + return {} + if isinstance(raw, Mapping): + return {str(key): str(value) for key, value in raw.items() if key and value} + try: + data = json.loads(raw) + except Exception: + return {} + if not isinstance(data, dict): + return {} + return {str(key): str(value) for key, value in data.items() if key and value} + + +def resolve_model_id_alias(model_id_or_path: str | None, alias_map: Mapping[str, str] | None = None) -> str | None: + if not model_id_or_path: + return model_id_or_path + aliases = alias_map if alias_map is not None else load_model_alias_map() + return aliases.get(model_id_or_path, model_id_or_path) diff --git a/src/twinkle/server/launcher/env_propagation.py b/src/twinkle/server/launcher/env_propagation.py index 0c815e548..4d4a7a474 100644 --- a/src/twinkle/server/launcher/env_propagation.py +++ b/src/twinkle/server/launcher/env_propagation.py @@ -18,6 +18,7 @@ 'TWINKLE_TELEMETRY_SERVICE', 'TWINKLE_TELEMETRY_ENDPOINT', 'TWINKLE_TELEMETRY_INTERVAL', + 'TWINKLE_MODEL_ID_ALIASES', ) diff --git a/src/twinkle/server/launcher/server_launcher.py b/src/twinkle/server/launcher/server_launcher.py index 8f1248557..4c64c1208 100644 --- a/src/twinkle/server/launcher/server_launcher.py +++ b/src/twinkle/server/launcher/server_launcher.py @@ -8,6 +8,7 @@ """ from __future__ import annotations +import json import os import signal import threading @@ -16,6 +17,7 @@ from typing import Any from twinkle import get_logger +from twinkle.hub.model_alias import MODEL_ID_ALIASES_ENV, build_model_alias_map from twinkle.server.config import ServerConfig from twinkle.server.config.application_spec import ApplicationSpec from twinkle.server.utils.ray_serve_patch import apply_ray_serve_patches, get_runtime_env_for_patches @@ -221,6 +223,12 @@ def launch(self) -> None: os.environ[k] = v logger.info(f'Persistence backend configured: mode={persistence.mode}') + model_alias_map = build_model_alias_map(self.config.applications) + if model_alias_map: + os.environ[MODEL_ID_ALIASES_ENV] = json.dumps(model_alias_map, ensure_ascii=False) + else: + os.environ.pop(MODEL_ID_ALIASES_ENV, None) + self._init_ray() self._start_serve() diff --git a/src/twinkle_client/common/serialize.py b/src/twinkle_client/common/serialize.py index 3093ec58b..42a27beb3 100644 --- a/src/twinkle_client/common/serialize.py +++ b/src/twinkle_client/common/serialize.py @@ -1,5 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import json +from dataclasses import fields from numbers import Number from peft import LoraConfig from pydantic import BaseModel @@ -15,6 +16,7 @@ primitive_types = (str, Number, bool, bytes, type(None)) container_types = (Mapping, list, tuple, set, frozenset) basic_types = (*primitive_types, *container_types) +_DATASET_META_FIELDS = {field.name for field in fields(DatasetMeta)} def _serialize_data_slice(data_slice): @@ -45,7 +47,7 @@ def _deserialize_data_slice(data_slice): def serialize_object(obj) -> str: if isinstance(obj, DatasetMeta): - data = obj.__dict__.copy() + data = {name: getattr(obj, name) for name in _DATASET_META_FIELDS} data['data_slice'] = _serialize_data_slice(data.get('data_slice')) data['_TWINKLE_TYPE_'] = 'DatasetMeta' return json.dumps(data, ensure_ascii=False) @@ -80,6 +82,7 @@ def deserialize_object(data: str) -> Any: if '_TWINKLE_TYPE_' in data: _type = data.pop('_TWINKLE_TYPE_') if _type == 'DatasetMeta': + data = {key: value for key, value in data.items() if key in _DATASET_META_FIELDS} data['data_slice'] = _deserialize_data_slice(data.get('data_slice')) return DatasetMeta(**data) elif _type == 'LoraConfig':