Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/twinkle/hub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .hub import HFHub, HubOperation, MSHub
from .model_alias import MODEL_ID_ALIASES_ENV, build_model_alias_map, load_model_alias_map, resolve_model_id_alias
2 changes: 2 additions & 0 deletions src/twinkle/hub/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, List, Literal, Optional, Union

from ..utils import requires
from .model_alias import resolve_model_id_alias

_executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
_futures = {}
Expand Down Expand Up @@ -204,6 +205,7 @@ def download_model(cls,
Returns:
The local dir
"""
model_id_or_path = resolve_model_id_alias(model_id_or_path)
if kwargs.pop('ignore_model', False):
ignore_patterns = set(ignore_patterns or []) | set(large_file_pattern)
if os.path.exists(model_id_or_path):
Expand Down
59 changes: 59 additions & 0 deletions src/twinkle/hub/model_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from __future__ import annotations

import json
import os
from collections.abc import Iterable, Mapping
from typing import Any

MODEL_ID_ALIASES_ENV = 'TWINKLE_MODEL_ID_ALIASES'


def _get_value(obj: Any, key: str, default: Any = None) -> Any:
if isinstance(obj, Mapping):
return obj.get(key, default)
return getattr(obj, key, default)


def extract_model_alias(route_prefix: str, service_type: str = 'model') -> str | None:
marker = f'/{service_type}/'
if not route_prefix or marker not in route_prefix:
return None
alias = route_prefix.split(marker, 1)[1].strip('/')
return alias or None


def build_model_alias_map(applications: Iterable[Any]) -> dict[str, str]:
alias_map: dict[str, str] = {}
for application in applications or []:
if _get_value(application, 'import_path') != 'model':
continue
alias = extract_model_alias(_get_value(application, 'route_prefix', ''), service_type='model')
args = _get_value(application, 'args')
model_id = _get_value(args, 'model_id') if args is not None else None
if alias and model_id and alias != model_id:
alias_map[alias] = model_id
return alias_map


def load_model_alias_map(raw: str | Mapping[str, str] | None = None) -> dict[str, str]:
if raw is None:
raw = os.environ.get(MODEL_ID_ALIASES_ENV)
if not raw:
return {}
if isinstance(raw, Mapping):
return {str(key): str(value) for key, value in raw.items() if key and value}
try:
data = json.loads(raw)
except Exception:
return {}
if not isinstance(data, dict):
return {}
return {str(key): str(value) for key, value in data.items() if key and value}


def resolve_model_id_alias(model_id_or_path: str | None, alias_map: Mapping[str, str] | None = None) -> str | None:
if not model_id_or_path:
return model_id_or_path
aliases = alias_map if alias_map is not None else load_model_alias_map()
return aliases.get(model_id_or_path, model_id_or_path)
1 change: 1 addition & 0 deletions src/twinkle/server/launcher/env_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'TWINKLE_TELEMETRY_SERVICE',
'TWINKLE_TELEMETRY_ENDPOINT',
'TWINKLE_TELEMETRY_INTERVAL',
'TWINKLE_MODEL_ID_ALIASES',
)


Expand Down
8 changes: 8 additions & 0 deletions src/twinkle/server/launcher/server_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
from __future__ import annotations

import json
import os
import signal
import threading
Expand All @@ -16,6 +17,7 @@
from typing import Any

from twinkle import get_logger
from twinkle.hub.model_alias import MODEL_ID_ALIASES_ENV, build_model_alias_map
from twinkle.server.config import ServerConfig
from twinkle.server.config.application_spec import ApplicationSpec
from twinkle.server.utils.ray_serve_patch import apply_ray_serve_patches, get_runtime_env_for_patches
Expand Down Expand Up @@ -221,6 +223,12 @@ def launch(self) -> None:
os.environ[k] = v
logger.info(f'Persistence backend configured: mode={persistence.mode}')

model_alias_map = build_model_alias_map(self.config.applications)
if model_alias_map:
os.environ[MODEL_ID_ALIASES_ENV] = json.dumps(model_alias_map, ensure_ascii=False)
else:
os.environ.pop(MODEL_ID_ALIASES_ENV, None)

self._init_ray()
self._start_serve()

Expand Down
5 changes: 4 additions & 1 deletion src/twinkle_client/common/serialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import json
from dataclasses import fields
from numbers import Number
from peft import LoraConfig
from pydantic import BaseModel
Expand All @@ -15,6 +16,7 @@
primitive_types = (str, Number, bool, bytes, type(None))
container_types = (Mapping, list, tuple, set, frozenset)
basic_types = (*primitive_types, *container_types)
_DATASET_META_FIELDS = {field.name for field in fields(DatasetMeta)}


def _serialize_data_slice(data_slice):
Expand Down Expand Up @@ -45,7 +47,7 @@ def _deserialize_data_slice(data_slice):

def serialize_object(obj) -> str:
if isinstance(obj, DatasetMeta):
data = obj.__dict__.copy()
data = {name: getattr(obj, name) for name in _DATASET_META_FIELDS}
data['data_slice'] = _serialize_data_slice(data.get('data_slice'))
data['_TWINKLE_TYPE_'] = 'DatasetMeta'
return json.dumps(data, ensure_ascii=False)
Expand Down Expand Up @@ -80,6 +82,7 @@ def deserialize_object(data: str) -> Any:
if '_TWINKLE_TYPE_' in data:
_type = data.pop('_TWINKLE_TYPE_')
if _type == 'DatasetMeta':
data = {key: value for key, value in data.items() if key in _DATASET_META_FIELDS}
data['data_slice'] = _deserialize_data_slice(data.get('data_slice'))
return DatasetMeta(**data)
elif _type == 'LoraConfig':
Expand Down
Loading