Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ __pycache__
/litellm


pipelines/*
!pipelines/.gitignore
.DS_Store
# Ignore everything in pipelines
pipelines/*

# But keep files directly inside pipelines/
!pipelines/*.*

.venv
venv/
venv/
.idea/
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@

API_KEY = os.getenv("PIPELINES_API_KEY", "0p3n-w3bu!")
PIPELINES_DIR = os.getenv("PIPELINES_DIR", "./pipelines")
DATABASE_URL = os.getenv("DATABASE_URL")
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
66 changes: 49 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool


from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator, Iterator


from utils.pipelines.auth import bearer_security, get_current_user
from utils.pipelines.main import get_last_user_message, stream_message_template
from utils.pipelines.misc import convert_to_raw_url
Expand All @@ -28,13 +26,13 @@
import sys
import subprocess


from config import API_KEY, PIPELINES_DIR, LOG_LEVELS
from config import API_KEY, PIPELINES_DIR, LOG_LEVELS, DATABASE_URL, REDIS_URL
from utils.pipelines.database import init_database, close_database, is_database_available
from utils.pipelines.redis_client import init_redis, close_redis, is_redis_available

if not os.path.exists(PIPELINES_DIR):
os.makedirs(PIPELINES_DIR)


PIPELINES = {}
PIPELINE_MODULES = {}
PIPELINE_NAMES = {}
Expand Down Expand Up @@ -88,13 +86,13 @@ def get_all_pipelines():
"pipelines": (
pipeline.valves.pipelines
if hasattr(pipeline, "valves")
and hasattr(pipeline.valves, "pipelines")
and hasattr(pipeline.valves, "pipelines")
else []
),
"priority": (
pipeline.valves.priority
if hasattr(pipeline, "valves")
and hasattr(pipeline.valves, "priority")
and hasattr(pipeline.valves, "priority")
else 0
),
"valves": pipeline.valves if hasattr(pipeline, "valves") else None,
Expand Down Expand Up @@ -131,7 +129,6 @@ def install_frontmatter_requirements(requirements):


async def load_module_from_path(module_name, module_path):

try:
# Read the module content
with open(module_path, "r") as file:
Expand Down Expand Up @@ -224,6 +221,22 @@ async def load_modules_from_directory(directory):


async def on_startup():
# Initialize database if DATABASE_URL is provided
if DATABASE_URL is not None:
try:
await init_database()
except Exception as e:
logging.error(f"Failed to initialize database: {e}")
# Continue without database if initialization fails

# Initialize Redis if REDIS_URL is provided
if REDIS_URL is not None:
try:
await init_redis()
except Exception as e:
logging.error(f"Failed to initialize Redis: {e}")
# Continue without Redis if it fails

await load_modules_from_directory(PIPELINES_DIR)

for module in PIPELINE_MODULES.values():
Expand All @@ -236,6 +249,14 @@ async def on_shutdown():
if hasattr(module, "on_shutdown"):
await module.on_shutdown()

# Close database connection if it was initialized
if is_database_available():
await close_database()

# Close Redis connection if it was initialized
if is_redis_available():
await close_redis()


async def reload():
await on_shutdown()
Expand All @@ -258,10 +279,8 @@ async def lifespan(app: FastAPI):

app.state.PIPELINES = PIPELINES


origins = ["*"]


app.add_middleware(
CORSMiddleware,
allow_origins=origins,
Expand Down Expand Up @@ -324,7 +343,21 @@ async def get_models(user: str = Depends(get_current_user)):
@app.get("/v1")
@app.get("/")
async def get_status():
return {"status": True}
status_info = {"status": True}

# Add database status if available
if DATABASE_URL is not None:
status_info["database"] = {
"available": is_database_available(),
"url_configured": True
}
else:
status_info["database"] = {
"available": False,
"url_configured": False
}

return status_info


@app.get("/v1/pipelines")
Expand Down Expand Up @@ -387,7 +420,7 @@ async def download_file(url: str, dest_folder: str):
@app.post("/v1/pipelines/add")
@app.post("/pipelines/add")
async def add_pipeline(
form_data: AddPipelineForm, user: str = Depends(get_current_user)
form_data: AddPipelineForm, user: str = Depends(get_current_user)
):
if user != API_KEY:
raise HTTPException(
Expand Down Expand Up @@ -417,7 +450,7 @@ async def add_pipeline(
@app.post("/v1/pipelines/upload")
@app.post("/pipelines/upload")
async def upload_pipeline(
file: UploadFile = File(...), user: str = Depends(get_current_user)
file: UploadFile = File(...), user: str = Depends(get_current_user)
):
if user != API_KEY:
raise HTTPException(
Expand Down Expand Up @@ -466,7 +499,7 @@ class DeletePipelineForm(BaseModel):
@app.delete("/v1/pipelines/delete")
@app.delete("/pipelines/delete")
async def delete_pipeline(
form_data: DeletePipelineForm, user: str = Depends(get_current_user)
form_data: DeletePipelineForm, user: str = Depends(get_current_user)
):
if user != API_KEY:
raise HTTPException(
Expand Down Expand Up @@ -552,7 +585,6 @@ async def get_valves_spec(pipeline_id: str):
@app.post("/v1/{pipeline_id}/valves/update")
@app.post("/{pipeline_id}/valves/update")
async def update_valves(pipeline_id: str, form_data: dict):

if pipeline_id not in PIPELINE_MODULES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down Expand Up @@ -663,8 +695,8 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(messages)

if (
form_data.model not in app.state.PIPELINES
or app.state.PIPELINES[form_data.model]["type"] == "filter"
form_data.model not in app.state.PIPELINES
or app.state.PIPELINES[form_data.model]["type"] == "filter"
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down
28 changes: 28 additions & 0 deletions pipeline.service
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[Unit]
Description=SquadRun Pipelines Service
After=network.target

[Service]
# Run inside project directory
WorkingDirectory=/home/ubuntu/squadrun-pipelines/

# Export environment (virtualenv)
Environment="PATH=/home/ubuntu/squadrun-pipelines/.venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"

# Run the script with bash
ExecStart=/bin/bash ./start.sh

# Restart policy
Restart=always
RestartSec=5

# User/group to run as
User=ubuntu
Group=ubuntu

# Logging (journalctl -u squadrun-pipelines)
StandardOutput=journal
StandardError=journal

[Install]
WantedBy=multi-user.target
Binary file not shown.
Loading