-
Notifications
You must be signed in to change notification settings - Fork 20
Llm integration POC #1028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
timfdev
wants to merge
12
commits into
main
Choose a base branch
from
llm-integration
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Llm integration POC #1028
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
99da746
Vector search and agent mode POC
timfdev 960ce80
l
timfdev d34d467
add ag-ui package
timfdev cede6f5
fix linting
timfdev 65e963d
last lint fix
timfdev 90a5d1a
Streaming pipeline for indexing, using litellm to track token count v…
timfdev d0a23ec
fix mypy issues & use pgvector image
timfdev 4fce33d
use pgvector for codspeed tests
timfdev b8b4eb8
Merge branch 'main' into llm-integration
timfdev 1ec3625
update docs and cleanup
timfdev 5d4c316
Merge branch 'llm-integration' of github.com:workfloworchestrator/orc…
timfdev f837226
use python 3.10+ style type hinting
timfdev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import structlog | ||
from fastapi import FastAPI, HTTPException | ||
from starlette.types import ASGIApp | ||
|
||
from orchestrator.settings import app_settings | ||
|
||
logger = structlog.get_logger(__name__) | ||
|
||
|
||
def _disabled_agent_app(reason: str) -> FastAPI: | ||
app = FastAPI(title="Agent disabled") | ||
|
||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"]) | ||
async def _disabled(path: str) -> None: | ||
raise HTTPException(status_code=503, detail=f"Agent disabled: {reason}") | ||
|
||
return app | ||
|
||
|
||
def build_agent_app() -> ASGIApp: | ||
if not app_settings.AGENT_MODEL or not app_settings.OPENAI_API_KEY: | ||
logger.warning("Agent route disabled: missing model or OPENAI_API_KEY") | ||
return _disabled_agent_app("missing configuration") | ||
|
||
try: | ||
from pydantic_ai.ag_ui import StateDeps | ||
from pydantic_ai.agent import Agent | ||
from pydantic_ai.settings import ModelSettings | ||
|
||
from orchestrator.search.agent.prompts import get_base_instructions, get_dynamic_instructions | ||
from orchestrator.search.agent.state import SearchState | ||
from orchestrator.search.agent.tools import search_toolset | ||
except ImportError: | ||
logger.error( | ||
"\nRequired packages not installed:\n" | ||
"WARNING: These packages are NOT compatible with the current " | ||
"pydantic version in orchestrator-core.\n Upgrading pydantic to install " | ||
"may cause incompatibilities or runtime errors.\n\n" | ||
" pydantic-ai==0.7.0\n" | ||
" ag-ui-protocol>=0.1.8\n\n" | ||
"Install them locally to enable the agent:\n" | ||
" pip install 'pydantic-ai==0.7.0' 'ag-ui-protocol>=0.1.8'\n" | ||
) | ||
logger.warning("Agent route disabled: Missing required packages") | ||
return _disabled_agent_app("Missing required packages") | ||
|
||
try: | ||
agent = Agent( | ||
model=app_settings.AGENT_MODEL, | ||
deps_type=StateDeps[SearchState], | ||
model_settings=ModelSettings( | ||
parallel_tool_calls=False | ||
), # https://github.com/pydantic/pydantic-ai/issues/562 | ||
toolsets=[search_toolset], | ||
) | ||
agent.instructions(get_base_instructions) | ||
agent.instructions(get_dynamic_instructions) | ||
|
||
return agent.to_ag_ui(deps=StateDeps(SearchState())) | ||
except Exception as e: | ||
logger.error("Agent init failed; serving disabled stub.", error=str(e)) | ||
return _disabled_agent_app(str(e)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from typing import Any, TypeVar, cast | ||
|
||
from fastapi import APIRouter | ||
from pydantic import BaseModel | ||
from sqlalchemy import case, select | ||
from sqlalchemy.orm import selectinload | ||
|
||
from orchestrator.db import ( | ||
ProcessTable, | ||
ProductTable, | ||
SubscriptionTable, | ||
WorkflowTable, | ||
db, | ||
) | ||
from orchestrator.schemas.search import ( | ||
ConnectionSchema, | ||
PageInfoSchema, | ||
ProcessSearchSchema, | ||
ProductSearchSchema, | ||
SubscriptionSearchResult, | ||
WorkflowSearchSchema, | ||
) | ||
from orchestrator.schemas.subscription import SubscriptionDomainModelSchema | ||
from orchestrator.search.retrieval import execute_search | ||
from orchestrator.search.schemas.parameters import ( | ||
BaseSearchParameters, | ||
ProcessSearchParameters, | ||
ProductSearchParameters, | ||
SubscriptionSearchParameters, | ||
WorkflowSearchParameters, | ||
) | ||
|
||
router = APIRouter(tags=["Search"], prefix="/search") | ||
T = TypeVar("T", bound=BaseModel) | ||
|
||
|
||
async def _perform_search_and_fetch_simple( | ||
search_params: BaseSearchParameters, | ||
db_model: Any, | ||
response_schema: type[BaseModel], | ||
pk_column_name: str, | ||
eager_loads: list[Any], | ||
) -> ConnectionSchema: | ||
results = await execute_search(search_params=search_params, db_session=db.session, limit=20) | ||
|
||
if not results: | ||
data: dict[str, Any] = {"page_info": PageInfoSchema(), "page": []} | ||
return ConnectionSchema(**cast(Any, data)) | ||
|
||
entity_ids = [res.entity_id for res in results] | ||
pk_column = getattr(db_model, pk_column_name) | ||
ordering_case = case({entity_id: i for i, entity_id in enumerate(entity_ids)}, value=pk_column) | ||
|
||
stmt = select(db_model).options(*eager_loads).filter(pk_column.in_(entity_ids)).order_by(ordering_case) | ||
entities = db.session.scalars(stmt).all() | ||
|
||
page = [response_schema.model_validate(entity) for entity in entities] | ||
|
||
data = {"page_info": PageInfoSchema(), "page": page} | ||
return ConnectionSchema(**cast(Any, data)) | ||
|
||
|
||
@router.post( | ||
"/subscriptions", | ||
response_model=ConnectionSchema[SubscriptionSearchResult], | ||
response_model_by_alias=True, | ||
) | ||
async def search_subscriptions( | ||
search_params: SubscriptionSearchParameters, | ||
) -> ConnectionSchema[SubscriptionSearchResult]: | ||
search_results = await execute_search(search_params=search_params, db_session=db.session, limit=20) | ||
|
||
if not search_results: | ||
data = {"page_info": PageInfoSchema(), "page": []} | ||
return ConnectionSchema(**cast(Any, data)) | ||
|
||
search_info_map = {res.entity_id: res for res in search_results} | ||
entity_ids = list(search_info_map.keys()) | ||
|
||
pk_column = SubscriptionTable.subscription_id | ||
ordering_case = case({entity_id: i for i, entity_id in enumerate(entity_ids)}, value=pk_column) | ||
|
||
stmt = ( | ||
select(SubscriptionTable) | ||
.options( | ||
selectinload(SubscriptionTable.product), | ||
selectinload(SubscriptionTable.customer_descriptions), | ||
) | ||
.filter(pk_column.in_(entity_ids)) | ||
.order_by(ordering_case) | ||
) | ||
subscriptions = db.session.scalars(stmt).all() | ||
|
||
page = [] | ||
for sub in subscriptions: | ||
search_data = search_info_map.get(str(sub.subscription_id)) | ||
if search_data: | ||
subscription_model = SubscriptionDomainModelSchema.model_validate(sub) | ||
|
||
result_item = SubscriptionSearchResult( | ||
score=search_data.score, | ||
highlight=search_data.highlight, | ||
subscription=subscription_model.model_dump(), | ||
) | ||
page.append(result_item) | ||
|
||
data = {"page_info": PageInfoSchema(), "page": page} | ||
return ConnectionSchema(**cast(Any, data)) | ||
|
||
|
||
@router.post("/workflows", response_model=ConnectionSchema[WorkflowSearchSchema], response_model_by_alias=True) | ||
async def search_workflows(search_params: WorkflowSearchParameters) -> ConnectionSchema[WorkflowSearchSchema]: | ||
return await _perform_search_and_fetch_simple( | ||
search_params=search_params, | ||
db_model=WorkflowTable, | ||
response_schema=WorkflowSearchSchema, | ||
pk_column_name="workflow_id", | ||
eager_loads=[selectinload(WorkflowTable.products)], | ||
) | ||
|
||
|
||
@router.post("/products", response_model=ConnectionSchema[ProductSearchSchema], response_model_by_alias=True) | ||
async def search_products(search_params: ProductSearchParameters) -> ConnectionSchema[ProductSearchSchema]: | ||
return await _perform_search_and_fetch_simple( | ||
search_params=search_params, | ||
db_model=ProductTable, | ||
response_schema=ProductSearchSchema, | ||
pk_column_name="product_id", | ||
eager_loads=[ | ||
selectinload(ProductTable.workflows), | ||
selectinload(ProductTable.fixed_inputs), | ||
selectinload(ProductTable.product_blocks), | ||
], | ||
) | ||
|
||
|
||
@router.post("/processes", response_model=ConnectionSchema[ProcessSearchSchema], response_model_by_alias=True) | ||
async def search_processes(search_params: ProcessSearchParameters) -> ConnectionSchema[ProcessSearchSchema]: | ||
return await _perform_search_and_fetch_simple( | ||
search_params=search_params, | ||
db_model=ProcessTable, | ||
response_schema=ProcessSearchSchema, | ||
pk_column_name="process_id", | ||
eager_loads=[ | ||
selectinload(ProcessTable.workflow), | ||
], | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import typer | ||
|
||
from orchestrator.search.core.types import EntityType | ||
from orchestrator.search.indexing import run_indexing_for_entity | ||
|
||
app = typer.Typer( | ||
name="index", | ||
help="Index search indexes", | ||
) | ||
|
||
|
||
@app.command("subscriptions") | ||
def subscriptions_command( | ||
subscription_id: str | None = typer.Option(None, help="UUID (default = all)"), | ||
dry_run: bool = typer.Option(False, help="No DB writes"), | ||
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"), | ||
) -> None: | ||
"""Index subscription_search_index.""" | ||
run_indexing_for_entity( | ||
entity_kind=EntityType.SUBSCRIPTION, | ||
entity_id=subscription_id, | ||
dry_run=dry_run, | ||
force_index=force_index, | ||
) | ||
|
||
|
||
@app.command("products") | ||
def products_command( | ||
product_id: str | None = typer.Option(None, help="UUID (default = all)"), | ||
dry_run: bool = typer.Option(False, help="No DB writes"), | ||
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"), | ||
) -> None: | ||
"""Index product_search_index.""" | ||
run_indexing_for_entity( | ||
entity_kind=EntityType.PRODUCT, | ||
entity_id=product_id, | ||
dry_run=dry_run, | ||
force_index=force_index, | ||
) | ||
|
||
|
||
@app.command("processes") | ||
def processes_command( | ||
process_id: str | None = typer.Option(None, help="UUID (default = all)"), | ||
dry_run: bool = typer.Option(False, help="No DB writes"), | ||
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"), | ||
) -> None: | ||
"""Index process_search_index.""" | ||
run_indexing_for_entity( | ||
entity_kind=EntityType.PROCESS, | ||
entity_id=process_id, | ||
dry_run=dry_run, | ||
force_index=force_index, | ||
) | ||
|
||
|
||
@app.command("workflows") | ||
def workflows_command( | ||
workflow_id: str | None = typer.Option(None, help="UUID (default = all)"), | ||
dry_run: bool = typer.Option(False, help="No DB writes"), | ||
force_index: bool = typer.Option(False, help="Force re-index (ignore hash cache)"), | ||
) -> None: | ||
"""Index workflow_search_index.""" | ||
run_indexing_for_entity( | ||
entity_kind=EntityType.WORKFLOW, | ||
entity_id=workflow_id, | ||
dry_run=dry_run, | ||
force_index=force_index, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
app() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tags=["Search"]? And should we add prefix = /search (and update the search router accordingly)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see you have it in the other file. Maybe that should be moved here to keep the "overview" of how routes are mounted in this file.