Skip to content

Llm integration POC #1028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/run-codspeed-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ name: CodSpeed

on:
push:
branches: [ "main" ]
branches: ["main"]
pull_request:
branches: [ "main" ]
branches: ["main"]

env:
UV_LOCKED: true # Assert that the `uv.lock` will remain unchanged
UV_LOCKED: true # Assert that the `uv.lock` will remain unchanged

jobs:
codspeed:
Expand All @@ -18,7 +18,7 @@ jobs:
options: --privileged
services:
postgres:
image: postgres:15-alpine
image: pgvector/pgvector:pg15
# Provide the password for postgres
env:
POSTGRES_PASSWORD: nwa
Expand Down
13 changes: 7 additions & 6 deletions .github/workflows/run-unit-tests.yml
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
name: Unit tests
on:
push:
branches: [ main ]
branches: [main]
workflow_call:
pull_request:

env:
UV_LOCKED: true # Assert that the `uv.lock` will remain unchanged
UV_LOCKED: true # Assert that the `uv.lock` will remain unchanged

jobs:
container_job:
name: Unit tests Python (${{ matrix.python-version }}) Postgres (${{ matrix.postgres-version }})
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11', '3.12', '3.13']
postgres-version: ['15', '16', '17']
python-version: ["3.11", "3.12", "3.13"]
postgres-version: ["15", "16", "17"]
fail-fast: false
container: ubuntu:latest
services:
postgres:
image: postgres:${{ matrix.postgres-version }}-alpine
image: pgvector/pgvector:pg${{ matrix.postgres-version }}

# Provide the password for postgres
env:
POSTGRES_PASSWORD: nwa
Expand Down Expand Up @@ -68,6 +69,6 @@ jobs:
- name: "Upload coverage to Codecov"
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }} # gives error 'Could not find a repository associated with upload token'
token: ${{ secrets.CODECOV_TOKEN }} # gives error 'Could not find a repository associated with upload token'
fail_ci_if_error: false
files: ./coverage.xml
3 changes: 3 additions & 0 deletions orchestrator/api/api_v1/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
product_blocks,
products,
resource_types,
search,
settings,
subscription_customer_descriptions,
subscriptions,
Expand Down Expand Up @@ -83,3 +84,5 @@
tags=["Core", "Translations"],
)
api_router.include_router(ws.router, prefix="/ws", tags=["Core", "Events"])

api_router.include_router(search.router)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tags=["Search"]? And should we add prefix = /search (and update the search router accordingly)?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see you have it in the other file. Maybe that should be moved here to keep the "overview" of how routes are mounted in this file.

62 changes: 62 additions & 0 deletions orchestrator/api/api_v1/endpoints/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import structlog
from fastapi import FastAPI, HTTPException
from starlette.types import ASGIApp

from orchestrator.settings import app_settings

logger = structlog.get_logger(__name__)


def _disabled_agent_app(reason: str) -> FastAPI:
app = FastAPI(title="Agent disabled")

@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"])
async def _disabled(path: str) -> None:
raise HTTPException(status_code=503, detail=f"Agent disabled: {reason}")

return app


def build_agent_app() -> ASGIApp:
if not app_settings.AGENT_MODEL or not app_settings.OPENAI_API_KEY:
logger.warning("Agent route disabled: missing model or OPENAI_API_KEY")
return _disabled_agent_app("missing configuration")

try:
from pydantic_ai.ag_ui import StateDeps
from pydantic_ai.agent import Agent
from pydantic_ai.settings import ModelSettings

from orchestrator.search.agent.prompts import get_base_instructions, get_dynamic_instructions
from orchestrator.search.agent.state import SearchState
from orchestrator.search.agent.tools import search_toolset
except ImportError:
logger.error(
"\nRequired packages not installed:\n"
"WARNING: These packages are NOT compatible with the current "
"pydantic version in orchestrator-core.\n Upgrading pydantic to install "
"may cause incompatibilities or runtime errors.\n\n"
" pydantic-ai==0.7.0\n"
" ag-ui-protocol>=0.1.8\n\n"
"Install them locally to enable the agent:\n"
" pip install 'pydantic-ai==0.7.0' 'ag-ui-protocol>=0.1.8'\n"
)
logger.warning("Agent route disabled: Missing required packages")
return _disabled_agent_app("Missing required packages")

try:
agent = Agent(
model=app_settings.AGENT_MODEL,
deps_type=StateDeps[SearchState],
model_settings=ModelSettings(
parallel_tool_calls=False
), # https://github.com/pydantic/pydantic-ai/issues/562
toolsets=[search_toolset],
)
agent.instructions(get_base_instructions)
agent.instructions(get_dynamic_instructions)

return agent.to_ag_ui(deps=StateDeps(SearchState()))
except Exception as e:
logger.error("Agent init failed; serving disabled stub.", error=str(e))
return _disabled_agent_app(str(e))
147 changes: 147 additions & 0 deletions orchestrator/api/api_v1/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from typing import Any, TypeVar, cast

from fastapi import APIRouter
from pydantic import BaseModel
from sqlalchemy import case, select
from sqlalchemy.orm import selectinload

from orchestrator.db import (
ProcessTable,
ProductTable,
SubscriptionTable,
WorkflowTable,
db,
)
from orchestrator.schemas.search import (
ConnectionSchema,
PageInfoSchema,
ProcessSearchSchema,
ProductSearchSchema,
SubscriptionSearchResult,
WorkflowSearchSchema,
)
from orchestrator.schemas.subscription import SubscriptionDomainModelSchema
from orchestrator.search.retrieval import execute_search
from orchestrator.search.schemas.parameters import (
BaseSearchParameters,
ProcessSearchParameters,
ProductSearchParameters,
SubscriptionSearchParameters,
WorkflowSearchParameters,
)

router = APIRouter(tags=["Search"], prefix="/search")
T = TypeVar("T", bound=BaseModel)


async def _perform_search_and_fetch_simple(
search_params: BaseSearchParameters,
db_model: Any,
response_schema: type[BaseModel],
pk_column_name: str,
eager_loads: list[Any],
) -> ConnectionSchema:
results = await execute_search(search_params=search_params, db_session=db.session, limit=20)

if not results:
data: dict[str, Any] = {"page_info": PageInfoSchema(), "page": []}
return ConnectionSchema(**cast(Any, data))

entity_ids = [res.entity_id for res in results]
pk_column = getattr(db_model, pk_column_name)
ordering_case = case({entity_id: i for i, entity_id in enumerate(entity_ids)}, value=pk_column)

stmt = select(db_model).options(*eager_loads).filter(pk_column.in_(entity_ids)).order_by(ordering_case)
entities = db.session.scalars(stmt).all()

page = [response_schema.model_validate(entity) for entity in entities]

data = {"page_info": PageInfoSchema(), "page": page}
return ConnectionSchema(**cast(Any, data))


@router.post(
"/subscriptions",
response_model=ConnectionSchema[SubscriptionSearchResult],
response_model_by_alias=True,
)
async def search_subscriptions(
search_params: SubscriptionSearchParameters,
) -> ConnectionSchema[SubscriptionSearchResult]:
search_results = await execute_search(search_params=search_params, db_session=db.session, limit=20)

if not search_results:
data = {"page_info": PageInfoSchema(), "page": []}
return ConnectionSchema(**cast(Any, data))

search_info_map = {res.entity_id: res for res in search_results}
entity_ids = list(search_info_map.keys())

pk_column = SubscriptionTable.subscription_id
ordering_case = case({entity_id: i for i, entity_id in enumerate(entity_ids)}, value=pk_column)

stmt = (
select(SubscriptionTable)
.options(
selectinload(SubscriptionTable.product),
selectinload(SubscriptionTable.customer_descriptions),
)
.filter(pk_column.in_(entity_ids))
.order_by(ordering_case)
)
subscriptions = db.session.scalars(stmt).all()

page = []
for sub in subscriptions:
search_data = search_info_map.get(str(sub.subscription_id))
if search_data:
subscription_model = SubscriptionDomainModelSchema.model_validate(sub)

result_item = SubscriptionSearchResult(
score=search_data.score,
highlight=search_data.highlight,
subscription=subscription_model.model_dump(),
)
page.append(result_item)

data = {"page_info": PageInfoSchema(), "page": page}
return ConnectionSchema(**cast(Any, data))


@router.post("/workflows", response_model=ConnectionSchema[WorkflowSearchSchema], response_model_by_alias=True)
async def search_workflows(search_params: WorkflowSearchParameters) -> ConnectionSchema[WorkflowSearchSchema]:
return await _perform_search_and_fetch_simple(
search_params=search_params,
db_model=WorkflowTable,
response_schema=WorkflowSearchSchema,
pk_column_name="workflow_id",
eager_loads=[selectinload(WorkflowTable.products)],
)


@router.post("/products", response_model=ConnectionSchema[ProductSearchSchema], response_model_by_alias=True)
async def search_products(search_params: ProductSearchParameters) -> ConnectionSchema[ProductSearchSchema]:
return await _perform_search_and_fetch_simple(
search_params=search_params,
db_model=ProductTable,
response_schema=ProductSearchSchema,
pk_column_name="product_id",
eager_loads=[
selectinload(ProductTable.workflows),
selectinload(ProductTable.fixed_inputs),
selectinload(ProductTable.product_blocks),
],
)


@router.post("/processes", response_model=ConnectionSchema[ProcessSearchSchema], response_model_by_alias=True)
async def search_processes(search_params: ProcessSearchParameters) -> ConnectionSchema[ProcessSearchSchema]:
return await _perform_search_and_fetch_simple(
search_params=search_params,
db_model=ProcessTable,
response_schema=ProcessSearchSchema,
pk_column_name="process_id",
eager_loads=[
selectinload(ProcessTable.workflow),
],
)
4 changes: 4 additions & 0 deletions orchestrator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from oauth2_lib.fastapi import AuthManager, Authorization, GraphqlAuthorization, OIDCAuth
from orchestrator import __version__
from orchestrator.api.api_v1.api import api_router
from orchestrator.api.api_v1.endpoints.agent import build_agent_app
from orchestrator.api.error_handling import ProblemDetailException
from orchestrator.cli.main import app as cli_app
from orchestrator.db import db, init_database
Expand Down Expand Up @@ -150,6 +151,9 @@ def __init__(
metrics_app = make_asgi_app(registry=ORCHESTRATOR_METRICS_REGISTRY)
self.mount("/api/metrics", metrics_app)

agent_app = build_agent_app()
self.mount("/agent", agent_app)

@self.router.get("/", response_model=str, response_class=JSONResponse, include_in_schema=False)
def _index() -> str:
return "Orchestrator Core"
Expand Down
73 changes: 73 additions & 0 deletions orchestrator/cli/index_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import typer

from orchestrator.search.core.types import EntityType
from orchestrator.search.indexing import run_indexing_for_entity

app = typer.Typer(
name="index",
help="Index search indexes",
)


@app.command("subscriptions")
def subscriptions_command(
subscription_id: str | None = typer.Option(None, help="UUID (default = all)"),
dry_run: bool = typer.Option(False, help="No DB writes"),
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"),
) -> None:
"""Index subscription_search_index."""
run_indexing_for_entity(
entity_kind=EntityType.SUBSCRIPTION,
entity_id=subscription_id,
dry_run=dry_run,
force_index=force_index,
)


@app.command("products")
def products_command(
product_id: str | None = typer.Option(None, help="UUID (default = all)"),
dry_run: bool = typer.Option(False, help="No DB writes"),
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"),
) -> None:
"""Index product_search_index."""
run_indexing_for_entity(
entity_kind=EntityType.PRODUCT,
entity_id=product_id,
dry_run=dry_run,
force_index=force_index,
)


@app.command("processes")
def processes_command(
process_id: str | None = typer.Option(None, help="UUID (default = all)"),
dry_run: bool = typer.Option(False, help="No DB writes"),
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"),
) -> None:
"""Index process_search_index."""
run_indexing_for_entity(
entity_kind=EntityType.PROCESS,
entity_id=process_id,
dry_run=dry_run,
force_index=force_index,
)


@app.command("workflows")
def workflows_command(
workflow_id: str | None = typer.Option(None, help="UUID (default = all)"),
dry_run: bool = typer.Option(False, help="No DB writes"),
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"),
) -> None:
"""Index workflow_search_index."""
run_indexing_for_entity(
entity_kind=EntityType.WORKFLOW,
entity_id=workflow_id,
dry_run=dry_run,
force_index=force_index,
)


if __name__ == "__main__":
app()
4 changes: 3 additions & 1 deletion orchestrator/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

import typer

from orchestrator.cli import database, generate, scheduler
from orchestrator.cli import database, generate, index_llm, scheduler, search_explore

app = typer.Typer()
app.add_typer(scheduler.app, name="scheduler", help="Access all the scheduler functions")
app.add_typer(database.app, name="db", help="Interact with the application database")
app.add_typer(generate.app, name="generate", help="Generate products, workflows and other artifacts")
app.add_typer(index_llm.app, name="index")
app.add_typer(search_explore.app, name="search")


if __name__ == "__main__":
Expand Down
Loading