diff --git a/packages/providers/pilabs/models/README.md b/packages/providers/pilabs/models/README.md new file mode 100644 index 0000000..4f7b38c --- /dev/null +++ b/packages/providers/pilabs/models/README.md @@ -0,0 +1,50 @@ +# NLWeb Pi Labs Models + +Pi Labs LLM scoring provider for NLWeb. + +## Overview + +This package provides integration with Pi Labs scoring API for relevance scoring in NLWeb queries. + +## Features + +- **PiLabsProvider**: LLM provider that uses Pi Labs scoring API +- **PiLabsClient**: HTTP client for Pi Labs API +- Async scoring with httpx and HTTP/2 support +- Thread-safe client initialization + +## Installation + +```bash +pip install -e packages/providers/pilabs/models +``` + +## Usage + +Configure in your `config.yaml`: + +```yaml +llm: + scoring: + llm_type: pilabs + endpoint: "http://localhost:8001/invocations" + import_path: "nlweb_pilabs_models.llm" + class_name: "PiLabsProvider" +``` + +## Requirements + +- Python >= 3.11 +- httpx with HTTP/2 support +- nlweb_core + +## API + +The Pi Labs provider expects: +- `request.query`: The user query +- `item.description`: The item to score +- `site.itemType`: The type of item + +Returns: +- `score`: Relevance score (0-100) +- `description`: Item description diff --git a/packages/providers/pilabs/models/nlweb_pilabs_models/__init__.py b/packages/providers/pilabs/models/nlweb_pilabs_models/__init__.py new file mode 100644 index 0000000..25f05bb --- /dev/null +++ b/packages/providers/pilabs/models/nlweb_pilabs_models/__init__.py @@ -0,0 +1,3 @@ +"""Pi Labs Models package for NLWeb.""" + +__version__ = "0.1.0" diff --git a/packages/providers/pilabs/models/nlweb_pilabs_models/llm/__init__.py b/packages/providers/pilabs/models/nlweb_pilabs_models/llm/__init__.py new file mode 100644 index 0000000..35494a1 --- /dev/null +++ b/packages/providers/pilabs/models/nlweb_pilabs_models/llm/__init__.py @@ -0,0 +1,5 @@ +"""Pi Labs LLM providers.""" + +from nlweb_pilabs_models.llm.pi_labs import PiLabsProvider, PiLabsClient + +__all__ = ["PiLabsProvider", "PiLabsClient"] diff --git a/packages/providers/pilabs/models/nlweb_pilabs_models/llm/pi_labs.py b/packages/providers/pilabs/models/nlweb_pilabs_models/llm/pi_labs.py new file mode 100644 index 0000000..0385dc5 --- /dev/null +++ b/packages/providers/pilabs/models/nlweb_pilabs_models/llm/pi_labs.py @@ -0,0 +1,156 @@ +import asyncio +import threading +from typing import Any +import httpx +import json + +from nlweb_core.llm import LLMProvider + + +class PiLabsClient: + """PiLabsClient accesses a Pi Labs scoring API. + It lazily initializes the client it will use to make requests.""" + + _client: httpx.AsyncClient + _url: str + + def __init__(self, url: str = "http://localhost:8001/invocations"): + self._url = url + self._client = httpx.AsyncClient( + http2=True, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=30), + ) + + async def score( + self, + llm_input: str, + llm_output: str, + scoring_spec: list[dict[str, Any]], + timeout: float = 30.0, + ) -> float: + resp = await self._client.post( + url=self._url, + json={ + "llm_input": llm_input, + "llm_output": llm_output, + "scoring_spec": scoring_spec, + }, + timeout=timeout, + ) + resp.raise_for_status() + return resp.json().get("total_score", 0) * 100 + + +class PiLabsProvider(LLMProvider): + """PiLabsProvider accesses a Pi Labs scoring API.""" + + _client_lock = threading.Lock() + _client: PiLabsClient | None = None + + @classmethod + def get_client(cls) -> PiLabsClient: + with cls._client_lock: + if cls._client is None: + cls._client = PiLabsClient() + return cls._client + + async def get_completion( + self, + prompt: str, + schema: dict[str, Any], + model: str | None = None, + temperature: float = 0, + max_tokens: int = 0, + timeout: float = 30.0, + **kwargs, + ) -> dict[str, Any]: + if schema.keys() != {"score", "description"}: + raise ValueError( + "PiLabsProvider only supports schema with 'score' and 'description' fields." + ) + if {"request.query", "site.itemType", "item.description"} - kwargs.keys(): + raise ValueError( + "PiLabsProvider requires 'request.query', 'site.itemType', and 'item.description' in kwargs." + ) + client = self.get_client() + score = await client.score( + llm_input=kwargs["request.query"].text, + llm_output=json.dumps(kwargs["item.description"]), + scoring_spec=[ + {"question": "Is this item relevant to the query?"}, + ], + timeout=timeout, + ) + return {"score": score, "description": kwargs["item.description"]} + + @classmethod + def clean_response(cls, content: str) -> dict[str, Any]: + raise NotImplementedError("PiLabsProvider does not support clean_response.") + + +async def pi_scoring_comparison(file): + # Generate output filename + base_name = file.rsplit(".", 1)[0] if "." in file else file + output_file = f"{base_name}_pi_eval.csv" + client = PiLabsProvider.get_client() + + with open(file, "r") as f: + lines = f.readlines() + data = [] + for line in lines: + try: + data.append(json.loads(line)) + except json.JSONDecodeError: + continue + + tasks = [] + async with asyncio.TaskGroup() as tg: + for item in data: + tasks.append(tg.create_task(process_item(item, client))) + + with open(output_file, "a") as f: + for task in tasks: + score, pi_score, csv_line = task.result() + if score > 64 or pi_score > 30: + print(csv_line) + f.write(csv_line + "\n") + + +async def process_item(item, client): + item_fields = { + "url": item.get("url", ""), + "name": item.get("name", ""), + "site": item.get("site", ""), + "siteUrl": item.get("site", ""), + "score": item.get("ranking", {}).get("score", 0), + "description": item.get("ranking", {}).get("description", ""), + "schema_object": item.get("schema_object", {}), + "query": item.get("query", ""), + } + desc = json.dumps(item_fields["schema_object"]) + pi_score, time_taken = await client.score( + item["query"], + desc, + scoring_spec=[ + {"question": "Is the item relevant to the query?"}, + ], + ) + score = item_fields["score"] + + item["ranking"]["score"] = pi_score + csv_line = f"O={score},P={pi_score},T={time_taken},Q={item_fields['query']},N={item_fields['name']}" # ,D={item_fields['description']}" + + if score > 64 or pi_score > 30: + print(csv_line) + return score, pi_score, csv_line + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python3 -m nlweb_models.llm.pi_labs ") + sys.exit(1) + + input_file = sys.argv[1] + asyncio.run(pi_scoring_comparison(input_file)) diff --git a/packages/providers/pilabs/models/pyproject.toml b/packages/providers/pilabs/models/pyproject.toml new file mode 100644 index 0000000..81d2b2a --- /dev/null +++ b/packages/providers/pilabs/models/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "nlweb_pilabs_models" +version = "0.1.0" +description = "Pi Labs model providers for NLWeb" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "nlweb_core", + "httpx[http2]>=0.24.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["nlweb_pilabs_models*"]