diff --git a/tests/pytest/__init__.py b/tests/pytest/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/pytest/filelocks/_worker.py b/tests/pytest/filelocks/_worker.py new file mode 100644 index 000000000..d8e54422c --- /dev/null +++ b/tests/pytest/filelocks/_worker.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import argparse +import importlib +import json +import sys +import traceback + + +def _load_callable(path: str): + if ":" not in path: + raise ValueError("Expected --func 'package.module:callable_name'") + mod_name, attr = path.split(":", 1) + mod = importlib.import_module(mod_name) + try: + return getattr(mod, attr) + except AttributeError as e: + raise AttributeError(f"{path!r} not found") from e + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run a callable in a clean subprocess") + parser.add_argument("--func", required=True, help="Import path: 'pkg.mod:callable'") + args = parser.parse_args() + + # Payload format: {"args": [...], "kwargs": {...}} + try: + payload = json.load(sys.stdin) if not sys.stdin.isatty() else {} + except json.JSONDecodeError: + payload = {} + + call_args = payload.get("args", []) or [] + call_kwargs = payload.get("kwargs", {}) or {} + + func = _load_callable(args.func) + + try: + # Call user code; allow it to print to stdout/stderr freely. + func(*call_args, **call_kwargs) + # Success -> exit 0 (no extra prints so your function's stdout stays clean). + sys.exit(0) + except SystemExit: + # Re-raise so explicit sys.exit propagates + raise + except Exception: + # On failure, send a traceback to stderr for the parent to capture. + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/pytest/filelocks/conftest.py b/tests/pytest/filelocks/conftest.py new file mode 100644 index 000000000..96d29e981 --- /dev/null +++ b/tests/pytest/filelocks/conftest.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import json +import os +import sys +import textwrap +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +import pytest +import subprocess +from concurrent.futures import ThreadPoolExecutor, as_completed + + +def pytest_addoption(parser: pytest.Parser) -> None: + group = parser.getgroup("filelock-runner") + group.addoption("--workers", action="store", type=int, default=8, + help="Default number of subprocess workers to launch per test") + group.addoption("--subprocess-timeout", action="store", type=float, default=30.0, + help="Per-subprocess timeout in seconds") + group.addoption("--no-output-truncation", action="store_true", + help="If set, do not truncate captured stdout/stderr on failure") + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "workers(n): override subprocess worker count for this test") + + +@dataclass +class ProcResult: + index: int + cmd: List[str] + payload: Dict[str, Any] + returncode: Optional[int] # None if timeout/terminated + duration_s: float + stdout: str + stderr: str + timeout: bool + + +@pytest.fixture(scope="session") +def project_root() -> Path: + return Path.cwd() + + +# @pytest.fixture(autouse=True) +# def _session_env(monkeypatch: pytest.MonkeyPatch) -> None: +# monkeypatch.setenv("PYTHONUNBUFFERED", "1") + +@pytest.fixture(scope="session", autouse=True) +def _session_env(): + prev = os.environ.get("PYTHONUNBUFFERED") + os.environ["PYTHONUNBUFFERED"] = "1" + yield + if prev is None: + os.environ.pop("PYTHONUNBUFFERED", None) + else: + os.environ["PYTHONUNBUFFERED"] = prev + +@pytest.fixture +def per_test_dir(tmp_path: Path) -> Path: + """ + A clean directory you can use per test. Useful to isolate lock files, + temp outputs, etc. + """ + return tmp_path + + +@pytest.fixture +def workers(request: pytest.FixtureRequest) -> int: + marker = request.node.get_closest_marker("workers") + if marker: + return int(marker.args[0]) + return int(request.config.getoption("--workers")) + + +@pytest.fixture +def subprocess_timeout(request: pytest.FixtureRequest) -> float: + return float(request.config.getoption("--subprocess-timeout")) + + +@pytest.fixture +def truncate_outputs(request: pytest.FixtureRequest) -> bool: + return not bool(request.config.getoption("--no-output-truncation")) + + +@pytest.fixture +def run_many(project_root: Path, subprocess_timeout: float): + """ + Fan-out runner. Launches multiple Python subprocesses that execute + a callable via tests/pytest/filelocks/workers.py, capturing stdout+stderr for each. + """ + def _run_many( + func: str, + payloads: Iterable[Dict[str, Any]], + *, + cwd: Optional[Path] = None, + timeout: Optional[float] = None, + env: Optional[Dict[str, str]] = None, + python: str = sys.executable, + max_parallel: Optional[int] = None, + ) -> List[ProcResult]: + payloads = list(payloads) + if not payloads: + raise ValueError("payloads must be a non-empty iterable of {'args': [...], 'kwargs': {...}} dicts") + + tmo = subprocess_timeout if timeout is None else timeout + working_dir = str((cwd or project_root).resolve()) + + # Ensure tests/ is importable so `-m tests.pytest.filelocks.workers` works + env_combined = os.environ.copy() + if env: + env_combined.update(env) + + # Guarantee PYTHONPATH has the project root + roots = [str(project_root.resolve()), working_dir] + existing = [p for p in env_combined.get("PYTHONPATH", "").split(os.pathsep) if p] + dedup = [] + for p in roots + existing: + if p and p not in dedup: + dedup.append(p) + env_combined["PYTHONPATH"] = os.pathsep.join(dedup) + + base_cmd = [python, "-u", "-m", "tests.pytest.filelocks._worker", "--func", func] + print(base_cmd) + + results: List[ProcResult] = [] + + def launch_one(index: int, payload: Dict[str, Any]) -> ProcResult: + start = time.perf_counter() + try: + proc = subprocess.run( + base_cmd, + input=json.dumps(payload), + cwd=working_dir, + env=env_combined, + capture_output=True, + text=True, + timeout=tmo, + ) + duration = time.perf_counter() - start + return ProcResult( + index=index, + cmd=base_cmd, + payload=payload, + returncode=proc.returncode, + duration_s=duration, + stdout=proc.stdout, + stderr=proc.stderr, + timeout=False, + ) + except subprocess.TimeoutExpired as e: + duration = time.perf_counter() - start + stdout = e.stdout if isinstance(e.stdout, str) else (e.stdout.decode() if e.stdout else "") + stderr = e.stderr if isinstance(e.stderr, str) else (e.stderr.decode() if e.stderr else "") + return ProcResult( + index=index, + cmd=base_cmd, + payload=payload, + returncode=None, + duration_s=duration, + stdout=stdout, + stderr=stderr, + timeout=True, + ) + + max_workers = max_parallel or min(32, len(payloads)) + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = [pool.submit(launch_one, i, payload) for i, payload in enumerate(payloads)] + for fut in as_completed(futures): + results.append(fut.result()) + + # Keep results ordered by index for readability + results.sort(key=lambda r: r.index) + return results + + return _run_many + + +@pytest.fixture +def assert_all_ok(truncate_outputs: bool): + """ + Assert helper that fails the test if ANY subprocess failed. + It prints stdout, stderr, and the stdin payload for each failing worker. + """ + def _assert_all_ok(results: List[ProcResult], *, show_bytes: int = 4000) -> None: + failed = [ + r for r in results + if r.timeout or (r.returncode is None) or (r.returncode != 0) + ] + if not failed: + return + + def maybe_trunc(s: str) -> str: + if not truncate_outputs or len(s) <= show_bytes: + return s + head = s[:show_bytes] + tail = s[-show_bytes:] + return f"{head}\n[...output truncated...]\n{tail}" + + sections = [] + sections.append(f"FAILURES DETECTED: {len(failed)}/{len(results)} subprocesses failed") + for r in failed: + rc = "TIMEOUT" if r.timeout else r.returncode + payload_pretty = json.dumps(r.payload, indent=2, sort_keys=True) + sections.append( + textwrap.dedent( + f""" + ── worker #{r.index} ───────────────────────────────────────── + cmd: {' '.join(r.cmd)} + exit: {rc}, ran for {r.duration_s:.2f}s + stdin payload: + {payload_pretty} + + ── stdout ─────────────────────────────────────────────────── + {maybe_trunc(r.stdout).rstrip()} + + ── stderr ─────────────────────────────────────────────────── + {maybe_trunc(r.stderr).rstrip()} + """ + ).rstrip() + ) + pytest.fail("\n\n".join(sections), pytrace=False) + + return _assert_all_ok + +@pytest.fixture +def hf_cache_env(per_test_dir: Path) -> dict: + hf_home = per_test_dir / "hf_home" + (hf_home / "hub").mkdir(parents=True, exist_ok=True) + # Keep tokenizers from forking extra threads & spamming logs in CI + env = { + "HF_HOME": str(hf_home), + "TRANSFORMERS_CACHE": str(hf_home / "transformers"), + "TOKENIZERS_PARALLELISM": "false", + } + return env \ No newline at end of file diff --git a/tests/pytest/filelocks/test_10_from_pretrained.py b/tests/pytest/filelocks/test_10_from_pretrained.py new file mode 100644 index 000000000..ef85bc253 --- /dev/null +++ b/tests/pytest/filelocks/test_10_from_pretrained.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import os +from pathlib import Path +import pytest + + +@pytest.fixture(scope="session") +def unsloth_model_name() -> str: + # You can override with: UNSLOTH_TEST_MODEL=your/model + # The default is a tiny LLaMA random-weight model widely used for tests. + model = os.environ.get("UNSLOTH_TEST_MODEL", "unsloth/Qwen3-4B-Instruct-2507") + if model == 'qwen': + model = "unsloth/Qwen3-4B-Instruct-2507" + elif model == 'gemma3': + model = "unsloth/gemma-3-4b-it" + elif model == 'gpt_oss': + model = "unsloth/gpt-oss-20b" + return model + + +@pytest.fixture(scope="session") +def can_4bit() -> bool: + try: + import torch # noqa + import bitsandbytes as bnb # noqa + # 4-bit really needs a usable CUDA device in practice. + import torch + return bool(torch.cuda.is_available()) + except Exception: + return False + + +@pytest.mark.parametrize("load_in_4bit", [True], ids=lambda b: f"4bit={b}") +def test_from_pretrained_many_processes( + load_in_4bit: bool, + can_4bit: bool, + unsloth_model_name: str, + per_test_dir: Path, + run_many, + assert_all_ok, + hf_cache_env: dict, + workers: int, +): + if load_in_4bit and not can_4bit: + pytest.skip("bitsandbytes/CUDA not available; skipping 4-bit path") + + # Payloads: all processes load the same model into the same HF cache + payload = {"args": [unsloth_model_name], "kwargs": {"load_in_4bit": load_in_4bit}} + payloads = [payload for _ in range(workers)] + + # --- Pre-warm single load so we fail fast if offline or the model isn't compatible. + warmup = run_many( + "tests.pytest.filelocks.workers:load_from_pretrained", + [payload], + cwd=per_test_dir, + env=hf_cache_env, + timeout=600.0, # generous first pull + max_parallel=1, + )[0] + if warmup.returncode != 0: + msg = f"Warmup load failed (rc={warmup.returncode}).\n\nstdout:\n{warmup.stdout}\n\nstderr:\n{warmup.stderr}" + pytest.skip(msg) + + # --- Now the real concurrency test + results = run_many( + "tests.pytest.filelocks.workers:load_from_pretrained", + payloads, + cwd=per_test_dir, + env=hf_cache_env, + timeout=120.0, + ) + assert_all_ok(results) diff --git a/tests/pytest/filelocks/test_20_full.py b/tests/pytest/filelocks/test_20_full.py new file mode 100644 index 000000000..3785f9e18 --- /dev/null +++ b/tests/pytest/filelocks/test_20_full.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import os +from pathlib import Path +import pytest + + +@pytest.fixture(scope="session") +def unsloth_model_name() -> str: + model = os.environ.get("UNSLOTH_TEST_MODEL", "unsloth/Qwen3-4B-Instruct-2507") + if model == 'qwen': + model = "unsloth/Qwen3-4B-Instruct-2507" + elif model == 'gemma3': + model = "unsloth/gemma-3-4b-it" + elif model == 'gpt_oss': + model = "unsloth/gpt-oss-20b" + return model + + + +@pytest.fixture(scope="session") +def chat_template_name() -> str: + return os.environ.get("UNSLOTH_CHAT_TEMPLATE", "qwen3") + + +@pytest.fixture(scope="session") +def can_4bit() -> bool: + try: + import torch # noqa: F401 + import bitsandbytes # noqa: F401 + return bool(torch.cuda.is_available()) + except Exception: + return False + + +def test_full_end_to_end_clobber( + per_test_dir: Path, + project_root: Path, + hf_cache_env: dict, + run_many, + assert_all_ok, + workers: int, + unsloth_model_name: str, + chat_template_name: str, + can_4bit: bool, +): + if os.environ.get("UNSLOTH_RUN_FULL", "0") != "1": + pytest.skip("Set UNSLOTH_RUN_FULL=1 to enable this full clobber test") + + load_in_4bit = (os.environ.get("UNSLOTH_FULL_4BIT", "1") == "1") + if load_in_4bit and not can_4bit: + pytest.skip("4-bit requested but no CUDA/bitsandbytes; skipping") + + # Shared caches to maximize lock contention (models + datasets) + env = dict(hf_cache_env) + hf_home = Path(env["HF_HOME"]) + env["HF_DATASETS_CACHE"] = str(hf_home / "datasets") + (hf_home / "datasets").mkdir(parents=True, exist_ok=True) + + # Everyone writes into the SAME target dir on purpose + artifacts = per_test_dir / "artifacts" + artifacts.mkdir(parents=True, exist_ok=True) + + barrier_base = per_test_dir / "barriers" + + enable_push = bool(os.environ.get("UNSLOTH_PUSH_TOKEN")) + payload_common = { + "load_in_4bit": load_in_4bit, + "chat_template": chat_template_name, + "barrier_base": str(barrier_base), + "nprocs": workers, + "dataset_slice": os.environ.get("UNSLOTH_DATASET_SLICE", "train[:1000]"), + "artifacts_dir": str(artifacts), # shared by all workers + "enable_push": enable_push, + "hf_repo": os.environ.get("UNSLOTH_HF_REPO", ""), + "push_token": os.environ.get("UNSLOTH_PUSH_TOKEN", ""), + "enable_gguf": os.environ.get("UNSLOTH_ENABLE_GGUF", "1") != "0", + } + + # Fan out + payloads = [{"args": [unsloth_model_name], "kwargs": payload_common} for _ in range(workers)] + results = run_many( + "tests.pytest.filelocks.workers:run_full", + payloads, + cwd=project_root, + env=env, + timeout=float(os.environ.get("UNSLOTH_FULL_TIMEOUT", "3600")), # protects against full hangs + ) + assert_all_ok(results) + + model_dir = artifacts / "model" + assert model_dir.exists() and any(model_dir.iterdir()), "model dir missing or empty" \ No newline at end of file diff --git a/tests/pytest/filelocks/test_import.py b/tests/pytest/filelocks/test_import.py new file mode 100644 index 000000000..a38f5964c --- /dev/null +++ b/tests/pytest/filelocks/test_import.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import os +import shutil +import stat +import subprocess +from pathlib import Path + +import pytest + + +def _which_uv() -> str: + from shutil import which + uv = which("uv") + if not uv: + pytest.skip("uv not found on PATH. Install uv first: https://docs.astral.sh/uv/") + return uv + + +def _venv_python(venv_dir: Path) -> Path: + if os.name == "nt": + return venv_dir / "Scripts" / "python.exe" + return venv_dir / "bin" / "python" + + +def _rm_tree_force(path: Path) -> None: + if not path.exists(): + return + + def _onerror(func, p, exc): + # Windows: clear read-only bits then retry + try: + os.chmod(p, stat.S_IWRITE) + func(p) + except Exception: + pass + + shutil.rmtree(path, onerror=_onerror) + + +@pytest.mark.slow +def test_uv_fresh_env_many_imports( + per_test_dir: Path, + project_root: Path, + run_many, + assert_all_ok, + workers: int, +): + """ + Creates a brand-new venv with uv, installs Unsloth into it with NO cache, + fans out many subprocesses that 'import unsloth', then deletes the env. + """ + uv = _which_uv() + vdir = per_test_dir / "uv-venv" + + _rm_tree_force(vdir) + + env = os.environ.copy() + env["UV_NO_CACHE"] = "1" + env["UV_NO_CONFIG"] = "1" + + try: + create_cmd = [uv, "--no-cache", "--no-config", "venv", str(vdir)] + proc = subprocess.run(create_cmd, capture_output=True, text=True, env=env) + if proc.returncode != 0: + pytest.fail( + "uv venv failed\n\n" + f"$ {' '.join(create_cmd)}\n\n" + f"stdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}" + ) + + py = _venv_python(vdir) + + install_cmd = [uv, "--no-cache", "--no-config", "pip", "install", "--python", str(py)] + uninstall_cmd = [uv, "pip", "uninstall", "--python", str(py)] + install_dev_cmd = [uv, "--no-cache", "--no-config", "pip", "install", "--python", str(py)] + editable = os.environ.get("UNSLOTH_EDITABLE") == "1" + spec = os.environ.get("UNSLOTH_PIP_SPEC") or ["unsloth", "xformers<=0.0.28.post3", "vllm<=0.9.1", "transformers<=4.49.0"] + uninstall_spec = ["unsloth", "unsloth_zoo"] + dev_install_spec = ["git+https://github.com/mmathew23/unsloth.git@locks", "git+https://github.com/mmathew23/unsloth_zoo.git@locks"] + + index_url = os.environ.get("UNSLOTH_INDEX_URL") + extra_index_url = os.environ.get("UNSLOTH_EXTRA_INDEX_URL") + if index_url: + install_cmd += ["--index-url", index_url] + if extra_index_url: + install_cmd += ["--extra-index-url", extra_index_url] + + pip_extra = os.environ.get("UNSLOTH_PIP_EXTRA", "") + if pip_extra: + install_cmd += pip_extra.split() + + if editable: + install_cmd += ["-e", str(project_root)] + else: + install_cmd += spec + + uninstall_cmd += uninstall_spec + install_dev_cmd += dev_install_spec + for icmd in [install_cmd, uninstall_cmd, install_dev_cmd]: + proc = subprocess.run(icmd, capture_output=True, text=True, env=env, cwd=project_root) + if proc.returncode != 0: + pytest.fail( + "uv pip install failed\n\n" + f"$ {' '.join(install_cmd)}\n\n" + f"stdout:\n{proc.stdout}\n\nstderr:\n{proc.stderr}" + ) + + payloads = [{"args": [], "kwargs": {}} for _ in range(workers)] + results = run_many( + "tests.pytest.filelocks.workers:import_unsloth", + payloads, + python=str(py), + cwd=project_root, + timeout=120.0, + ) + assert_all_ok(results) + + finally: + # 4) Always delete the environment so nothing persists on disk + _rm_tree_force(vdir) diff --git a/tests/pytest/filelocks/workers.py b/tests/pytest/filelocks/workers.py new file mode 100644 index 000000000..382b29041 --- /dev/null +++ b/tests/pytest/filelocks/workers.py @@ -0,0 +1,190 @@ +from __future__ import annotations +import time +from pathlib import Path +import os +import uuid +from typing import Optional + +def load_from_pretrained(model_name: str, load_in_4bit: bool = False) -> None: + """ + Load a small model via Unsloth. Intentionally minimal: we don't do generation; + the goal is to hit the file-locking code path and then exit. + """ + from unsloth import FastLanguageModel + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=model_name, + load_in_4bit=True, + max_seq_length=32, + dtype=None, + ) + + _ = tokenizer("hi", return_tensors="pt") + del model, tokenizer + print("load_from_pretrained: done") + + +def import_unsloth() -> None: + import unsloth # noqa: F401 + print("import_unsloth: done") + # Print version so parent can see something on stdout if it wants. + try: + import inspect + print(getattr(unsloth, "__version__", "unknown")) + except Exception: + pass + +def barrier_wait(base_dir: str | Path, name: str, nprocs: int, timeout_s: float = 600, poll_ms: int = 100) -> None: + base = Path(base_dir) / name + base.mkdir(parents=True, exist_ok=True) + token = base / f"{os.getpid()}-{uuid.uuid4().hex}.arrived" + token.write_text("") + go = base / ".go" + start = time.perf_counter() + while True: + if go.exists(): + return + arrivals = [p for p in base.iterdir() if p.is_file() and p.name != ".go"] + if len(arrivals) >= nprocs: + try: + fd = os.open(str(go), os.O_CREAT | os.O_EXCL | os.O_WRONLY) + with os.fdopen(fd, "w") as f: + f.write("ok") + return + except FileExistsError: + return + if time.perf_counter() - start > timeout_s: + raise TimeoutError( + f"Barrier '{name}' timed out after {timeout_s}s: arrivals={len(arrivals)}/{nprocs}, dir={base}" + ) + time.sleep(poll_ms / 1000.0) + + +def run_full( + model_name: str, + load_in_4bit: bool = True, + chat_template: str = "qwen3", + *, + barrier_base: Optional[str] = None, + nprocs: int = 1, + dataset_slice: str = "train[:1000]", + artifacts_dir: Optional[str] = None, + enable_push: bool = False, + hf_repo: str = "", + push_token: str = "", + enable_gguf: bool = True, +) -> None: + import os + os.environ["UNSLOTH_LOGGING_ENABLED"] = "1" + from unsloth import FastLanguageModel + if barrier_base: + barrier_wait(barrier_base, "after_import_unsloth", nprocs) + + try: + import torch + has_cuda = bool(torch.cuda.is_available()) + except Exception: + has_cuda = False + if load_in_4bit and not has_cuda: + raise RuntimeError("load_in_4bit=True but no CUDA device available") + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=model_name, + load_in_4bit=bool(load_in_4bit), + max_seq_length=256, + dtype=None, + ) + + model = FastLanguageModel.get_peft_model( + model, + r=2, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + lora_alpha=2, + lora_dropout=0, + bias="none", + use_gradient_checkpointing="unsloth", + random_state=3407, + use_rslora=False, + loftq_config=None, + ) + + if barrier_base: + barrier_wait(barrier_base, "before_get_chat_template", nprocs) + + from unsloth.chat_templates import get_chat_template, standardize_data_formats + tokenizer = get_chat_template(tokenizer, chat_template=chat_template) + + from datasets import load_dataset + dataset = load_dataset("mlabonne/FineTome-100k", split=dataset_slice) + dataset = standardize_data_formats(dataset) + + def formatting_prompts_func(examples): + convos = examples["conversations"] + texts = [ + tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) + for convo in convos + ] + return {"text": texts} + + dataset = dataset.map(formatting_prompts_func, batched=True) + + try: + import bitsandbytes as _bnb # noqa: F401 + optim_name = "adamw_8bit" + except Exception: + optim_name = "adamw_torch" + + from trl import SFTTrainer, SFTConfig + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, + eval_dataset=None, + args=SFTConfig( + dataset_text_field="text", + per_device_train_batch_size=1, + gradient_accumulation_steps=1, + warmup_steps=5, + max_steps=5, + learning_rate=2e-4, + logging_steps=1, + optim=optim_name, + weight_decay=0.01, + lr_scheduler_type="linear", + seed=3407, + report_to="none", + ), + ) + trainer.train() + + save_root = Path(artifacts_dir or ".").resolve() + save_root.mkdir(parents=True, exist_ok=True) + target_dir = save_root / "model" + + if barrier_base: + barrier_wait(barrier_base, "before_save_pretrained_merged", nprocs) + model.save_pretrained_merged(str(target_dir), tokenizer, save_method="merged_16bit") + + if barrier_base: + barrier_wait(barrier_base, "before_push_to_hub_merged", nprocs) + if enable_push and hf_repo and push_token: + model.push_to_hub_merged(hf_repo, tokenizer, save_method="merged_16bit", token=push_token) + else: + print("Skipping push_to_hub_merged") + + if barrier_base: + barrier_wait(barrier_base, "before_save_pretrained_gguf", nprocs) + if enable_gguf: + model.save_pretrained_gguf(str(target_dir), tokenizer, quantization_method="q4_k_m") + else: + print("Skipping GGUF save") + + if barrier_base: + barrier_wait(barrier_base, "before_push_to_hub_gguf", nprocs) + if enable_push and hf_repo and push_token and enable_gguf: + model.push_to_hub_gguf(hf_repo, tokenizer, quantization_method="q4_k_m", token=push_token) + else: + print("Skipping push_to_hub_gguf") + + del model, tokenizer + print("run_full: done") diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 99d651ae5..b6742f753 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -68,6 +68,10 @@ except Exception as exception: raise exception pass +os.environ["UNSLOTH_ZOO_UTILS_ONLY"] = "1" +import unsloth_zoo.utils +# we do this to make compile folder and locks available to modules that need +# it earlier than the regular unsloth_zoo import @functools.cache def is_hip(): @@ -273,6 +277,14 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo") # except: # raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`") + if os.environ.get("UNSLOTH_ZOO_IS_PRESENT", "0") == "0": + try: + # earlier unsloth_zoo.utils import doesn't fully import + # unsloth_zoo so delete it so it can reload + del sys.modules["unsloth_zoo"] + except: + pass + os.environ["UNSLOTH_ZOO_UTILS_ONLY"] = "0" import unsloth_zoo except: raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo`") diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index e70c6b50a..e5a1d1331 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -32,6 +32,7 @@ delete_vllm, ) from unsloth_zoo.log import logger +from unsloth_zoo.utils import get_lock import numpy as np from .synthetic_configs import ( @@ -343,32 +344,40 @@ def chunk_data(self, filename = None): if not hasattr(self, "overlap") or not hasattr(self, "max_generation_tokens"): raise RuntimeError("Please use prepare_qa_generation first!") - with open(filename, "r", encoding = "utf-8") as f: text = f.read() - - max_tokens = self.max_seq_length - self.max_generation_tokens*2 - 128 # -128 to reduce errors - if max_tokens <= 5: - raise RuntimeError("Generation length is way too long!") - input_ids = self.tokenizer(text, add_special_tokens = False).input_ids - - # Get left and right boundaries - length = len(input_ids) - n_chunks = int(np.ceil(length / (max_tokens - self.overlap))) - boundaries = np.ceil(np.linspace(0, length - self.overlap, n_chunks)).astype(int) - boundaries = np.stack((boundaries[:-1], (boundaries + self.overlap)[1:])).T - boundaries = np.minimum(boundaries, length).tolist() - - # Get extension of filename like .txt - filename, extension = os.path.splitext(filename) - if filename.endswith("/"): filename = filename[:-1] - - all_filenames = [] - for i, (left, right) in enumerate(boundaries): - chunked_text = self.tokenizer.decode(input_ids[left : right]) - new_filename = f"{filename}_{i}{extension}" - all_filenames.append(new_filename) - with open(new_filename, "w", encoding = "utf-8") as f: f.write(chunked_text) + lock = get_lock(filename, timeout = 20) + try: + with lock: + with open(filename, "r", encoding = "utf-8") as f: text = f.read() + + max_tokens = self.max_seq_length - self.max_generation_tokens*2 - 128 # -128 to reduce errors + if max_tokens <= 5: + raise RuntimeError("Generation length is way too long!") + input_ids = self.tokenizer(text, add_special_tokens = False).input_ids + + # Get left and right boundaries + length = len(input_ids) + n_chunks = int(np.ceil(length / (max_tokens - self.overlap))) + boundaries = np.ceil(np.linspace(0, length - self.overlap, n_chunks)).astype(int) + boundaries = np.stack((boundaries[:-1], (boundaries + self.overlap)[1:])).T + boundaries = np.minimum(boundaries, length).tolist() + + # Get extension of filename like .txt + filename, extension = os.path.splitext(filename) + if filename.endswith("/"): filename = filename[:-1] + + all_filenames = [] + for i, (left, right) in enumerate(boundaries): + chunked_text = self.tokenizer.decode(input_ids[left : right]) + new_filename = f"{filename}_{i}{extension}" + all_filenames.append(new_filename) + with open(new_filename, "w", encoding = "utf-8") as f: f.write(chunked_text) + pass + return all_filenames + except Exception as e: + if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": + logger.error(f"Unsloth: Failed to chunk data because {str(e)}") + pass pass - return all_filenames pass def prepare_qa_generation( @@ -408,7 +417,15 @@ def prepare_qa_generation( .replace("{cleanup_batch_size}", str(cleanup_batch_size))\ .replace("{cleanup_temperature}", str(cleanup_temperature)) - with open("synthetic_data_kit_config.yaml", "w", encoding = "utf-8") as f: f.write(config) + lock = get_lock("synthetic_data_kit_config.yaml") + try: + with lock: + with open("synthetic_data_kit_config.yaml", "w", encoding = "utf-8") as f: f.write(config) + except Exception as e: + if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": + logger.error(f"Unsloth: Failed to write synthetic_data_kit_config.yaml because {str(e)}") + pass + pass self.overlap = overlap pass diff --git a/unsloth/import_fixes.py b/unsloth/import_fixes.py index 4deb0deb5..bd6e958f4 100644 --- a/unsloth/import_fixes.py +++ b/unsloth/import_fixes.py @@ -62,6 +62,7 @@ def GetPrototype(self, descriptor): # Fix Xformers performance issues since 0.0.25 def fix_xformers_performance_issue(): + from unsloth_zoo.utils import get_lock if importlib.util.find_spec("xformers") is None: return xformers_version = importlib_version("xformers") if Version(xformers_version) < Version("0.0.29"): @@ -70,19 +71,21 @@ def fix_xformers_performance_issue(): cutlass = Path(xformers_location) / "ops" / "fmha" / "cutlass.py" try: if cutlass.exists(): - with open(cutlass, "r+", encoding = "utf-8") as f: - text = f.read() - # See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591 - if "num_splits_key=-1," in text: - text = text.replace( - "num_splits_key=-1,", - "num_splits_key=None,", - ) - f.seek(0) - f.write(text) - f.truncate() - if UNSLOTH_ENABLE_LOGGING: - print("Unsloth: Patching Xformers to fix some performance issues.") + lock = get_lock(str(cutlass)) + with lock: + with open(cutlass, "r+", encoding = "utf-8") as f: + text = f.read() + # See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591 + if "num_splits_key=-1," in text: + text = text.replace( + "num_splits_key=-1,", + "num_splits_key=None,", + ) + f.seek(0) + f.write(text) + f.truncate() + if UNSLOTH_ENABLE_LOGGING: + print("Unsloth: Patching Xformers to fix some performance issues.") except Exception as e: if UNSLOTH_ENABLE_LOGGING: print(f"Unsloth: Failed patching Xformers with error = {str(e)}") @@ -90,6 +93,7 @@ def fix_xformers_performance_issue(): # ValueError: 'aimv2' is already used by a Transformers config, pick another name. def fix_vllm_aimv2_issue(): + from unsloth_zoo.utils import get_lock if importlib.util.find_spec("vllm") is None: return vllm_version = importlib_version("vllm") if Version(vllm_version) < Version("0.10.1"): @@ -98,29 +102,31 @@ def fix_vllm_aimv2_issue(): ovis_config = Path(vllm_version) / "transformers_utils" / "configs" / "ovis.py" try: if ovis_config.exists(): - with open(ovis_config, "r+", encoding = "utf-8") as f: - text = f.read() + lock = get_lock(ovis_config) + with lock: + with open(ovis_config, "r+", encoding = "utf-8") as f: + text = f.read() # See https://github.com/vllm-project/vllm-ascend/issues/2046 - if 'AutoConfig.register("aimv2", AIMv2Config)' in text: - text = text.replace( - 'AutoConfig.register("aimv2", AIMv2Config)', - '', - ) - text = text.replace( - '''backbone_config.pop('model_type') - backbone_config = AutoConfig.for_model(model_type, - **backbone_config)''', - '''if model_type != "aimv2": - backbone_config.pop('model_type') - backbone_config = AutoConfig.for_model(model_type, **backbone_config) - else: - backbone_config = AIMv2Config(**backbone_config)''' - ) - f.seek(0) - f.write(text) - f.truncate() - if UNSLOTH_ENABLE_LOGGING: - print("Unsloth: Patching vLLM to fix `'aimv2' is already used by a Transformers config, pick another name.`") + if 'AutoConfig.register("aimv2", AIMv2Config)' in text: + text = text.replace( + 'AutoConfig.register("aimv2", AIMv2Config)', + '', + ) + text = text.replace( + '''backbone_config.pop('model_type') + backbone_config = AutoConfig.for_model(model_type, + **backbone_config)''', + '''if model_type != "aimv2": + backbone_config.pop('model_type') + backbone_config = AutoConfig.for_model(model_type, **backbone_config) + else: + backbone_config = AIMv2Config(**backbone_config)''' + ) + f.seek(0) + f.write(text) + f.truncate() + if UNSLOTH_ENABLE_LOGGING: + print("Unsloth: Patching vLLM to fix `'aimv2' is already used by a Transformers config, pick another name.`") except Exception as e: if UNSLOTH_ENABLE_LOGGING: print(f"Unsloth: Failed patching vLLM with error = {str(e)}") diff --git a/unsloth/save.py b/unsloth/save.py index 506c8a68f..3bcb2910e 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -16,6 +16,7 @@ from importlib.metadata import version as importlib_version from unsloth_zoo.hf_utils import dtype_from_config, HAS_TORCH_DTYPE from unsloth_zoo.llama_cpp import convert_to_gguf, quantize_gguf, use_local_gguf, install_llama_cpp, check_llama_cpp, _download_convert_hf_to_gguf +from unsloth_zoo.utils import get_lock from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit from peft.tuners.lora import Linear4bit as Peft_Linear4bit from peft.tuners.lora import Linear as Peft_Linear @@ -1070,130 +1071,140 @@ def save_to_gguf( print(print_info) # Step 1: Ensure llama.cpp is installed + lock = get_lock("llama.cpp", timeout=-1) try: - quantizer_location, converter_location = check_llama_cpp() - print("Unsloth: llama.cpp found in the system. Skipping installation.") - except: - print("Unsloth: Installing llama.cpp. This might take 3 minutes...") - if IS_KAGGLE_ENVIRONMENT: - # Kaggle: no CUDA support due to environment limitations - quantizer_location, converter_location = install_llama_cpp( - gpu_support=False, - print_output=print_output - ) - else: - quantizer_location, converter_location = install_llama_cpp( - gpu_support=False, # GGUF conversion doesn't need CUDA - print_output=print_output - ) + with lock: + try: + quantizer_location, converter_location = check_llama_cpp() + print("Unsloth: llama.cpp found in the system. Skipping installation.") + except: + print("Unsloth: Installing llama.cpp. This might take 3 minutes...") + if IS_KAGGLE_ENVIRONMENT: + # Kaggle: no CUDA support due to environment limitations + quantizer_location, converter_location = install_llama_cpp( + gpu_support=False, + print_output=print_output + ) + else: + quantizer_location, converter_location = install_llama_cpp( + gpu_support=False, # GGUF conversion doesn't need CUDA + print_output=print_output + ) + except Exception as e: + logger.error(f"Unsloth: Error installing llama.cpp: {e}", exc_info=True) # Step 2: Download and patch converter script - print("Unsloth: Preparing converter script...") - with use_local_gguf(): - converter_path, supported_text_archs, supported_vision_archs = _download_convert_hf_to_gguf() - - # Step 3: Initial GGUF conversion - print(f"Unsloth: [1] Converting model into {first_conversion_dtype} GGUF format.") - print(f"This might take 3 minutes...") - - initial_files, is_vlm_update = convert_to_gguf( - model_name=model_name, - input_folder=model_directory, - model_dtype = model_dtype, - quantization_type=first_conversion, - converter_location=converter_path, - supported_text_archs=supported_text_archs, - supported_vision_archs=supported_vision_archs, - is_vlm=is_vlm, - is_gpt_oss=is_gpt_oss, - max_shard_size="50GB", - print_output=print_output, - ) - # update is_vlm switch - is_vlm = is_vlm_update - # Check conversion success - for file in initial_files: - if not os.path.exists(file): - if IS_KAGGLE_ENVIRONMENT: - raise RuntimeError( - f"Unsloth: Conversion failed for {file}\n" - "You are in a Kaggle environment with limited disk space (20GB).\n" - "Try saving to /tmp for more space or use a smaller model.\n" - "Alternatively, save the 16bit model first, then convert manually." - ) - else: - raise RuntimeError( - f"Unsloth: Conversion failed for {file}\n" - "Please check disk space and try again." + save_lock = get_lock(f"save_to_gguf_{model_name}", timeout=-1) + try: + with save_lock: + print("Unsloth: Preparing converter script...") + with use_local_gguf(): + converter_path, supported_text_archs, supported_vision_archs = _download_convert_hf_to_gguf() + + # Step 3: Initial GGUF conversion + print(f"Unsloth: [1] Converting model into {first_conversion_dtype} GGUF format.") + print(f"This might take 3 minutes...") + + initial_files, is_vlm_update = convert_to_gguf( + model_name=model_name, + input_folder=model_directory, + model_dtype = model_dtype, + quantization_type=first_conversion, + converter_location=converter_path, + supported_text_archs=supported_text_archs, + supported_vision_archs=supported_vision_archs, + is_vlm=is_vlm, + is_gpt_oss=is_gpt_oss, + max_shard_size="50GB", + print_output=print_output, ) - - print(f"Unsloth: Initial conversion completed! Files: {initial_files}") - - # Step 4: Additional quantizations using llama-quantize - all_saved_locations = initial_files.copy() - - # Get CPU count for quantization - n_cpus = psutil.cpu_count() - if n_cpus is None: n_cpus = 1 - n_cpus *= 2 - - if not is_gpt_oss: - base_gguf = initial_files[0] - quants_created = False - for quant_method in quantization_method: - if quant_method != first_conversion: - print(f"Unsloth: [2] Converting GGUF {first_conversion_dtype} into {quant_method}. This might take 10 minutes...") - output_location = f"{model_name}.{quant_method.upper()}.gguf" - - try: - # Use the quantize_gguf function we created - quantized_file = quantize_gguf( - input_gguf=base_gguf, - output_gguf=output_location, - quant_type=quant_method, - quantizer_location=quantizer_location, - print_output=print_output - ) - all_saved_locations.append(quantized_file) - quants_created = True - except Exception as e: + # update is_vlm switch + is_vlm = is_vlm_update + # Check conversion success + for file in initial_files: + if not os.path.exists(file): if IS_KAGGLE_ENVIRONMENT: raise RuntimeError( - f"Unsloth: Quantization failed for {output_location}\n"\ - "You are in a Kaggle environment, which might be the reason this is failing.\n"\ - "Kaggle only provides 20GB of disk space in the working directory.\n"\ - "Merging to 16bit for 7b models use 16GB of space.\n"\ - "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\ - "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\ - "You can try saving it to the `/tmp` directory for larger disk space.\n"\ - "I suggest you to save the 16bit model first, then use manual llama.cpp conversion.\n"\ - "Error: {e}" + f"Unsloth: Conversion failed for {file}\n" + "You are in a Kaggle environment with limited disk space (20GB).\n" + "Try saving to /tmp for more space or use a smaller model.\n" + "Alternatively, save the 16bit model first, then convert manually." ) else: raise RuntimeError( - f"Unsloth: Quantization failed for {output_location}\n"\ - "You might have to compile llama.cpp yourself, then run this again.\n"\ - "You do not need to close this Python program. Run the following commands in a new terminal:\n"\ - "You must run this in the same folder as you're saving your model.\n"\ - "git clone --recursive https://github.com/ggerganov/llama.cpp\n"\ - "cd llama.cpp && make clean && make all -j\n"\ - "Once that's done, redo the quantization.\n"\ - "Error: {e}" + f"Unsloth: Conversion failed for {file}\n" + "Please check disk space and try again." ) + + print(f"Unsloth: Initial conversion completed! Files: {initial_files}") + + # Step 4: Additional quantizations using llama-quantize + all_saved_locations = initial_files.copy() + + # Get CPU count for quantization + n_cpus = psutil.cpu_count() + if n_cpus is None: n_cpus = 1 + n_cpus *= 2 + + if not is_gpt_oss: + base_gguf = initial_files[0] + quants_created = False + for quant_method in quantization_method: + if quant_method != first_conversion: + print(f"Unsloth: [2] Converting GGUF {first_conversion_dtype} into {quant_method}. This might take 10 minutes...") + output_location = f"{model_name}.{quant_method.upper()}.gguf" + + try: + # Use the quantize_gguf function we created + quantized_file = quantize_gguf( + input_gguf=base_gguf, + output_gguf=output_location, + quant_type=quant_method, + quantizer_location=quantizer_location, + print_output=print_output + ) + all_saved_locations.append(quantized_file) + quants_created = True + except Exception as e: + if IS_KAGGLE_ENVIRONMENT: + raise RuntimeError( + f"Unsloth: Quantization failed for {output_location}\n"\ + "You are in a Kaggle environment, which might be the reason this is failing.\n"\ + "Kaggle only provides 20GB of disk space in the working directory.\n"\ + "Merging to 16bit for 7b models use 16GB of space.\n"\ + "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\ + "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\ + "You can try saving it to the `/tmp` directory for larger disk space.\n"\ + "I suggest you to save the 16bit model first, then use manual llama.cpp conversion.\n"\ + "Error: {e}" + ) + else: + raise RuntimeError( + f"Unsloth: Quantization failed for {output_location}\n"\ + "You might have to compile llama.cpp yourself, then run this again.\n"\ + "You do not need to close this Python program. Run the following commands in a new terminal:\n"\ + "You must run this in the same folder as you're saving your model.\n"\ + "git clone --recursive https://github.com/ggerganov/llama.cpp\n"\ + "cd llama.cpp && make clean && make all -j\n"\ + "Once that's done, redo the quantization.\n"\ + "Error: {e}" + ) + pass + pass pass pass - pass - pass - print("Unsloth: Model files cleanup...") - if quants_created: - all_saved_locations.remove(base_gguf) - Path(base_gguf).unlink() + print("Unsloth: Model files cleanup...") + if quants_created: + all_saved_locations.remove(base_gguf) + Path(base_gguf).unlink() - # flip the list to get [text_model, mmproj] order. for text models stays the same. - all_saved_locations.reverse() - else: - print("Unsloth: GPT-OSS model - skipping additional quantizations") - pass + # flip the list to get [text_model, mmproj] order. for text models stays the same. + all_saved_locations.reverse() + else: + print("Unsloth: GPT-OSS model - skipping additional quantizations") + pass + except Exception as e: + logger.error(f"Unsloth: Error saving to GGUF: {e}", exc_info=True) if is_gpt_oss: want_full_precision = True @@ -2180,13 +2191,18 @@ def save_lora_to_custom_dir(model, tokenizer, save_directory): os.makedirs(save_directory, exist_ok=True) # Call the unsloth_save_model function with the custom directory - unsloth_save_model( - model, - tokenizer, - save_directory=save_directory, - save_method="lora", - push_to_hub=False, - ) + save_lock = get_lock(save_directory, timeout=-1) + try: + with save_lock: + unsloth_save_model( + model, + tokenizer, + save_directory=save_directory, + save_method="lora", + push_to_hub=False, + ) + except Exception as e: + logger.error(f"Unsloth: Error saving LoRA to custom directory: {e}", exc_info=True) # Corrected method within the model class to convert LoRA to GGML and push to Hugging Face Hub def unsloth_convert_lora_to_ggml_and_push_to_hub( @@ -2203,59 +2219,69 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub( temporary_location: str = "_unsloth_temporary_saved_buffers", maximum_memory_usage: float = 0.85, ): - if not os.path.exists("llama.cpp"): - if IS_KAGGLE_ENVIRONMENT: - python_install = install_python_non_blocking(["protobuf"]) - python_install.wait() - install_llama_cpp_blocking(use_cuda=False) - makefile = None - else: - git_clone = install_llama_cpp_clone_non_blocking() - python_install = install_python_non_blocking(["protobuf"]) - git_clone.wait() - makefile = install_llama_cpp_make_non_blocking() - python_install.wait() - else: - makefile = None + llama_lock = get_lock("llama.cpp", timeout=-1) + try: + with llama_lock: + if not os.path.exists("llama.cpp"): + if IS_KAGGLE_ENVIRONMENT: + python_install = install_python_non_blocking(["protobuf"]) + python_install.wait() + install_llama_cpp_blocking(use_cuda=False) + makefile = None + else: + git_clone = install_llama_cpp_clone_non_blocking() + python_install = install_python_non_blocking(["protobuf"]) + git_clone.wait() + makefile = install_llama_cpp_make_non_blocking() + python_install.wait() + else: + makefile = None + except Exception as e: + logger.error(f"Unsloth: Error installing llama.cpp: {e}", exc_info=True) for _ in range(3): gc.collect() - lora_directory_push = "lora-to-ggml-push" - save_lora_to_custom_dir(self, tokenizer, lora_directory_push) + lora_lock = get_lock("lora-to-ggml-push", timeout=-1) + try: + with lora_lock: + lora_directory_push = "lora-to-ggml-push" + save_lora_to_custom_dir(self, tokenizer, lora_directory_push) - model_type = self.config.model_type - output_file = os.path.join(lora_directory_push, "ggml-adapter-model.bin") + model_type = self.config.model_type + output_file = os.path.join(lora_directory_push, "ggml-adapter-model.bin") - print(f"Unsloth: Converting auto-saved LoRA adapters at {lora_directory_push} to GGML format.") - print(f"The output file will be {output_file}") + print(f"Unsloth: Converting auto-saved LoRA adapters at {lora_directory_push} to GGML format.") + print(f"The output file will be {output_file}") - command = f"python3 llama.cpp/convert-lora-to-ggml.py {lora_directory_push} {output_file} llama" + command = f"python3 llama.cpp/convert-lora-to-ggml.py {lora_directory_push} {output_file} llama" - try: - with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp: - for line in sp.stdout: - print(line, end="", flush=True) - for line in sp.stderr: - print(line, end="", flush=True) - sp.wait() - if sp.returncode != 0: - raise subprocess.CalledProcessError(sp.returncode, command) - except subprocess.CalledProcessError as e: - print(f"Error: Conversion failed with return code {e.returncode}") - return - - print(f"Unsloth: Conversion completed! Output file: {output_file}") - - print("Unsloth: Uploading GGML file to Hugging Face Hub...") - username = upload_to_huggingface( - self, repo_id, token, - "GGML converted LoRA", "ggml", output_file, None, private, - ) - link = f"{repo_id.lstrip('/')}" - print("Unsloth: Done.") - print(f"Converted LoRA to GGML and uploaded to https://huggingface.co/{link}") - print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!") + try: + with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp: + for line in sp.stdout: + print(line, end="", flush=True) + for line in sp.stderr: + print(line, end="", flush=True) + sp.wait() + if sp.returncode != 0: + raise subprocess.CalledProcessError(sp.returncode, command) + except subprocess.CalledProcessError as e: + print(f"Error: Conversion failed with return code {e.returncode}") + return + + print(f"Unsloth: Conversion completed! Output file: {output_file}") + + print("Unsloth: Uploading GGML file to Hugging Face Hub...") + username = upload_to_huggingface( + self, repo_id, token, + "GGML converted LoRA", "ggml", output_file, None, private, + ) + link = f"{repo_id.lstrip('/')}" + print("Unsloth: Done.") + print(f"Converted LoRA to GGML and uploaded to https://huggingface.co/{link}") + print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!") + except Exception as e: + logger.error(f"Unsloth: Error converting LoRA to GGML and pushing to Hugging Face Hub: {e}", exc_info=True) def unsloth_convert_lora_to_ggml_and_save_locally( self, @@ -2264,50 +2290,60 @@ def unsloth_convert_lora_to_ggml_and_save_locally( temporary_location: str = "_unsloth_temporary_saved_buffers", maximum_memory_usage: float = 0.85, ): - if not os.path.exists("llama.cpp"): - if IS_KAGGLE_ENVIRONMENT: - python_install = install_python_non_blocking(["protobuf"]) - python_install.wait() - install_llama_cpp_blocking(use_cuda=False) - makefile = None - else: - git_clone = install_llama_cpp_clone_non_blocking() - python_install = install_python_non_blocking(["protobuf"]) - git_clone.wait() - makefile = install_llama_cpp_make_non_blocking() - python_install.wait() - else: - makefile = None + llama_lock = get_lock("llama.cpp", timeout=-1) + try: + with llama_lock: + if not os.path.exists("llama.cpp"): + if IS_KAGGLE_ENVIRONMENT: + python_install = install_python_non_blocking(["protobuf"]) + python_install.wait() + install_llama_cpp_blocking(use_cuda=False) + makefile = None + else: + git_clone = install_llama_cpp_clone_non_blocking() + python_install = install_python_non_blocking(["protobuf"]) + git_clone.wait() + makefile = install_llama_cpp_make_non_blocking() + python_install.wait() + else: + makefile = None + except Exception as e: + logger.error(f"Unsloth: Error installing llama.cpp: {e}", exc_info=True) for _ in range(3): gc.collect() - # Use the provided save_directory for local saving - save_lora_to_custom_dir(self, tokenizer, save_directory) + lora_lock = get_lock("lora-to-ggml-save", timeout=-1) + try: + with lora_lock: + # Use the provided save_directory for local saving + save_lora_to_custom_dir(self, tokenizer, save_directory) - model_type = self.config.model_type - output_file = os.path.join(save_directory, "ggml-adapter-model.bin") + model_type = self.config.model_type + output_file = os.path.join(save_directory, "ggml-adapter-model.bin") - print(f"Unsloth: Converting auto-saved LoRA adapters at {save_directory} to GGML format.") - print(f"The output file will be {output_file}") + print(f"Unsloth: Converting auto-saved LoRA adapters at {save_directory} to GGML format.") + print(f"The output file will be {output_file}") - command = f"python3 llama.cpp/convert-lora-to-ggml.py {save_directory} {output_file} llama" + command = f"python3 llama.cpp/convert-lora-to-ggml.py {save_directory} {output_file} llama" - try: - with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp: - for line in sp.stdout: - print(line, end="", flush=True) - for line in sp.stderr: - print(line, end="", flush=True) - sp.wait() - if sp.returncode != 0: - raise subprocess.CalledProcessError(sp.returncode, command) - except subprocess.CalledProcessError as e: - print(f"Error: Conversion failed with return code {e.returncode}") - return - print("Unsloth: Done.") - print(f"Unsloth: Conversion completed! Output file: {output_file}") - print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!") + try: + with subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1, universal_newlines=True) as sp: + for line in sp.stdout: + print(line, end="", flush=True) + for line in sp.stderr: + print(line, end="", flush=True) + sp.wait() + if sp.returncode != 0: + raise subprocess.CalledProcessError(sp.returncode, command) + except subprocess.CalledProcessError as e: + print(f"Error: Conversion failed with return code {e.returncode}") + return + print("Unsloth: Done.") + print(f"Unsloth: Conversion completed! Output file: {output_file}") + print("\nThis GGML making function was made by Maheswar. Ping him @Maheswar on the Unsloth Discord or on HuggingFace (@mahiatlinux) if you like this!") + except Exception as e: + logger.error(f"Unsloth: Error converting LoRA to GGML and saving locally: {e}", exc_info=True) pass @@ -2335,9 +2371,14 @@ def save_to_gguf_generic( if repo_id is not None and token is None: raise RuntimeError("Unsloth: Please specify a token for uploading!") - if not os.path.exists(os.path.join("llama.cpp", "unsloth_convert_hf_to_gguf.py")): - install_llama_cpp(just_clone_repo = True) - pass + llama_lock = get_lock("llama.cpp", timeout=-1) + try: + with llama_lock: + if not os.path.exists(os.path.join("llama.cpp", "unsloth_convert_hf_to_gguf.py")): + install_llama_cpp(just_clone_repo = True) + pass + except Exception as e: + logger.error(f"Unsloth: Error installing llama.cpp: {e}", exc_info=True) # Use old style quantization_method new_quantization_methods = [] @@ -2371,32 +2412,37 @@ def save_to_gguf_generic( # Go through all types and save individually - somewhat inefficient # since we save F16 / BF16 multiple times - for quantization_type in new_quantization_methods: - metadata = _convert_to_gguf( - save_directory, - print_output = True, - quantization_type = quantization_type, - ) - if repo_id is not None: - prepare_saving( - model, - repo_id, - push_to_hub = True, - max_shard_size = "50GB", - private = True, - token = token, - ) + convert_lock = get_lock("save_to_gguf_generic", timeout=-1) + try: + with convert_lock: + for quantization_type in new_quantization_methods: + metadata = _convert_to_gguf( + save_directory, + print_output = True, + quantization_type = quantization_type, + ) + if repo_id is not None: + prepare_saving( + model, + repo_id, + push_to_hub = True, + max_shard_size = "50GB", + private = True, + token = token, + ) - from huggingface_hub import HfApi - api = HfApi(token = token) - api.upload_folder( - folder_path = save_directory, - repo_id = repo_id, - repo_type = "model", - allow_patterns = ["*.gguf"], - ) - pass - pass + from huggingface_hub import HfApi + api = HfApi(token = token) + api.upload_folder( + folder_path = save_directory, + repo_id = repo_id, + repo_type = "model", + allow_patterns = ["*.gguf"], + ) + pass + pass + except Exception as e: + logger.error(f"Unsloth: Error saving to GGUF: {e}", exc_info=True) return metadata pass diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 067f2596c..064025c55 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -34,6 +34,7 @@ from unsloth_zoo.training_utils import ( fix_zero_training_loss, ) +from unsloth_zoo.utils import get_lock __all__ = [ "load_correct_tokenizer", @@ -347,56 +348,62 @@ def fix_sentencepiece_tokenizer( # We need to manually edit the sentencepiece tokenizer! from transformers.utils import sentencepiece_model_pb2 - if not os.path.exists(temporary_location): - os.makedirs(temporary_location) - pass + os.makedirs(temporary_location, exist_ok = True) + lock = get_lock(f"{temporary_location}/tokenizer.model", timeout = 20) + try: + with lock: - # Check if tokenizer.model exists - if not os.path.isfile(f"{temporary_location}/tokenizer.model"): - return new_tokenizer - pass + # Check if tokenizer.model exists + if not os.path.isfile(f"{temporary_location}/tokenizer.model"): + return new_tokenizer + pass - # First save the old tokenizer - old_tokenizer.save_pretrained(temporary_location) + # First save the old tokenizer + old_tokenizer.save_pretrained(temporary_location) - tokenizer_file = sentencepiece_model_pb2.ModelProto() - tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read()) + tokenizer_file = sentencepiece_model_pb2.ModelProto() + tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read()) - # Now save the new tokenizer - new_tokenizer.save_pretrained(temporary_location) + # Now save the new tokenizer + new_tokenizer.save_pretrained(temporary_location) - # Now correct the old tokenizer's .model file - for old_token, new_token in token_mapping.items(): - ids = old_tokenizer([old_token], add_special_tokens = False).input_ids - ids = ids[0] - if (len(ids) != 1): - # Skip this token! - print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!") - continue - pass - ids = ids[0] - # [TODO] Hack for Starling - try except - try: - tokenizer_piece = tokenizer_file.pieces[ids] - except: - continue - assert(tokenizer_piece.piece == old_token) - tokenizer_piece.piece = new_token - pass + # Now correct the old tokenizer's .model file + for old_token, new_token in token_mapping.items(): + ids = old_tokenizer([old_token], add_special_tokens = False).input_ids + ids = ids[0] + if (len(ids) != 1): + # Skip this token! + print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!") + continue + pass + ids = ids[0] + # [TODO] Hack for Starling - try except + try: + tokenizer_piece = tokenizer_file.pieces[ids] + except: + continue + assert(tokenizer_piece.piece == old_token) + tokenizer_piece.piece = new_token + pass - # And now write it - with open(f"{temporary_location}/tokenizer.model", "wb") as file: - file.write(tokenizer_file.SerializeToString()) - pass + # And now write it + with open(f"{temporary_location}/tokenizer.model", "wb") as file: + file.write(tokenizer_file.SerializeToString()) + pass - # And load it! - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained( - temporary_location, - eos_token = new_tokenizer.eos_token, - pad_token = new_tokenizer.pad_token, - ) - return tokenizer + # And load it! + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained( + temporary_location, + eos_token = new_tokenizer.eos_token, + pad_token = new_tokenizer.pad_token, + ) + return tokenizer + pass + except Exception as e: + if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": + logger.error(f"Unsloth: Failed to fix sentencepiece tokenizer because {str(e)}") + pass pass @@ -421,50 +428,57 @@ class SentencePieceTokenTypes(IntEnum): pass # Load tokenizer.model - tokenizer_file = sentencepiece_model_pb2.ModelProto() - if not os.path.isfile(f"{saved_location}/tokenizer.model"): return - tokenizer_file.ParseFromString(open(f"{saved_location}/tokenizer.model", "rb").read()) - sentence_piece_size = len(tokenizer_file.pieces) - - # Load added_tokens_json - if not os.path.isfile(f"{saved_location}/added_tokens.json"): return - with open(f"{saved_location}/added_tokens.json", "r", encoding = "utf-8") as file: - added_tokens_json = json.load(file) - pass - if len(added_tokens_json) == 0: return - - added_tokens_json = dict(sorted(added_tokens_json.items(), key = lambda item: item[1])) - new_size = sentence_piece_size + len(added_tokens_json) - - # Confirm added_tokens_json is correct - added_tokens_ids = np.array(list(added_tokens_json.values())) - diff = np.diff(added_tokens_ids) - if (diff.min() != 1 or diff.max() != 1): return - if (added_tokens_ids.min() != sentence_piece_size): return - - # Edit sentence piece tokens with added_tokens_json - logger.warning( - f"Unsloth: Extending {saved_location}/tokenizer.model with added_tokens.json.\n"\ - f"Originally tokenizer.model is of size ({sentence_piece_size}).\n"\ - f"But we need to extend to sentencepiece vocab size ({new_size})." - ) - new_tokens = deepcopy(tokenizer_file.pieces[-len(added_tokens_ids):]) - for new_token, added_token in zip(new_tokens, added_tokens_json.keys()): - new_token.piece = added_token.encode("utf-8") - new_token.score = -1000.0 - new_token.type = SentencePieceTokenTypes.USER_DEFINED - pass + lock = get_lock(f"{saved_location}/tokenizer.model", timeout = 20) + try: + with lock: + tokenizer_file = sentencepiece_model_pb2.ModelProto() + if not os.path.isfile(f"{saved_location}/tokenizer.model"): return + tokenizer_file.ParseFromString(open(f"{saved_location}/tokenizer.model", "rb").read()) + sentence_piece_size = len(tokenizer_file.pieces) + + # Load added_tokens_json + if not os.path.isfile(f"{saved_location}/added_tokens.json"): return + with open(f"{saved_location}/added_tokens.json", "r", encoding = "utf-8") as file: + added_tokens_json = json.load(file) + pass + if len(added_tokens_json) == 0: return + + added_tokens_json = dict(sorted(added_tokens_json.items(), key = lambda item: item[1])) + new_size = sentence_piece_size + len(added_tokens_json) + + # Confirm added_tokens_json is correct + added_tokens_ids = np.array(list(added_tokens_json.values())) + diff = np.diff(added_tokens_ids) + if (diff.min() != 1 or diff.max() != 1): return + if (added_tokens_ids.min() != sentence_piece_size): return + + # Edit sentence piece tokens with added_tokens_json + logger.warning( + f"Unsloth: Extending {saved_location}/tokenizer.model with added_tokens.json.\n"\ + f"Originally tokenizer.model is of size ({sentence_piece_size}).\n"\ + f"But we need to extend to sentencepiece vocab size ({new_size})." + ) + new_tokens = deepcopy(tokenizer_file.pieces[-len(added_tokens_ids):]) + for new_token, added_token in zip(new_tokens, added_tokens_json.keys()): + new_token.piece = added_token.encode("utf-8") + new_token.score = -1000.0 + new_token.type = SentencePieceTokenTypes.USER_DEFINED + pass - tokenizer_file.pieces.extend(new_tokens) + tokenizer_file.pieces.extend(new_tokens) - with open(f"{saved_location}/tokenizer.model", "wb") as file: - file.write(tokenizer_file.SerializeToString()) - pass + with open(f"{saved_location}/tokenizer.model", "wb") as file: + file.write(tokenizer_file.SerializeToString()) + pass - # Add padding tokens - # actual_vocab_size = model.config.vocab_size - # padding = actual_vocab_size - len(tokenizer_file.pieces) - return + # Add padding tokens + # actual_vocab_size = model.config.vocab_size + # padding = actual_vocab_size - len(tokenizer_file.pieces) + return + except Exception as e: + if os.environ.get("UNSLOTH_LOGGING_ENABLED", "0") == "1": + logger.error(f"Unsloth: Failed to fix sentencepiece gguf because {str(e)}") + pass pass