Skip to content

Commit 735fdc4

Browse files
author
shieldx-bot
committed
ci(test): Add ML CI tests + load test harness
- GitHub Actions workflow to install requirements-test and run full suite - Asyncio and Locust-based HTTP benchmarks for concurrency/latency - Fix dl_service.py training endpoint syntax and validation - Fix validate_code.py print statement - Add README_TESTING with steps and performance targets
1 parent b9065cb commit 735fdc4

File tree

6 files changed

+339
-14
lines changed

6 files changed

+339
-14
lines changed

.github/workflows/ml-tests.yml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
name: ML Service Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
tests:
11+
runs-on: ubuntu-latest
12+
timeout-minutes: 30
13+
14+
steps:
15+
- name: Checkout repository
16+
uses: actions/checkout@v4
17+
18+
- name: Set up Python 3.11
19+
uses: actions/setup-python@v5
20+
with:
21+
python-version: '3.11'
22+
23+
- name: Install system deps
24+
run: |
25+
sudo apt-get update
26+
sudo apt-get install -y python3-venv
27+
28+
- name: Install Python dependencies
29+
run: |
30+
python -m pip install --upgrade pip
31+
pip install -r services/shieldx-ml/requirements-test.txt
32+
33+
- name: Run ML test suite
34+
working-directory: services/shieldx-ml
35+
run: |
36+
python run_tests.py
37+
38+
- name: Upload test report
39+
if: always()
40+
uses: actions/upload-artifact@v4
41+
with:
42+
name: ml-test-report
43+
path: services/shieldx-ml/test_report.json
44+
if-no-files-found: warn
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# ShieldX ML Service - Testing & Performance
2+
3+
This doc explains how to validate the ML service correctness and evaluate performance targets.
4+
5+
## Unit & Integration Tests
6+
7+
- Full suite exists under `services/shieldx-ml/tests` and `services/shieldx-ml/ml-service/tests`.
8+
- Run locally (requires Python 3.11 and dependencies):
9+
10+
```bash
11+
cd services/shieldx-ml
12+
python3 -m venv .venv && source .venv/bin/activate
13+
pip install -r requirements-test.txt
14+
python run_tests.py
15+
```
16+
17+
- Or run quick static validation without heavy deps:
18+
19+
```bash
20+
python3 validate_code.py
21+
```
22+
23+
- CI runs these automatically via GitHub Actions: `.github/workflows/ml-tests.yml`.
24+
25+
## Load & Concurrency Tests
26+
27+
Targets:
28+
- 10,000 concurrent requests
29+
- < 100ms latency (p99)
30+
- 99% detection rate (recall)
31+
32+
Tools provided:
33+
- `tools/locustfile.py`: Locust user model for HTTP inference.
34+
- `tools/bench_http.py`: asyncio benchmark for high-concurrency testing.
35+
36+
Example (asyncio benchmark):
37+
```bash
38+
# Start the DL service first (port 8001)
39+
python3 ml-service/dl_service.py &
40+
# In another terminal:
41+
python3 tools/bench_http.py --base-url http://localhost:8001 --model autoencoder_demo --concurrency 1000 --rpc 10
42+
```
43+
44+
Example (Locust):
45+
```bash
46+
pip install locust
47+
locust -f tools/locustfile.py --host http://localhost:8001
48+
```
49+
50+
## Production Readiness Notes
51+
52+
- For 10k concurrent with <100ms p99, deploy with:
53+
- Gunicorn + Uvicorn workers or ASGI (Quart/FastAPI) for async IO
54+
- Enable dynamic batching and GPU via `inference_engine.py`
55+
- Redis-backed cache and model warmup
56+
- Horizontal autoscaling (K8s HPA) and a gateway (HAProxy/NGINX) with keep-alive
57+
- Consider Triton Inference Server + TensorRT for GPU acceleration
58+
59+
- Detection 99% requires calibrated thresholds and balanced datasets; validate using `evaluate` endpoint with real distributions.
60+
61+
- See docs/ML_MASTER_ROADMAP.md for completed optimization and monitoring features.

services/shieldx-ml/ml-service/dl_service.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -149,23 +149,23 @@ def train_model(model_name: str):
149149
# Supervised models need labels
150150
if training_labels is None:
151151
return jsonify({'error': f'{model_type} requires training_labels'}), 400
152-
model.fit(training_data, training_labels, **training_params)
152+
model.fit(
153+
training_data,
154+
training_labels,
155+
epochs=training_params.get('epochs', 100),
156+
batch_size=training_params.get('batch_size', 64),
157+
validation_split=training_params.get('validation_split', 0.2),
158+
early_stopping_patience=training_params.get('early_stopping_patience', 10)
159+
)
153160
else:
154161
# Unsupervised models
155-
model.fit(training_data, **training_params)
162+
model.fit(
163+
training_data,
164+
epochs=training_params.get('epochs', 100),
165+
batch_size=training_params.get('batch_size', 256),
166+
validation_split=training_params.get('validation_split', 0.2),
167+
early_stopping_patience=training_params.get('early_stopping_patience', 10)
156168
)
157-
else:
158-
return jsonify({'error': f'Unknown model type: {model_type}'}), 400
159-
160-
# Train model
161-
logger.info(f"Training {model_type} model: {model_name}")
162-
model.fit(
163-
training_data,
164-
epochs=training_params.get('epochs', 100),
165-
batch_size=training_params.get('batch_size', 256 if model_type == 'autoencoder' else 64),
166-
validation_split=training_params.get('validation_split', 0.2),
167-
early_stopping_patience=training_params.get('early_stopping_patience', 10)
168-
)
169169

170170
# Save model
171171
model_path = os.path.join(MODEL_DIR, f"{model_name}.pt")
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python3
2+
import asyncio
3+
import aiohttp
4+
import numpy as np
5+
import json
6+
import time
7+
from statistics import mean
8+
9+
10+
async def predict(session, url, batch=32, input_dim=50):
11+
data = np.random.randn(batch, input_dim).tolist()
12+
payload = {"data": data, "return_proba": False}
13+
async with session.post(url, json=payload) as resp:
14+
await resp.text()
15+
return resp.status
16+
17+
18+
async def run_benchmark(base_url: str, model_name: str, concurrency: int = 1000, requests_per_client: int = 10):
19+
url = f"{base_url}/models/{model_name}/predict"
20+
timeout = aiohttp.ClientTimeout(total=30)
21+
connector = aiohttp.TCPConnector(limit=0)
22+
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
23+
latencies = []
24+
25+
async def worker():
26+
for _ in range(requests_per_client):
27+
start = time.perf_counter_ns()
28+
status = await predict(session, url)
29+
end = time.perf_counter_ns()
30+
if status == 200:
31+
latencies.append((end - start) / 1e6) # ms
32+
33+
tasks = [asyncio.create_task(worker()) for _ in range(concurrency)]
34+
t0 = time.time()
35+
await asyncio.gather(*tasks)
36+
elapsed = time.time() - t0
37+
38+
if not latencies:
39+
print("No successful requests.")
40+
return
41+
42+
latencies.sort()
43+
p50 = latencies[int(0.50 * len(latencies))]
44+
p90 = latencies[int(0.90 * len(latencies))]
45+
p99 = latencies[int(0.99 * len(latencies))]
46+
print(f"Requests: {len(latencies)} in {elapsed:.2f}s, RPS={len(latencies)/elapsed:.1f}")
47+
print(f"Latency ms: p50={p50:.2f}, p90={p90:.2f}, p99={p99:.2f}")
48+
49+
50+
if __name__ == "__main__":
51+
import argparse
52+
parser = argparse.ArgumentParser()
53+
parser.add_argument("--base-url", default="http://localhost:8001")
54+
parser.add_argument("--model", default="autoencoder_demo")
55+
parser.add_argument("--concurrency", type=int, default=1000)
56+
parser.add_argument("--rpc", type=int, default=10, help="requests per client")
57+
args = parser.parse_args()
58+
asyncio.run(run_benchmark(args.base_url, args.model, args.concurrency, args.rpc))
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from locust import HttpUser, task, between
2+
import json
3+
import numpy as np
4+
5+
6+
class InferenceUser(HttpUser):
7+
wait_time = between(0.001, 0.01)
8+
9+
def on_start(self):
10+
# Prepare a small random payload (adjust input_dim as per model)
11+
self.input_dim = 50
12+
self.batch = 32
13+
self.model_name = "autoencoder_demo"
14+
15+
@task(3)
16+
def predict(self):
17+
data = np.random.randn(self.batch, self.input_dim).tolist()
18+
payload = {"data": data, "return_proba": False}
19+
self.client.post(f"/models/{self.model_name}/predict", data=json.dumps(payload), headers={"Content-Type": "application/json"})
20+
21+
@task(1)
22+
def health(self):
23+
self.client.get("/health")
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Quick validation script for ShieldX ML Service
4+
Checks Python syntax and basic imports without running full tests
5+
"""
6+
7+
import sys
8+
import os
9+
from pathlib import Path
10+
import py_compile
11+
import importlib.util
12+
13+
class Colors:
14+
GREEN = '\033[0;32m'
15+
RED = '\033[0;31m'
16+
YELLOW = '\033[1;33m'
17+
BLUE = '\033[0;34m'
18+
BOLD = '\033[1m'
19+
NC = '\033[0m'
20+
21+
def print_header(text: str):
22+
print(f"\n{Colors.BOLD}{Colors.BLUE}{'='*60}{Colors.NC}")
23+
print(f"{Colors.BOLD}{Colors.BLUE}{text:^60}{Colors.NC}")
24+
print(f"{Colors.BOLD}{Colors.BLUE}{'='*60}{Colors.NC}\n")
25+
26+
def check_syntax(file_path: Path) -> bool:
27+
"""Check Python file syntax"""
28+
try:
29+
py_compile.compile(str(file_path), doraise=True)
30+
return True
31+
except py_compile.PyCompileError as e:
32+
print(f"{Colors.RED}✗ Syntax Error:{Colors.NC} {file_path}")
33+
print(f" {e}")
34+
return False
35+
36+
def check_imports(file_path: Path) -> bool:
37+
"""Check if file can be imported (basic check)"""
38+
try:
39+
with open(file_path) as f:
40+
content = f.read()
41+
42+
# Check for common syntax issues
43+
if 'import' in content:
44+
# Basic validation
45+
lines = content.split('\n')
46+
for i, line in enumerate(lines, 1):
47+
stripped = line.strip()
48+
if stripped.startswith('import ') or stripped.startswith('from '):
49+
# Check for basic import syntax
50+
if stripped.endswith('\\'):
51+
continue # Multi-line import
52+
if 'import' in stripped and not any(x in stripped for x in ['(', ',', 'as']):
53+
# Simple import
54+
parts = stripped.split()
55+
if len(parts) < 2:
56+
print(f"{Colors.YELLOW}⚠ Warning:{Colors.NC} Line {i}: {stripped}")
57+
58+
return True
59+
except Exception as e:
60+
print(f"{Colors.RED}✗ Import Check Failed:{Colors.NC} {file_path}")
61+
print(f" {e}")
62+
return False
63+
64+
def validate_directory(directory: str, pattern: str = "**/*.py") -> tuple:
65+
"""Validate all Python files in directory"""
66+
path = Path(directory)
67+
if not path.exists():
68+
print(f"{Colors.YELLOW}Directory not found: {directory}{Colors.NC}")
69+
return 0, 0
70+
71+
print(f"\n{Colors.BOLD}Validating: {directory}{Colors.NC}")
72+
print("-" * 60)
73+
74+
files = list(path.glob(pattern))
75+
if not files:
76+
print(f"{Colors.YELLOW}No Python files found{Colors.NC}")
77+
return 0, 0
78+
79+
passed = 0
80+
failed = 0
81+
82+
for file in sorted(files):
83+
# Skip __pycache__ and venv
84+
if '__pycache__' in str(file) or 'venv' in str(file):
85+
continue
86+
87+
# Check syntax
88+
if check_syntax(file):
89+
# Check imports
90+
if check_imports(file):
91+
print(f"{Colors.GREEN}{Colors.NC} {file.relative_to(path.parent)}")
92+
passed += 1
93+
else:
94+
failed += 1
95+
else:
96+
failed += 1
97+
98+
return passed, failed
99+
100+
def main():
101+
print_header("ShieldX ML Service - Quick Validation")
102+
103+
os.chdir('/home/vananh/shieldx/services/shieldx-ml')
104+
print(f"Working directory: {os.getcwd()}\n")
105+
106+
total_passed = 0
107+
total_failed = 0
108+
109+
# Validate main ML service code
110+
directories = [
111+
'ml-service',
112+
'tests',
113+
'ml-service/tests'
114+
]
115+
116+
for directory in directories:
117+
if os.path.exists(directory):
118+
passed, failed = validate_directory(directory)
119+
total_passed += passed
120+
total_failed += failed
121+
122+
# Summary
123+
print_header("Validation Summary")
124+
print(f"Total Files: {total_passed + total_failed}")
125+
print(f"{Colors.GREEN}Passed: {total_passed}{Colors.NC}")
126+
print(f"{Colors.RED}Failed: {total_failed}{Colors.NC}")
127+
128+
if total_failed == 0:
129+
print(f"\n{Colors.GREEN}{Colors.BOLD}✓ ALL FILES VALID!{Colors.NC}")
130+
# Guidance for running full tests
131+
print(f"\n{Colors.YELLOW}Note: Full test execution requires installing test dependencies.{Colors.NC}")
132+
print(f"{Colors.YELLOW}Run: pip3 install -r requirements-test.txt{Colors.NC}")
133+
return 0
134+
else:
135+
print(f"\n{Colors.RED}{Colors.BOLD}✗ VALIDATION FAILED{Colors.NC}")
136+
return 1
137+
138+
if __name__ == '__main__':
139+
sys.exit(main())

0 commit comments

Comments
 (0)