-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathserver.py
More file actions
257 lines (216 loc) · 8.25 KB
/
server.py
File metadata and controls
257 lines (216 loc) · 8.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#!/usr/bin/env python3
"""
GNN Pipeline FastAPI Server.
Provides REST endpoints for pipeline job management and tool invocation.
No authentication — designed for local research use.
Run with:
python -m api.server
# or:
uvicorn api.server:app --host 0.0.0.0 --port 8000 --reload
"""
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Dict
logger = logging.getLogger(__name__)
try:
import uvicorn
from fastapi import BackgroundTasks, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
except ImportError as e:
raise ImportError(
"FastAPI and uvicorn are required for the GNN API server. "
"Install with: uv sync --extra api"
) from e
from api import processor as job_mgr
from api.models import (
HealthResponse,
JobResponse,
JobStatus,
JobStatusResponse,
ProcessRequest,
ToolInfo,
ToolRequest,
ToolsResponse,
)
# Application metadata
app = FastAPI(
title="GNN Pipeline API",
description=(
"REST interface for the Generalized Notation Notation (GNN) processing pipeline. "
"Submit jobs, poll status, and invoke individual pipeline steps. "
"No authentication required — research tool for local use."
),
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# CORS for local browser access
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:*", "http://127.0.0.1:*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/api/v1/health", response_model=HealthResponse, tags=["Meta"])
async def health_check() -> HealthResponse:
"""Check API health and get basic system info."""
jobs = job_mgr.list_jobs()
active = sum(1 for j in jobs if j.get("status") in ("pending", "running"))
return HealthResponse(
status="healthy",
version="1.0.0",
pipeline_steps=len(job_mgr.PIPELINE_STEPS),
active_jobs=active,
timestamp=datetime.now()
)
@app.post("/api/v1/process", response_model=JobResponse, tags=["Jobs"])
async def submit_process_job(request: ProcessRequest, background_tasks: BackgroundTasks) -> JobResponse:
"""
Submit a GNN pipeline processing job.
Accepts a target directory and optional step selection.
Returns a job ID for polling with GET /api/v1/jobs/{job_id}.
"""
# Validate target directory exists
target_path = Path(request.target_dir)
if not target_path.exists():
# Try relative to repo root
repo_root = Path(__file__).parent.parent.parent
target_path = repo_root / request.target_dir
if not target_path.exists():
raise HTTPException(
status_code=400,
detail=f"Target directory not found: {request.target_dir}"
)
# Enforce path boundary: resolved path must stay within repo root
repo_root = Path(__file__).parent.parent.parent.resolve()
try:
resolved = target_path.resolve()
resolved.relative_to(repo_root)
except ValueError as err:
raise HTTPException(
status_code=400,
detail=f"Target directory must be within the repository root: {request.target_dir}"
) from err
job_id = job_mgr.create_job(
target_dir=str(target_path),
steps=request.steps,
skip_steps=request.skip_steps,
verbose=request.verbose,
strict=request.strict
)
# Launch async execution in background
background_tasks.add_task(job_mgr.execute_job_async, job_id)
job = job_mgr.get_job(job_id)
return JobResponse(
job_id=job_id,
status=JobStatus.PENDING,
created_at=datetime.fromisoformat(job["created_at"]),
steps_requested=request.steps,
message=f"Job {job_id} queued. Poll GET /api/v1/jobs/{job_id} for status."
)
@app.get("/api/v1/jobs/{job_id}", response_model=JobStatusResponse, tags=["Jobs"])
async def get_job_status(job_id: str) -> JobStatusResponse:
"""Poll the status of a submitted pipeline job."""
job = job_mgr.get_job(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job not found: {job_id}")
def _dt(s):
return datetime.fromisoformat(s) if s else None
return JobStatusResponse(
job_id=job["job_id"],
status=JobStatus(job["status"]),
created_at=_dt(job["created_at"]),
started_at=_dt(job.get("started_at")),
completed_at=_dt(job.get("completed_at")),
progress_step=job.get("progress_step"),
steps_completed=job.get("steps_completed", []),
steps_failed=job.get("steps_failed", []),
exit_code=job.get("exit_code"),
error_message=job.get("error_message"),
output_dir=job.get("output_dir")
)
@app.delete("/api/v1/jobs/{job_id}", tags=["Jobs"])
async def cancel_job(job_id: str) -> Dict[str, Any]:
"""Cancel a pending or running job."""
success = job_mgr.cancel_job(job_id)
if not success:
job = job_mgr.get_job(job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job not found: {job_id}")
raise HTTPException(
status_code=409,
detail=f"Job {job_id} is already in terminal state: {job['status']}"
)
return {"message": f"Job {job_id} cancelled"}
@app.get("/api/v1/jobs", tags=["Jobs"])
async def list_jobs(limit: int = 20) -> Dict[str, Any]:
"""List recent pipeline jobs."""
jobs = job_mgr.list_jobs(limit=limit)
return {"jobs": jobs, "total": len(jobs)}
@app.get("/api/v1/tools", response_model=ToolsResponse, tags=["Tools"])
async def list_tools() -> ToolsResponse:
"""List all available pipeline steps/tools."""
tools = [ToolInfo(**t) for t in job_mgr.get_pipeline_tools()]
return ToolsResponse(tools=tools, total=len(tools))
@app.post("/api/v1/tools/{step}", response_model=JobResponse, tags=["Tools"])
async def invoke_tool(step: int, request: ToolRequest, background_tasks: BackgroundTasks) -> JobResponse:
"""
Invoke a single pipeline step as a job.
Equivalent to submitting a process request with steps=[step].
"""
if step not in job_mgr.PIPELINE_STEPS:
raise HTTPException(status_code=404, detail=f"Unknown pipeline step: {step}")
target_path = Path(request.target_dir)
if not target_path.exists():
repo_root = Path(__file__).parent.parent.parent
target_path = repo_root / request.target_dir
if not target_path.exists():
raise HTTPException(status_code=400, detail=f"Target directory not found: {request.target_dir}")
# Enforce path boundary: resolved path must stay within repo root
_repo_root = Path(__file__).parent.parent.parent.resolve()
try:
target_path.resolve().relative_to(_repo_root)
except ValueError as err:
raise HTTPException(
status_code=400,
detail=f"Target directory must be within the repository root: {request.target_dir}"
) from err
job_id = job_mgr.create_job(
target_dir=str(target_path),
steps=[step],
verbose=request.verbose
)
background_tasks.add_task(job_mgr.execute_job_async, job_id)
step_name = job_mgr.PIPELINE_STEPS[step][0]
return JobResponse(
job_id=job_id,
status=JobStatus.PENDING,
created_at=datetime.now(),
steps_requested=[step],
message=f"Step {step} ({step_name}) queued as job {job_id}"
)
def run_server(host: str = "127.0.0.1", port: int = 8000, reload: bool = False):
"""Start the API server."""
if host not in ("127.0.0.1", "localhost"):
logger.warning(
"Binding to non-loopback address %s with no authentication — "
"ensure network-level access control is in place",
host,
)
uvicorn.run(
"api.server:app",
host=host,
port=port,
reload=reload,
log_level="info"
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="GNN Pipeline API Server")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--reload", action="store_true", help="Auto-reload on code changes")
args = parser.parse_args()
run_server(args.host, args.port, args.reload)