diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e1736b..259e79d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,11 +32,11 @@ jobs: run: | uv lock --locked uv run pre-commit run --all-files - uv run mypy app - uv run deptry app + uv run mypy app config + uv run deptry app config - name: Run tests with coverage - run: uv run pytest --cov=app --cov-report=xml --cov-report=term + run: uv run pytest --cov=app --cov=config --cov-report=xml --cov-report=term - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 diff --git a/Makefile b/Makefile index 25219f4..acfc803 100644 --- a/Makefile +++ b/Makefile @@ -49,9 +49,9 @@ check: @echo "🚀 Running pre-commit checks" @uv run pre-commit run -a @echo "🚀 Running static type checks (mypy)" - @uv run mypy app + @uv run mypy app config @echo "🚀 Checking for obsolete dependencies (deptry)" - @uv run deptry app + @uv run deptry app config test: @echo "🚀 Running tests with coverage" diff --git a/alembic/env.py b/alembic/env.py index 67b02d5..4cbd01f 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -1,87 +1,38 @@ -"""Alembic migration environment configuration.""" +"""Alembic migration environment configuration. -import os +Note: Memory-related tables are managed by memu-py (via ``ddl_mode: "create"``). +Alembic is kept here for any future server-specific schema migrations. + +Configuration note: + This env.py reuses the application's ``Settings`` class so that Alembic + reads the same ``POSTGRES_*`` / ``DATABASE_URL`` environment variables as + the running server. ``Settings.DATABASE_URL`` is already normalised to + ``postgresql+psycopg://`` by the ``assemble_db_url`` validator, which is + the synchronous driver required by Alembic. +""" # pylint: disable=no-member from logging.config import fileConfig -from urllib.parse import quote from sqlalchemy import pool from alembic import context - -# Import Base metadata for autogenerate support -# Note: We import from app.models.base which is side-effect free -# (doesn't create database connections or read environment variables) -from app.models.base import Base +from config.settings import Settings def get_sync_database_url() -> str: - """ - Get synchronous database URL for Alembic migrations. + """Return a synchronous database URL for Alembic migrations. - Alembic uses synchronous database connections, so we need to convert - the async URL (postgresql+asyncpg://) to sync format (postgresql+psycopg://). + Delegates to ``Settings`` so that the same ``POSTGRES_*`` / + ``DATABASE_URL`` environment variables used by the application are + honoured here as well. The ``assemble_db_url`` field-validator in + ``Settings`` already normalises the URL to ``postgresql+psycopg://``. Returns: - str: Synchronous database connection URL + Synchronous database connection URL suitable for SQLAlchemy. """ - # First try DATABASE_URL from environment - database_url = os.getenv("DATABASE_URL") - if database_url: - # Normalize common PostgreSQL DSNs to use the psycopg (sync) driver. - # Handle bare postgres:// and postgresql:// URLs that don't specify a driver. - if database_url.startswith("postgres://"): - # postgres://user:pass@host/db -> postgresql+psycopg://user:pass@host/db - database_url = "postgresql+psycopg://" + database_url[len("postgres://") :] - elif database_url.startswith("postgresql://") and not database_url.startswith("postgresql+"): - # postgresql://user:pass@host/db -> postgresql+psycopg://user:pass@host/db - database_url = "postgresql+psycopg://" + database_url[len("postgresql://") :] - - # Convert async driver to sync driver if needed - # postgresql+asyncpg:// -> postgresql+psycopg:// - if database_url.startswith("postgresql+asyncpg://"): - database_url = "postgresql+psycopg://" + database_url[len("postgresql+asyncpg://") :] - - return database_url - - # Construct from individual variables - db_host = os.getenv("DATABASE_HOST") - db_port = os.getenv("DATABASE_PORT", "5432") # Default PostgreSQL port - db_user = os.getenv("DATABASE_USER") - db_pass = os.getenv("DATABASE_PASSWORD") - db_name = os.getenv("DATABASE_NAME") - - # Validate required environment variables (consistent with app/database.py) - missing_vars = [ - name - for name, value in [ - ("DATABASE_HOST", db_host), - ("DATABASE_USER", db_user), - ("DATABASE_PASSWORD", db_pass), - ("DATABASE_NAME", db_name), - ] - if not value - ] - - if missing_vars: - raise RuntimeError( - f"Database configuration is incomplete. Missing environment variables: {', '.join(missing_vars)}" - ) - - # At this point, we know these are not None - assert db_host is not None - assert db_user is not None - assert db_pass is not None - assert db_name is not None - - # URL-encode username and password to handle special characters - # Use quote(..., safe="") instead of quote_plus() for URL userinfo section - db_user_encoded = quote(db_user, safe="") - db_pass_encoded = quote(db_pass, safe="") - - # Use psycopg (sync) for Alembic migrations - return f"postgresql+psycopg://{db_user_encoded}:{db_pass_encoded}@{db_host}:{db_port}/{db_name}" + settings = Settings() + return settings.DATABASE_URL # this is the Alembic Config object, which provides @@ -93,7 +44,20 @@ def get_sync_database_url() -> str: if config.config_file_name is not None: fileConfig(config.config_file_name) -target_metadata = Base.metadata +target_metadata = None +# NOTE: Alembic's autogenerate relies on `target_metadata` to discover the +# current schema from SQLAlchemy models. It is intentionally set to None +# because all current tables are managed externally by memu-py (see module +# docstring), so there is no server-specific SQLAlchemy metadata to inspect. +# +# Implications: +# * `alembic revision --autogenerate` will NOT detect any changes. +# * Any migrations must use explicit operations (e.g. op.create_table). +# +# When you introduce server-specific tables, define a SQLAlchemy +# MetaData / declarative Base and assign it here, for example: +# from myapp.models import Base +# target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, # can be acquired: diff --git a/app/database.py b/app/database.py deleted file mode 100644 index ace0607..0000000 --- a/app/database.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Database configuration and session management.""" - -import os -from urllib.parse import quote - -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine - -from app.models.base import Base - - -def get_database_url() -> str: - """ - Get database URL from environment variables. - - Priority: DATABASE_URL > constructed from individual variables - - Returns: - str: Database connection URL - - Raises: - RuntimeError: If required environment variables are missing - """ - database_url = os.getenv("DATABASE_URL") - if database_url: - # Normalize common PostgreSQL DSNs to use the psycopg async driver. - # Handle bare postgres:// and postgresql:// URLs that don't specify a driver. - if database_url.startswith("postgres://"): - # postgres://user:pass@host/db -> postgresql+psycopg://user:pass@host/db - database_url = "postgresql+psycopg://" + database_url[len("postgres://") :] - elif database_url.startswith("postgresql://") and not database_url.startswith("postgresql+"): - # postgresql://user:pass@host/db -> postgresql+psycopg://user:pass@host/db - database_url = "postgresql+psycopg://" + database_url[len("postgresql://") :] - - # Convert asyncpg driver to psycopg if needed (asyncpg is not a project dependency) - # postgresql+asyncpg:// -> postgresql+psycopg:// - if database_url.startswith("postgresql+asyncpg://"): - database_url = "postgresql+psycopg://" + database_url[len("postgresql+asyncpg://") :] - - return database_url - - # Construct from individual variables - db_host = os.getenv("DATABASE_HOST") - db_port = os.getenv("DATABASE_PORT", "5432") # Default PostgreSQL port - db_user = os.getenv("DATABASE_USER") - db_pass = os.getenv("DATABASE_PASSWORD") - db_name = os.getenv("DATABASE_NAME") - - # TODO: Improve validation to check for empty strings explicitly - # Current check 'if not value' treats empty string as missing - missing_vars = [ - name - for name, value in [ - ("DATABASE_HOST", db_host), - ("DATABASE_USER", db_user), - ("DATABASE_PASSWORD", db_pass), - ("DATABASE_NAME", db_name), - ] - if not value - ] - - if missing_vars: - raise RuntimeError( - f"Database configuration is incomplete. Missing environment variables: {', '.join(missing_vars)}" - ) - - # At this point, we know db_user, db_pass, db_host, db_name are not None - # Use assertion to help mypy understand this - assert db_user is not None - assert db_pass is not None - assert db_host is not None - assert db_name is not None - - # URL-encode username and password to handle special characters like '@', ':', '/' - # Use quote(..., safe="") instead of quote_plus() for URL userinfo section - db_user_encoded = quote(db_user, safe="") - db_pass_encoded = quote(db_pass, safe="") - - return f"postgresql+psycopg://{db_user_encoded}:{db_pass_encoded}@{db_host}:{db_port}/{db_name}" - - -# TODO: Consider lazy initialization to avoid executing during module import -# This would prevent database connection issues from failing tests that don't use the database -# Get database URL using the shared function -DATABASE_URL = get_database_url() - -# Create SQLAlchemy async engine -engine = create_async_engine( - DATABASE_URL, - pool_pre_ping=True, - pool_size=10, - max_overflow=20, -) -# Async session factory -# Note: autocommit is removed as it's not supported in SQLAlchemy 2.x async_sessionmaker -SessionLocal: async_sessionmaker[AsyncSession] = async_sessionmaker( - autoflush=False, - expire_on_commit=False, - bind=engine, -) -# Re-export Base for backward compatibility -__all__ = ["Base", "SessionLocal", "engine", "get_db", "get_database_url"] - - -async def get_db(): - """Dependency for FastAPI to get async database session.""" - async with SessionLocal() as db: - yield db diff --git a/app/main.py b/app/main.py index b59db47..39beab0 100644 --- a/app/main.py +++ b/app/main.py @@ -1,50 +1,57 @@ +"""memU Server - FastAPI application entry point.""" + import json -import os -import traceback +import logging import uuid +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from pathlib import Path from typing import Any -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse -from memu.app import MemoryService -from app.database import get_database_url +from app.services.memu import create_memory_service +from config.settings import Settings + +logger = logging.getLogger(__name__) -app = FastAPI(title="memU Server", version="0.1.0") +# Load settings from environment / .env +settings = Settings() -# Ensure required environment variables are set -openai_api_key = os.getenv("OPENAI_API_KEY") -if not openai_api_key: - raise RuntimeError( +if not settings.OPENAI_API_KEY.strip(): + # EM101/EM102: extract message to variable to satisfy ruff errmsg rules + msg = ( "OPENAI_API_KEY environment variable is not set or is empty. " "Set OPENAI_API_KEY to a valid OpenAI API key before starting the server." ) - -# Get database URL using shared configuration utility -database_url = get_database_url() - -service = MemoryService( - llm_profiles={ - "default": { - "provider": "openai", - "api_key": openai_api_key, - "base_url": os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"), - "model": os.getenv("DEFAULT_LLM_MODEL", "gpt-4o-mini"), - } - }, - database_config={"url": database_url}, -) + raise RuntimeError(msg) # Storage directory for conversation files -# Support both new STORAGE_PATH and legacy MEMU_STORAGE_DIR for backward compatibility -storage_dir = Path(os.getenv("STORAGE_PATH") or os.getenv("MEMU_STORAGE_DIR") or "./data") -storage_dir.mkdir(parents=True, exist_ok=True) +storage_dir = Path(settings.STORAGE_PATH) + + +@asynccontextmanager +async def lifespan(_app: FastAPI) -> AsyncIterator[None]: + """Initialise MemoryService on startup (defers DB connection until the app runs).""" + try: + storage_dir.mkdir(parents=True, exist_ok=True) + _app.state.service = create_memory_service(settings) + except Exception as exc: + # Log full traceback for operators and wrap in a clearer startup error + msg = "Failed to initialize MemoryService during application startup" + logger.exception(msg) + raise RuntimeError(msg) from exc + yield + + +app = FastAPI(title="memU Server", version="0.1.0", lifespan=lifespan) @app.post("/memorize") -async def memorize(payload: dict[str, Any]): +async def memorize(request: Request, payload: dict[str, Any]): try: + service = request.app.state.service file_path = storage_dir / f"conversation-{uuid.uuid4().hex}.json" with file_path.open("w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False) @@ -52,19 +59,21 @@ async def memorize(payload: dict[str, Any]): result = await service.memorize(resource_url=str(file_path), modality="conversation") return JSONResponse(content={"status": "success", "result": result}) except Exception as exc: - traceback.print_exc() - raise HTTPException(status_code=500, detail=str(exc)) from exc + logger.exception("Memorize request failed") + raise HTTPException(status_code=500, detail="Internal server error") from exc @app.post("/retrieve") -async def retrieve(payload: dict[str, Any]): +async def retrieve(request: Request, payload: dict[str, Any]): if "query" not in payload: raise HTTPException(status_code=400, detail="Missing 'query' in request body") try: + service = request.app.state.service result = await service.retrieve([payload["query"]]) return JSONResponse(content={"status": "success", "result": result}) except Exception as exc: - raise HTTPException(status_code=500, detail=str(exc)) from exc + logger.exception("Retrieve request failed") + raise HTTPException(status_code=500, detail="Internal server error") from exc @app.get("/") diff --git a/app/models/__init__.py b/app/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/models/base.py b/app/models/base.py deleted file mode 100644 index a0916be..0000000 --- a/app/models/base.py +++ /dev/null @@ -1,14 +0,0 @@ -"""SQLAlchemy Base class for model definitions. - -This module is intentionally side-effect-free - it only defines the Base class -without creating any database connections or reading environment variables. -This allows safe imports from alembic/env.py for migration autogeneration. -""" - -from sqlalchemy.orm import DeclarativeBase - - -class Base(DeclarativeBase): - """Base class for all SQLAlchemy models.""" - - pass diff --git a/app/services/memu.py b/app/services/memu.py new file mode 100644 index 0000000..e8df706 --- /dev/null +++ b/app/services/memu.py @@ -0,0 +1,37 @@ +"""MemU service factory for creating MemoryService instances.""" + +from typing import Any + +from memu.app import MemoryService + +from config.memu import build_memu_config +from config.settings import Settings + + +def create_memory_service( + settings: Settings | None = None, + memorize_config: dict[str, Any] | None = None, + retrieve_config: dict[str, Any] | None = None, +) -> MemoryService: + """Create a configured MemoryService instance. + + Args: + settings: Application settings. Uses default if not provided. + memorize_config: Optional memorize workflow config override. + retrieve_config: Optional retrieve workflow config override. + + Returns: + Configured MemoryService instance. + """ + if settings is None: + settings = Settings() + + memu_config = build_memu_config(settings) + + kwargs = {**memu_config} + if memorize_config: + kwargs["memorize_config"] = memorize_config + if retrieve_config: + kwargs["retrieve_config"] = retrieve_config + + return MemoryService(**kwargs) diff --git a/config/__init__.py b/config/__init__.py index e69de29..5d1518a 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -0,0 +1,11 @@ +"""Configuration management for memu-server.""" + +from .memu import MemUUser, build_memu_config, build_memu_llm_profiles +from .settings import Settings + +__all__ = [ + "Settings", + "MemUUser", + "build_memu_config", + "build_memu_llm_profiles", +] diff --git a/config/memu.py b/config/memu.py new file mode 100644 index 0000000..fdf76da --- /dev/null +++ b/config/memu.py @@ -0,0 +1,51 @@ +"""MemU configuration for memory service.""" + +from typing import Any + +from pydantic import BaseModel + +from config.settings import Settings + + +class MemUUser(BaseModel): + """User model for memu-py.""" + + user_id: str + agent_id: str | None = None + + +def build_memu_llm_profiles(settings: Settings) -> dict[str, Any]: + """Build LLM profiles for memu-py.""" + return { + "default": { + "api_key": settings.OPENAI_API_KEY, + "base_url": settings.OPENAI_BASE_URL, + "chat_model": settings.DEFAULT_LLM_MODEL, + }, + "embedding": { + "api_key": settings.EMBEDDING_API_KEY or settings.OPENAI_API_KEY, + "base_url": settings.EMBEDDING_BASE_URL, + "embed_model": settings.EMBEDDING_MODEL, + }, + } + + +def build_memu_config(settings: Settings) -> dict[str, Any]: + """Build memu-py core configuration. + + This configures memu-py to: + 1. Connect to PostgreSQL with pgvector + 2. Auto-create tables (ddl_mode: create) + 3. Use configured LLM profiles + """ + return { + "llm_profiles": build_memu_llm_profiles(settings), + "database_config": { + "metadata_store": { + "provider": "postgres", + "ddl_mode": "create", # Auto-create tables + "dsn": settings.DATABASE_URL, + } + }, + "user_config": {"model": MemUUser}, + } diff --git a/config/settings.py b/config/settings.py new file mode 100644 index 0000000..834c536 --- /dev/null +++ b/config/settings.py @@ -0,0 +1,82 @@ +"""Application settings for memu-server.""" + +from urllib.parse import quote + +from pydantic import ValidationInfo, field_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings loaded from environment variables. + + Values are resolved in order: init kwargs > environment variable > .env file > default. + """ + + # ── Database ── + POSTGRES_USER: str = "postgres" + POSTGRES_PASSWORD: str = "postgres" + POSTGRES_HOST: str = "localhost" + POSTGRES_PORT: int = 5432 + POSTGRES_DB: str = "memu" + DATABASE_URL: str = "" + + # ── LLM ── + # Empty defaults allow Settings() to be constructed in tests and service + # factories without requiring a live key. The non-empty check for + # OPENAI_API_KEY lives in app/main.py (RuntimeError with clear message). + OPENAI_API_KEY: str = "" + OPENAI_BASE_URL: str = "https://api.openai.com/v1" + DEFAULT_LLM_MODEL: str = "gpt-4o-mini" + + # ── Embedding ── + # Falls back to OPENAI_API_KEY in build_memu_llm_profiles when empty + EMBEDDING_API_KEY: str = "" + EMBEDDING_BASE_URL: str = "https://api.voyageai.com/v1" + EMBEDDING_MODEL: str = "voyage-3.5-lite" + + # ── Temporal ── + TEMPORAL_HOST: str = "localhost" + TEMPORAL_PORT: int = 7233 + TEMPORAL_NAMESPACE: str = "default" + + # ── Storage ── + STORAGE_PATH: str = "./data/storage" + + @field_validator("DATABASE_URL", mode="after") + @classmethod + def assemble_db_url(cls, v: str, info: ValidationInfo) -> str: + """Build DATABASE_URL from POSTGRES_* components when not explicitly set. + + When a URL is provided explicitly, common Postgres prefixes + (``postgresql://``, ``postgres://``, ``postgresql+asyncpg://``) + are normalised to ``postgresql+psycopg://`` so the correct + driver is always selected. + """ + if v.strip(): + # Normalise common DSN prefixes to the psycopg driver + for prefix in ("postgres://", "postgresql://", "postgresql+asyncpg://"): + if v.startswith(prefix): + return "postgresql+psycopg://" + v[len(prefix) :] + return v + # Preserve RFC 3986 sub-delimiters so characters like '!' are + # NOT percent-encoded. Percent-encoded values (e.g. %21) break + # memu-py's internal Alembic configparser which treats '%' as an + # interpolation character. + _sub_delims = "!$&'()*+,;=" + user = quote(info.data["POSTGRES_USER"], safe=_sub_delims) + password = quote(info.data["POSTGRES_PASSWORD"], safe=_sub_delims) + return ( + f"postgresql+psycopg://{user}:{password}" + f"@{info.data['POSTGRES_HOST']}:{info.data['POSTGRES_PORT']}/{info.data['POSTGRES_DB']}" + ) + + @property + def temporal_url(self) -> str: + return f"{self.TEMPORAL_HOST}:{self.TEMPORAL_PORT}" + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=True, + extra="ignore", + ) diff --git a/pyproject.toml b/pyproject.toml index 7ef794a..16e3712 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,26 +8,20 @@ requires-python = ">=3.13" dependencies = [ # Web Framework "fastapi[standard]>=0.122.0", - "memu-py>=1.2.0", "uvicorn[standard]>=0.35.0", - # Database & ORM - "sqlmodel>=0.0.27", - "sqlalchemy[asyncio]>=2.0.41", + # Memory Service (includes sqlmodel, openai, pendulum, etc.) + "memu-py[postgres]>=1.2.0", + + # Database driver (needed for connection URL construction) "psycopg[binary,pool]>=3.2.9", - "alembic>=1.16.2", - "pgvector>=0.3.2", # Workflow Engine "temporalio==1.16.0", - # LLM & AI - "openai>=1.54.4", - # Configuration & Utils "pydantic-settings>=2.10.1", "python-dotenv>=1.0.0", - "pendulum>=3.1.0", ] [dependency-groups] @@ -49,7 +43,7 @@ requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] -packages = ["app"] +packages = ["app", "config"] [tool.ruff] line-length = 120 @@ -76,7 +70,7 @@ asyncio_mode = "auto" testpaths = ["tests"] [tool.coverage.run] -source = ["app"] +source = ["app", "config"] omit = [ "*/tests/*", "*/__pycache__/*", @@ -110,24 +104,19 @@ enable = ["W0102", "W0212", "W0611"] # dangerous-default-value, protected-acces [tool.deptry.per_rule_ignores] DEP002 = [ "uvicorn", # Will be used for production deployment - "sqlmodel", # Planned for database models - "sqlalchemy", # Database ORM (used by sqlmodel) - "psycopg", # PostgreSQL adapter - "alembic", # Database migrations - "pgvector", # Vector database support + "psycopg", # PostgreSQL adapter used by memu-py[postgres] and alembic "temporalio", # Planned for workflow orchestration - "openai", # AI functionality (may be used indirectly via memu-py) - "pydantic-settings", # Configuration management (future use) - "python-dotenv", # Environment variables (future use) - "pendulum", # Date/time handling (future use) + "python-dotenv", # Environment variables (loaded by pydantic-settings) ] DEP003 = [ "app", # Project's own package, not an external dependency + "config", # Project's own config package + "pydantic", # Re-exported by pydantic-settings ] [tool.mypy] python_version = "3.13" -files = ["app"] +files = ["app", "config"] warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false # Gradually enable diff --git a/tests/test_env_validation.py b/tests/test_env_validation.py index f49bf2f..5dfadce 100644 --- a/tests/test_env_validation.py +++ b/tests/test_env_validation.py @@ -4,6 +4,7 @@ This approach avoids issues with shared module state and ensures clean test isolation. """ +import os import subprocess import sys from pathlib import Path @@ -23,8 +24,6 @@ def _run_import_test(env_vars: dict[str, str], remove_vars: list[str] | None = N Returns: CompletedProcess with returncode, stdout, and stderr """ - import os - # Start with a clean environment based on current env env = os.environ.copy() @@ -37,36 +36,28 @@ def _run_import_test(env_vars: dict[str, str], remove_vars: list[str] | None = N env.update(env_vars) # Run Python subprocess that imports app.main - result = subprocess.run( + return subprocess.run( # noqa: S603 [sys.executable, "-c", "from app.main import app; print(app.title)"], env=env, capture_output=True, text=True, cwd=str(PROJECT_ROOT), timeout=30, # Prevent tests from hanging indefinitely - ) - return result - - -def test_app_requires_openai_api_key(): - """Test that app refuses to start when OPENAI_API_KEY is not set.""" - result = _run_import_test( - env_vars={ - "DATABASE_URL": "postgresql+psycopg://test:test@localhost:5432/test", - }, - remove_vars=["OPENAI_API_KEY"], + check=False, ) - assert result.returncode != 0 - assert "OPENAI_API_KEY environment variable is not set or is empty" in result.stderr +def test_app_requires_openai_api_key(tmp_path): + """Test that app refuses to start when OPENAI_API_KEY is empty. -def test_app_refuses_empty_openai_api_key(): - """Test that app refuses to start when OPENAI_API_KEY is empty.""" + OPENAI_API_KEY defaults to empty string in Settings and is validated + at startup to be non-empty via a RuntimeError guard in main.py. + """ result = _run_import_test( env_vars={ "OPENAI_API_KEY": "", - "DATABASE_URL": "postgresql+psycopg://test:test@localhost:5432/test", + "EMBEDDING_API_KEY": "test", + "STORAGE_PATH": str(tmp_path / "storage"), }, ) @@ -74,43 +65,13 @@ def test_app_refuses_empty_openai_api_key(): assert "OPENAI_API_KEY environment variable is not set or is empty" in result.stderr -def test_app_requires_database_url(): - """Test that app refuses to start when DATABASE_URL is not set.""" - result = _run_import_test( - env_vars={ - "OPENAI_API_KEY": "test-key", - }, - remove_vars=["DATABASE_URL", "DATABASE_HOST", "DATABASE_USER", "DATABASE_PASSWORD", "DATABASE_NAME"], - ) - - assert result.returncode != 0 - assert "Database configuration is incomplete" in result.stderr - - -def test_app_with_individual_db_vars(): - """Test that app starts with individual DATABASE_* variables.""" - result = _run_import_test( - env_vars={ - "OPENAI_API_KEY": "test-key", - "DATABASE_HOST": "localhost", - "DATABASE_PORT": "54320", - "DATABASE_USER": "test_user", - "DATABASE_PASSWORD": "test_pass", - "DATABASE_NAME": "test_db", - }, - remove_vars=["DATABASE_URL"], - ) - - assert result.returncode == 0 - assert "memU Server" in result.stdout - - -def test_app_starts_with_valid_openai_api_key(): +def test_app_starts_with_valid_openai_api_key(tmp_path): """Test that app starts successfully with valid OPENAI_API_KEY.""" result = _run_import_test( env_vars={ "OPENAI_API_KEY": "test-valid-key", - "DATABASE_URL": "postgresql+psycopg://test:test@localhost:5432/test", + "EMBEDDING_API_KEY": "test-embed-key", + "STORAGE_PATH": str(tmp_path / "storage"), }, ) diff --git a/tests/test_health.py b/tests/test_health.py index 625ce4b..d2ed8ea 100644 --- a/tests/test_health.py +++ b/tests/test_health.py @@ -1,5 +1,7 @@ """Tests for the application's root ("/") endpoint.""" +from http import HTTPStatus + import pytest from fastapi.testclient import TestClient @@ -8,17 +10,17 @@ def client(): """Create FastAPI test client with proper env setup.""" try: - from app.main import app + from app.main import app # noqa: PLC0415 return TestClient(app) - except Exception as exc: + except (RuntimeError, ImportError) as exc: # lifespan wraps errors in RuntimeError; import may raise ImportError pytest.skip(f"Could not initialize test client due to application setup error: {exc}") def test_root_endpoint(client): """Test root endpoint returns welcome message.""" response = client.get("/") - assert response.status_code == 200 + assert response.status_code == HTTPStatus.OK data = response.json() assert "message" in data assert data["message"] == "Hello MemU user!" diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 970d289..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Basic tests for the application. - -Note: Full integration tests with FastAPI TestClient will be added -as the project evolves. Currently using placeholder tests to ensure -CI pipeline runs successfully. -""" - -import pytest - - -def test_placeholder(): - """Placeholder test to ensure pytest runs successfully. - - This test will be replaced with actual integration tests - as features are implemented. - """ - assert True - - -def test_imports(): - """Test that main application modules can be imported.""" - try: - from app import main - - assert hasattr(main, "app") - assert hasattr(main, "service") - except Exception as e: - pytest.skip(f"Import test skipped due to compatibility issue: {e}") diff --git a/tests/test_memu_config.py b/tests/test_memu_config.py new file mode 100644 index 0000000..8cefa10 --- /dev/null +++ b/tests/test_memu_config.py @@ -0,0 +1,126 @@ +"""Tests for memu configuration.""" + +from config.memu import MemUUser, build_memu_config, build_memu_llm_profiles +from config.settings import Settings + + +def test_memu_user_model(): + """Test MemUUser model.""" + user = MemUUser(user_id="user123", agent_id="agent456") + assert user.user_id == "user123" + assert user.agent_id == "agent456" + + +def test_memu_user_optional_agent(): + """Test MemUUser with optional agent_id.""" + user = MemUUser(user_id="user123") + assert user.user_id == "user123" + assert user.agent_id is None + + +def test_build_memu_llm_profiles(): + """Test building LLM profiles.""" + settings = Settings() + profiles = build_memu_llm_profiles(settings) + + assert "default" in profiles + assert "embedding" in profiles + assert "api_key" in profiles["default"] + assert "chat_model" in profiles["default"] + assert "embed_model" in profiles["embedding"] + + +def test_build_memu_llm_profiles_uses_env_values(monkeypatch): + """Test that LLM profiles pick up values from environment variables.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key") + monkeypatch.setenv("DEFAULT_LLM_MODEL", "gpt-4o") + monkeypatch.setenv("EMBEDDING_API_KEY", "embed-key") + + settings = Settings() + profiles = build_memu_llm_profiles(settings) + + assert profiles["default"]["api_key"] == "sk-test-key" + assert profiles["default"]["chat_model"] == "gpt-4o" + assert profiles["embedding"]["api_key"] == "embed-key" + + +def test_build_memu_llm_profiles_embedding_fallback(): + """Test that embedding profile falls back to OPENAI_API_KEY when EMBEDDING_API_KEY is empty.""" + settings = Settings(OPENAI_API_KEY="sk-openai", EMBEDDING_API_KEY="") + profiles = build_memu_llm_profiles(settings) + + assert profiles["embedding"]["api_key"] == "sk-openai" + + +def test_build_memu_config(): + """Test building complete memu config.""" + settings = Settings() + config = build_memu_config(settings) + + assert "llm_profiles" in config + assert "database_config" in config + assert "user_config" in config + + # Check database config + db_config = config["database_config"]["metadata_store"] + assert db_config["provider"] == "postgres" + assert db_config["ddl_mode"] == "create" + assert db_config["dsn"] is not None + + # Check user config + assert config["user_config"]["model"] == MemUUser + + +def test_memu_config_database_url(monkeypatch): + """Test that memu config assembles database URL from POSTGRES_* components.""" + monkeypatch.delenv("DATABASE_URL", raising=False) + + settings = Settings(DATABASE_URL="") + config = build_memu_config(settings) + + dsn = config["database_config"]["metadata_store"]["dsn"] + assert "postgresql+psycopg" in dsn + assert settings.POSTGRES_DB in dsn + + +def test_memu_config_database_url_override(): + """Test that an explicit DATABASE_URL overrides POSTGRES_* assembly.""" + explicit_url = "postgresql+psycopg://custom:pass@remote:5432/prod" + settings = Settings(DATABASE_URL=explicit_url) + config = build_memu_config(settings) + + dsn = config["database_config"]["metadata_store"]["dsn"] + assert dsn == explicit_url + + +def test_memu_config_database_url_password_encoding(monkeypatch): + """Test that assembled DATABASE_URL properly encodes special characters in password.""" + monkeypatch.setenv("POSTGRES_USER", "testuser") + monkeypatch.setenv("POSTGRES_PASSWORD", "p@ss:word") + monkeypatch.setenv("POSTGRES_DB", "testdb") + # Clear DATABASE_URL so the validator assembles from POSTGRES_* components + monkeypatch.delenv("DATABASE_URL", raising=False) + + settings = Settings(DATABASE_URL="") + config = build_memu_config(settings) + + dsn = config["database_config"]["metadata_store"]["dsn"] + # The raw password with special characters should not appear in the DSN + assert "p@ss:word" not in dsn + # URL-encoded representations of '@' and ':' should be present + assert "%40" in dsn + assert "%3A" in dsn + + +def test_database_url_prefix_normalisation(): + """Test that common Postgres DSN prefixes are normalised to postgresql+psycopg://.""" + for prefix in ("postgres://", "postgresql://", "postgresql+asyncpg://"): + url = f"{prefix}user:pass@host:5432/db" + settings = Settings(DATABASE_URL=url) + assert settings.DATABASE_URL.startswith("postgresql+psycopg://") + assert settings.DATABASE_URL.endswith("user:pass@host:5432/db") + + # Already correct prefix should remain unchanged + correct = "postgresql+psycopg://user:pass@host:5432/db" + settings = Settings(DATABASE_URL=correct) + assert settings.DATABASE_URL == correct diff --git a/tests/test_memu_service.py b/tests/test_memu_service.py new file mode 100644 index 0000000..382e7ba --- /dev/null +++ b/tests/test_memu_service.py @@ -0,0 +1,40 @@ +"""Tests for the MemoryService factory.""" + +from unittest.mock import MagicMock, patch + +from app.services.memu import create_memory_service +from config.settings import Settings + + +@patch("app.services.memu.MemoryService") +def test_create_memory_service_returns_instance(mock_cls): + """Test that create_memory_service returns a MemoryService.""" + mock_cls.return_value = MagicMock() + settings = Settings(OPENAI_API_KEY="test-key") + service = create_memory_service(settings) + mock_cls.assert_called_once() + assert service is mock_cls.return_value + + +@patch("app.services.memu.MemoryService") +def test_create_memory_service_default_settings(mock_cls): + """Test that create_memory_service works with default Settings.""" + mock_cls.return_value = MagicMock() + service = create_memory_service() + mock_cls.assert_called_once() + assert service is mock_cls.return_value + + +@patch("app.services.memu.MemoryService") +def test_create_memory_service_with_overrides(mock_cls): + """Test that create_memory_service passes optional config overrides.""" + mock_cls.return_value = MagicMock() + settings = Settings(OPENAI_API_KEY="test-key") + create_memory_service( + settings, + memorize_config={"some_option": True}, + retrieve_config={"another_option": 42}, + ) + call_kwargs = mock_cls.call_args.kwargs + assert call_kwargs["memorize_config"] == {"some_option": True} + assert call_kwargs["retrieve_config"] == {"another_option": 42} diff --git a/uv.lock b/uv.lock index 2588958..a74654f 100644 --- a/uv.lock +++ b/uv.lock @@ -728,22 +728,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/24/4d/c05457130a196904f5a8989dffa6e8a5cd0c29021178c082e7e2e5d0725b/memu_py-1.2.0-cp313-abi3-win_amd64.whl", hash = "sha256:d766ff2929e4e7d06e3ab2f4d0b436413a1bea98af0a2236164140d3f43515b7", size = 241934, upload-time = "2026-01-14T15:00:22.303Z" }, ] +[package.optional-dependencies] +postgres = [ + { name = "pgvector" }, + { name = "sqlalchemy", extra = ["postgresql-psycopgbinary"] }, +] + [[package]] name = "memu-server" version = "0.1.0" source = { editable = "." } dependencies = [ - { name = "alembic" }, { name = "fastapi", extra = ["standard"] }, - { name = "memu-py" }, - { name = "openai" }, - { name = "pendulum" }, - { name = "pgvector" }, + { name = "memu-py", extra = ["postgres"] }, { name = "psycopg", extra = ["binary", "pool"] }, { name = "pydantic-settings" }, { name = "python-dotenv" }, - { name = "sqlalchemy", extra = ["asyncio"] }, - { name = "sqlmodel" }, { name = "temporalio" }, { name = "uvicorn", extra = ["standard"] }, ] @@ -764,17 +764,11 @@ dev = [ [package.metadata] requires-dist = [ - { name = "alembic", specifier = ">=1.16.2" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.122.0" }, - { name = "memu-py", specifier = ">=1.2.0" }, - { name = "openai", specifier = ">=1.54.4" }, - { name = "pendulum", specifier = ">=3.1.0" }, - { name = "pgvector", specifier = ">=0.3.2" }, + { name = "memu-py", extras = ["postgres"], specifier = ">=1.2.0" }, { name = "psycopg", extras = ["binary", "pool"], specifier = ">=3.2.9" }, { name = "pydantic-settings", specifier = ">=2.10.1" }, { name = "python-dotenv", specifier = ">=1.0.0" }, - { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" }, - { name = "sqlmodel", specifier = ">=0.0.27" }, { name = "temporalio", specifier = "==1.16.0" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.35.0" }, ] @@ -1510,8 +1504,8 @@ wheels = [ ] [package.optional-dependencies] -asyncio = [ - { name = "greenlet" }, +postgresql-psycopgbinary = [ + { name = "psycopg", extra = ["binary"] }, ] [[package]]