From cb744be9209499822112f2dc22acfa04381c7208 Mon Sep 17 00:00:00 2001 From: StrongWind1 <5987034+StrongWind1@users.noreply.github.com> Date: Wed, 18 Mar 2026 00:39:23 -0400 Subject: [PATCH 1/4] fix: database concurrency, thread safety, and code quality Thread safety: - Use scoped_session as a thread-local registry (property, not shared instance) so each handler thread gets its own Session and connection - Add _release() via _scoped_session.remove() in try/finally on every public method to return connections to the pool immediately - Atomic duplicate check + INSERT inside one lock scope to prevent TOCTOU race condition - Rollback on every error path to prevent PendingRollbackError cascade - Catch PoolTimeoutError gracefully (warning + skip, hash in file stream) - Display logging gated on successful DB write Code quality: - Extract _check_duplicate() and _log_credential() from 170-line add_auth() - Add _handle_db_error() shared helper for duplicated try/except - Fix add_host_extra() lock: contextlib.nullcontext replaces manual acquire/release with bool flag - Fix add_host_extra() connection leak on standalone calls (missing _release) - Fix add_host() unconditional commit (only when values change) - Fix missing commit() in add_host_extra() update branch - Fix db_path for :memory: (was str(None)) - expire_on_commit=False so ORM objects survive _release() Engine configuration (per SQLAlchemy 2.0 docs): - :memory: -> StaticPool, SQLite file -> QueuePool, MySQL -> QueuePool - skip_autocommit_rollback + pool_reset_on_return=None - pool_use_lifo for natural idle-connection expiry - Drop deprecated future=True, dead init_dementor_db(), dead config fields - Fix :memory: URL (was missing third slash) - Rename db_raw_path -> db_url Constants: - _CLEARTEXT/_NO_USER/_HOST_INFO -> CLEARTEXT/NO_USER/HOST_INFO with backward-compatible aliases --- dementor/db/__init__.py | 24 +- dementor/db/connector.py | 163 +++++++---- dementor/db/model.py | 604 +++++++++++++++++++++++++++------------ 3 files changed, 552 insertions(+), 239 deletions(-) diff --git a/dementor/db/__init__.py b/dementor/db/__init__.py index 92d02d7..8c1d473 100644 --- a/dementor/db/__init__.py +++ b/dementor/db/__init__.py @@ -18,18 +18,36 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""Dementor database package -- constants, helpers, and ORM models. + +Provides the :class:`~dementor.db.model.DementorDB` wrapper for thread-safe +credential storage, the :class:`~dementor.db.connector.DatabaseConfig` for +``[DB]`` TOML configuration, and engine initialization via +:func:`~dementor.db.connector.create_db`. +""" + +__all__ = ["CLEARTEXT", "HOST_INFO", "NO_USER", "normalize_client_address"] + # --------------------------------------------------------------------------- # # Public constants # --------------------------------------------------------------------------- # -_CLEARTEXT = "Cleartext" +CLEARTEXT = "Cleartext" """Constant indicating plaintext credentials (as opposed to hashes).""" -_NO_USER = "" +NO_USER = "" """Placeholder string used when username is absent or invalid in credential logging.""" -_HOST_INFO = "_host_info" +HOST_INFO = "_host_info" """Key used in extras dict to store host information for credential logging.""" +# Backward-compatible aliases so existing imports like +# from dementor.db import _CLEARTEXT +# keep working without a mass-rename across all protocol files. +# New code should use the unprefixed names above. +_CLEARTEXT = CLEARTEXT +_NO_USER = NO_USER +_HOST_INFO = HOST_INFO + def normalize_client_address(client: str) -> str: """Normalize IPv6-mapped IPv4 addresses by stripping IPv6 prefix. diff --git a/dementor/db/connector.py b/dementor/db/connector.py index 1ce73ba..acc1ed2 100644 --- a/dementor/db/connector.py +++ b/dementor/db/connector.py @@ -18,70 +18,56 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # pyright: reportUninitializedInstanceVariable=false +"""Database engine initialization and configuration. + +Reads the ``[DB]`` TOML section via :class:`DatabaseConfig`, builds a +SQLAlchemy :class:`~sqlalchemy.engine.Engine` with backend-specific pool +settings, and exposes :func:`create_db` as the single entry point used +by :func:`~dementor.standalone.serve` at startup. +""" + import typing +from typing import Any from sqlalchemy import Engine, create_engine +from sqlalchemy.pool import StaticPool from dementor.config.session import SessionConfig -from dementor.db.model import DementorDB, ModelBase +from dementor.db.model import DementorDB from dementor.log.logger import dm_logger from dementor.config.toml import TomlConfig, Attribute as A class DatabaseConfig(TomlConfig): - """ - Configuration mapping for the ``[DB]`` TOML section. + """Configuration mapping for the ``[DB]`` TOML section. - The attributes correspond to the most common SQLAlchemy connection - parameters. All fields are optional - sensible defaults are applied - when a key is missing. + Users set EITHER ``Url`` (a full SQLAlchemy DSN for any backend, + e.g. ``mysql+pymysql://user:pass@host/db``) OR ``Path`` (a file + path for the default SQLite backend, e.g. ``Dementor.db``). + + When ``Url`` is omitted, ``Path`` is resolved relative to the + session workspace and wrapped into a ``sqlite+pysqlite://`` URL. """ _section_: typing.ClassVar[str] = "DB" _fields_: typing.ClassVar[list[A]] = [ - A("db_raw_path", "Url", None), + A("db_url", "Url", None), A("db_path", "Path", "Dementor.db"), A("db_duplicate_creds", "DuplicateCreds", False), - A("db_dialect", "Dialect", None), - A("db_driver", "Driver", None), ] if typing.TYPE_CHECKING: # pragma: no cover - only for static analysis - db_raw_path: str | None + db_url: str | None db_path: str db_duplicate_creds: bool - db_dialect: str | None - db_driver: str | None - - -def init_dementor_db(session: SessionConfig) -> Engine | None: - """ - Initialise the database engine and create all tables. - - :param session: The active :class:`~dementor.config.session.SessionConfig` - containing the ``db_config`` attribute. - :type session: SessionConfig - :return: The created SQLAlchemy ``Engine`` or ``None`` if an error - prevented initialisation. - :rtype: Engine | None - """ - engine = init_engine(session) - if engine is not None: - ModelBase.metadata.create_all(engine) - return engine def init_engine(session: SessionConfig) -> Engine | None: - """ - Build a SQLAlchemy ``Engine`` from a :class:`DatabaseConfig`. - - The logic follows the rules laid out in the SQLAlchemy documentation - (see https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls). + """Build a SQLAlchemy ``Engine`` from a :class:`DatabaseConfig`. - * If ``db_raw_path`` is supplied it is used verbatim. - * Otherwise a URL is composed from ``dialect``, ``driver`` and ``path``. - For SQLite the path is resolved relative to the session's - ``resolve_path`` helper; missing directories are created on the fly. + * If ``db_url`` (TOML ``Url``) is supplied it is used verbatim. + * Otherwise ``db_path`` (TOML ``Path``) is resolved relative to the + session workspace and wrapped into a ``sqlite+pysqlite://`` URL. Sensitive information (user/password) is hidden in the debug output. @@ -91,18 +77,19 @@ def init_engine(session: SessionConfig) -> Engine | None: :rtype: Engine | None """ # --------------------------------------------------------------- # - # 1. Resolve "raw" URL - either provided by the user or built. + # 1. Resolve URL -- either user-supplied DSN or built from Path. # --------------------------------------------------------------- # - raw_path = session.db_config.db_raw_path + raw_path = session.db_config.db_url if raw_path is None: - # Build the URL manually when the user didn't provide a full DSN. - dialect = session.db_config.db_dialect or "sqlite" - driver = session.db_config.db_driver or "pysqlite" + # No Url configured -- use the SQLite Path default. + dialect = "sqlite" + driver = "pysqlite" path = session.db_config.db_path if not path: return dm_logger.error("Database path not specified!") - # :memory: is a special SQLite in-memory database. - if dialect == "sqlite" and path != ":memory:": + if path == ":memory:": + path = "/:memory:" + else: real_path = session.resolve_path(path) if not real_path.parent.exists(): dm_logger.debug(f"Creating database directory {real_path.parent}") @@ -118,26 +105,98 @@ def init_engine(session: SessionConfig) -> Engine | None: dialect = sql_type driver = "" + # --------------------------------------------------------------- # + # 2. Mask credentials in the debug log output. + # --------------------------------------------------------------- # + # For non-SQLite URLs like mysql+pymysql://user:pass@host/db, + # replace the user:pass portion with stars so passwords don't + # appear in log files. if dialect != "sqlite": first_element, *parts = path.split("/") if "@" in first_element: - # keep only the “host:port” part, replace user:pass with stars first_element = first_element.split("@")[1] path = "***:***@" + "/".join([first_element, *parts]) dm_logger.debug("Using database [%s:%s] at: %s", dialect, driver, path) - return create_engine(raw_path, isolation_level="AUTOCOMMIT", future=True) + + # --------------------------------------------------------------- # + # 3. Build the engine with backend-specific pool settings. + # --------------------------------------------------------------- # + # All backends use AUTOCOMMIT -- Dementor does individual INSERT/SELECT + # operations, not multi-statement transactions. + # + # pool_reset_on_return=None: the pool's default is to ROLLBACK on + # every connection checkin, which is wasted work under AUTOCOMMIT. + # + # skip_autocommit_rollback=True: tells the dialect itself not to + # emit ROLLBACK either (SQLAlchemy 2.0.43+). Together these two + # settings eliminate every unnecessary ROLLBACK round-trip. + common: dict[str, Any] = { + "isolation_level": "AUTOCOMMIT", + "pool_reset_on_return": None, + "skip_autocommit_rollback": True, + } + + # Three pool strategies, one per backend constraint: + # + # :memory: SQLite -> StaticPool (DB exists only inside one connection; + # a second connection = empty DB. DementorDB.lock + # serializes all access to that one connection.) + # + # File SQLite -> QueuePool (SQLAlchemy 2.0 default for file SQLite. + # Each thread checks out its own connection; + # _release() returns it after each operation.) + # + # MySQL/PostgreSQL -> QueuePool (Connection reuse avoids the ~10-50ms + # TCP+auth overhead of opening a new connection per + # query. LIFO keeps idle connections at the front so + # the server's wait_timeout can expire the rest.) + if dialect == "sqlite": + if path == ":memory:" or path.endswith("/:memory:"): + return create_engine( + raw_path, + **common, + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + # File-based SQLite: QueuePool is the SQLAlchemy 2.0 default. + # check_same_thread=False is set automatically by the dialect. + # DementorDB._release() returns connections after each operation. + return create_engine(raw_path, **common) + + # MySQL / MariaDB / PostgreSQL: QueuePool. + # pool_pre_ping – detect dead connections before checkout. + # pool_use_lifo – reuse most-recent connection so idle ones expire + # naturally via server-side wait_timeout. + # pool_recycle – hard ceiling: close connections older than 1 hour. + # pool_timeout=5 – fail fast on exhaustion (PoolTimeoutError caught + # in model.py); hash file is the primary capture path. + return create_engine( + raw_path, + **common, + pool_pre_ping=True, + pool_use_lifo=True, + pool_size=20, + max_overflow=40, + pool_timeout=5, + pool_recycle=3600, + ) def create_db(session: SessionConfig) -> DementorDB: - """ - High-level helper that returns a fully-initialised :class:`DementorDB`. + """Create a fully initialised :class:`DementorDB` ready for use. - :param session: Current session configuration. + Builds the SQLAlchemy engine via :func:`init_engine` and passes it + to the :class:`~dementor.db.model.DementorDB` constructor, which + creates the tables and sets up the scoped session. + + :param session: Current session configuration holding the + :class:`DatabaseConfig` at ``session.db_config``. :type session: SessionConfig - :return: Ready-to-use :class:`DementorDB` instance. + :return: Ready-to-use database wrapper. :rtype: DementorDB - :raises Exception: If the engine cannot be created. + :raises RuntimeError: If the engine cannot be created (e.g. empty + ``Path`` with no ``Url``). """ engine = init_engine(session) if not engine: diff --git a/dementor/db/model.py b/dementor/db/model.py index 9977efe..c4d3049 100644 --- a/dementor/db/model.py +++ b/dementor/db/model.py @@ -17,7 +17,17 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -# pyright: reportUnusedCallResult=false, reportAny=false, reportExplicitAny=false, reportPrivateUsage=false +# pyright: reportUnusedCallResult=false, reportAny=false, reportExplicitAny=false +"""ORM models and thread-safe database wrapper for Dementor. + +Defines the three ORM tables (``hosts``, ``extras``, ``credentials``) and +the :class:`DementorDB` class that protocol handlers call to store captured +credentials. All public methods are thread-safe via a combination of +:func:`~sqlalchemy.orm.scoped_session` (one session per thread) and a +:class:`threading.Lock` that serializes writes. +""" + +import contextlib import datetime import json import threading @@ -30,6 +40,7 @@ NoInspectionAvailable, NoSuchTableError, OperationalError, + TimeoutError as PoolTimeoutError, ) from sqlalchemy.orm import ( DeclarativeBase, @@ -43,10 +54,10 @@ from dementor.config.session import SessionConfig from dementor.db import ( - _CLEARTEXT, - _NO_USER, + CLEARTEXT, + HOST_INFO, + NO_USER, normalize_client_address, - _HOST_INFO, ) from dementor.log.logger import dm_logger from dementor.log import dm_console_lock @@ -57,12 +68,7 @@ class ModelBase(DeclarativeBase): - """ - Base class for all ORM models. - - It exists solely to give a common ``metadata`` object that can be used - for ``create_all`` / ``drop_all`` calls. - """ + """Base class for all ORM models.""" class HostInfo(ModelBase): @@ -160,12 +166,26 @@ class DementorDB: """ def __init__(self, engine: Engine, config: SessionConfig) -> None: + """Initialise the database wrapper. + + Creates all ORM tables if they do not exist, sets up a + :func:`~sqlalchemy.orm.scoped_session` registry for thread-local + sessions, and allocates the write lock. + + :param engine: A configured SQLAlchemy engine (from :func:`init_engine`). + :type engine: Engine + :param config: The active session configuration. + :type config: SessionConfig + :raises NoSuchTableError: If table creation fails due to a schema issue. + :raises NoInspectionAvailable: If the engine cannot be inspected. + """ self.db_engine: Engine = engine - self.db_path: str = str(engine.url.database) + self.db_path: str = str(engine.url.database or ":memory:") self.metadata: MetaData = ModelBase.metadata self.config: SessionConfig = config - # Ensure tables exist; any problem is reported immediately. + # Verify DB connectivity and create tables on first run. + # checkfirst=True avoids errors on subsequent starts. with self.db_engine.connect(): try: self.metadata.create_all(self.db_engine, checkfirst=True) @@ -173,46 +193,106 @@ def __init__(self, engine: Engine, config: SessionConfig) -> None: dm_logger.error(f"Failed to connect to database {self.db_path}! {exc}") raise - session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=True) - self.session: Session = scoped_session(session_factory)() + # expire_on_commit=False: ORM objects keep their attributes after + # _release() detaches them from the session. Without this, accessing + # host.id after _release() would raise a DetachedInstanceError. + session_factory = sessionmaker(bind=self.db_engine, expire_on_commit=False) + + # Store the scoped_session *registry*, not a Session instance. + # The .session property calls _scoped_session() to get the + # thread-local Session on demand. This is the fix for the original + # concurrency bug where all threads shared one Session/connection. + self._scoped_session: scoped_session[Session] = scoped_session(session_factory) + + # Serializes all DB writes. Both the duplicate check and the INSERT + # run inside this lock to prevent TOCTOU races. Reads (TUI queries) + # do not acquire it -- they get their own session via scoped_session. self.lock: threading.Lock = threading.Lock() + @property + def session(self) -> Session: + """Return the thread-local session from the scoped_session registry. + + Each thread gets its own Session instance, preventing concurrent access + to a shared database connection (which corrupts pymysql's packet + sequence on MySQL/MariaDB backends). + """ + return self._scoped_session() + # --------------------------------------------------------------------- # # Low-level helpers # --------------------------------------------------------------------- # def close(self) -> None: - """Close the underlying SQLAlchemy session.""" - self.session.close() + """Close all thread-local sessions and dispose of the engine.""" + self._scoped_session.remove() + self.db_engine.dispose() + + def _release(self) -> None: + """Return this thread's DB connection to the pool. + + Called at the end of every public method so handler threads don't + hold connections while doing non-DB work (SMB tree-connect, logoff, + Rich rendering, etc.). The scoped_session transparently creates a + fresh session on next access. + """ + # remove() does close() + clears the thread-local registry entry. + # Plain close() would leave a stale registry entry that prevents the + # pool from reclaiming the connection when the thread dies. + self._scoped_session.remove() + + def _handle_db_error(self, exc: OperationalError) -> None: + """Rollback and handle common OperationalError patterns. + + Detects outdated schema errors (``no such column``) and logs a + user-friendly message instead of crashing. All other + OperationalErrors are re-raised after rollback. + + :param exc: The caught OperationalError. + :type exc: OperationalError + :raises OperationalError: If the error is not a known schema issue. + """ + self.session.rollback() + if "no such column" in str(exc).lower(): + dm_logger.error( + "Could not execute SQL - you are probably using an outdated Dementor.db" + ) + else: + raise exc def _execute(self, q: TypedReturnsRows[tuple[_T]]) -> ScalarResult[_T] | None: """Execute a SQLAlchemy query and handle common operational errors. - :param q: SQLAlchemy query object. - :type q: Select | Insert | Update | Delete - :return: Query result or `None` if error occurred. - :rtype: Any + Catches :class:`OperationalError` (schema mismatch), + :class:`PoolTimeoutError` (pool exhaustion), and generic exceptions, + rolling back the session in each case so subsequent operations are + not poisoned. + + :param q: A SQLAlchemy selectable (e.g. from :func:`sqlalchemy.sql.select`). + :type q: TypedReturnsRows[tuple[_T]] + :return: Scalar result set, or ``None`` if a recoverable error occurred. + :rtype: ScalarResult[_T] | None """ try: return self.session.scalars(q) - except OperationalError as e: - if "no such column" in str(e).lower(): - dm_logger.error( - "Could not execute SQL - you are probably using an outdated Dementor.db" - ) - else: - raise - - def commit(self): + except OperationalError as exc: + self._handle_db_error(exc) + return None + except PoolTimeoutError: + dm_logger.warning("Database connection pool exhausted; skipping query") + return None + except Exception: + self.session.rollback() + raise + + def commit(self) -> None: """Commit the current transaction and handle schema-related errors.""" try: self.session.commit() - except OperationalError as e: - if "no such column" in str(e).lower(): - dm_logger.error( - "Could not execute SQL - you are probably using an outdated Dementor.db" - ) - else: - raise + except OperationalError as exc: + self._handle_db_error(exc) + except Exception: + self.session.rollback() + raise # --------------------------------------------------------------------- # # Public CRUD-style helpers @@ -242,48 +322,62 @@ def add_host( :return: The persisted :class:`HostInfo` object or ``None`` on failure. :rtype: HostInfo | None """ - with self.lock: - q = sql.select(HostInfo).where(HostInfo.ip == ip) - result = self._execute(q) - if result is None: - return None - host = result.one_or_none() - if not host: - host = HostInfo(ip=ip, hostname=hostname, domain=domain) - self.session.add(host) - self.commit() - else: - # Preserve existing values; only fill missing data. - host.domain = host.domain or domain or "" - host.hostname = host.hostname or hostname or "" - self.commit() + # try/finally guarantees _release() runs even if an exception + # propagates, so we never leak a DB connection from this thread. + try: + with self.lock: + q = sql.select(HostInfo).where(HostInfo.ip == ip) + result = self._execute(q) + if result is None: + return None + host = result.one_or_none() + if not host: + host = HostInfo(ip=ip, hostname=hostname, domain=domain) + self.session.add(host) + self.commit() + else: + # Preserve existing values; only fill missing data. + new_domain = host.domain or domain or "" + new_hostname = host.hostname or hostname or "" + if host.domain != new_domain or host.hostname != new_hostname: + host.domain = new_domain + host.hostname = new_hostname + self.commit() - if extras: - for key, value in extras.items(): - self.add_host_extra(host.id, key, value, no_lock=True) - return host + if extras: + for key, value in extras.items(): + self.add_host_extra(host.id, key, value, _locked=True) + return host + finally: + self._release() def add_host_extra( - self, host_id: int, key: str, value: str, no_lock: bool = False + self, host_id: int, key: str, value: str, *, _locked: bool = False ) -> None: - """ - Store an arbitrary extra attribute for a host. + """Store an arbitrary extra attribute for a host. - ``extras`` are stored in a separate table to keep the ``hosts`` row - small and to allow multiple values per host. + Values are stored as a JSON array in the ``extras`` table. If the + key already exists for the given host, the new value is appended to + the array; otherwise a new row is created. - :param host_id: Primary key of the target ``HostInfo``. + :param host_id: Primary key of the target :class:`HostInfo`. :type host_id: int - :param key: Attribute name. + :param key: Attribute name (e.g. ``"os"``, ``"service"``). :type key: str - :param value: Attribute value. + :param value: Attribute value to store or append. :type value: str - :param no_lock: Skip acquiring lock if `True` (internal use). - :type no_lock: bool, optional + :param _locked: When ``True``, the caller already holds ``self.lock`` + (internal use by :meth:`add_host`), defaults to ``False``. + :type _locked: bool, optional """ - if not no_lock: - self.lock.acquire() - try: + # When called from add_host() the lock is already held, so we use + # nullcontext() as a no-op context manager to avoid a deadlock. + # When called standalone (e.g. from a protocol handler), we acquire + # the real lock to serialize the read-modify-write on the JSON array. + lock: threading.Lock | contextlib.nullcontext[None] = ( + contextlib.nullcontext() if _locked else self.lock + ) + with lock: q = sql.select(HostExtra).where( HostExtra.host == host_id, HostExtra.key == key ) @@ -296,13 +390,153 @@ def add_host_extra( self.session.add(extra) self.commit() else: - # REVISIT: values: list[str] = json.loads(extra.value) values.append(value) extra.value = json.dumps(values) - finally: - if not no_lock: - self.lock.release() + self.commit() + + # --------------------------------------------------------------------- # + # Credential capture + # --------------------------------------------------------------------- # + def _check_duplicate( + self, + protocol: str, + credtype: str, + username: str, + domain: str | None, + ) -> bool: + """Check if a credential with the same key fields already exists. + + The comparison is case-insensitive on all four fields. Must be + called while ``self.lock`` is held to prevent a TOCTOU race with + the subsequent INSERT. + + :param protocol: Protocol name (e.g. ``"smb"``). + :type protocol: str + :param credtype: Credential type (e.g. ``"NetNTLMv2"``). + :type credtype: str + :param username: Username to match. + :type username: str + :param domain: Domain to match, or ``None`` (matches empty string). + :type domain: str | None + :return: ``True`` if a duplicate exists, ``False`` otherwise. + Returns ``True`` on DB error to avoid silent data loss. + :rtype: bool + """ + q = sql.select(Credential).filter( + sql.func.lower(Credential.domain) == sql.func.lower(domain or ""), + sql.func.lower(Credential.username) == sql.func.lower(username), + sql.func.lower(Credential.credtype) == sql.func.lower(credtype), + sql.func.lower(Credential.protocol) == sql.func.lower(protocol), + ) + result = self._execute(q) + if result is None: + return True # DB error -- treat as exists to avoid silent data loss + return len(result.all()) > 0 + + def _log_credential( + self, + target_logger: Any, + credtype: str, + username: str, + password: str, + domain: str | None, + hostname: str | None, + client_address: str, + extras: dict[str, str] | None, + host_info: str | None, + custom: bool, + *, + is_duplicate: bool, + ) -> None: + """Emit user-facing log messages for a captured or skipped credential. + + For new captures, acquires :data:`dm_console_lock` and emits a + multi-line Rich-formatted block (type, username, hash/password, + extras). For duplicates, emits a single "Skipping" line. + + :param target_logger: Logger instance with ``success``/``highlight`` + methods (typically a :class:`ProtocolLogger`). + :type target_logger: Any + :param credtype: Credential type label (e.g. ``"NetNTLMv2"``). + :type credtype: str + :param username: Captured username. + :type username: str + :param password: Captured password or hashcat-formatted hash line. + :type password: str + :param domain: Domain name, or ``None``. + :type domain: str | None + :param hostname: Hostname of the remote system, or ``None``. + :type hostname: str | None + :param client_address: Normalized client IP address. + :type client_address: str + :param extras: Additional key-value metadata to display, or ``None``. + :type extras: dict[str, str] | None + :param host_info: Human-readable host description for the display + line (e.g. ``"Windows 10 Build 19041 (name: WS01)"``), or ``None``. + :type host_info: str | None + :param custom: When ``True``, omit the "Hash"/"Password" label from + the success line (used for non-standard credential types). + :type custom: bool + :param is_duplicate: When ``True``, only emit the "Skipping" line. + :type is_duplicate: bool + """ + text = "Password" if credtype == CLEARTEXT else "Hash" + username_text = markup.escape(username) + if not str(username).strip(): + username_text = "(blank)" + + full_name = ( + f" for [b]{markup.escape(domain)}[/]/[b]{username_text}[/]" + if domain + else f" for [b]{username_text}[/]" + ) + if host_info: + full_name += f" on [b]{markup.escape(host_info)}[/]" + + if is_duplicate: + target_logger.highlight( + f"Skipping previously captured {credtype} {text}" + f" for {full_name} from {client_address}", + host=hostname or client_address, + ) + return + + with dm_console_lock: + head_text = text if not custom else "" + credtype_esc = markup.escape(credtype) + target_logger.success( + f"Captured {credtype_esc} {head_text}{full_name} from {client_address}:", + host=hostname or client_address, + locked=True, + ) + if username != NO_USER: + target_logger.highlight( + f"{credtype_esc} Username: {username_text}", + host=hostname or client_address, + locked=True, + ) + target_logger.highlight( + ( + f"{credtype_esc} {text}: {markup.escape(password)}" + if not custom + else f"{credtype_esc}: {markup.escape(password)}" + ), + host=hostname or client_address, + locked=True, + ) + if extras: + target_logger.highlight( + f"{credtype_esc} Extras:", + host=hostname or client_address, + locked=True, + ) + for name, value in extras.items(): + target_logger.highlight( + f" {name}: {markup.escape(value)}", + host=hostname or client_address, + locked=True, + ) def add_auth( self, @@ -317,36 +551,42 @@ def add_auth( extras: dict[str, str] | None = None, custom: bool = False, ) -> None: - """ - Store a captured credential in the database and emit user-friendly logs. + """Store a captured credential in the database and emit user-friendly logs. - The method performs a duplicate-check (unless the global config - ``db_duplicate_creds`` is ``True``) and respects read-only database - mode. + The duplicate check and INSERT are atomic (both inside ``self.lock``) + to prevent race conditions. Display logging only runs after a + successful DB write. The connection is released via :meth:`_release` + in a ``finally`` block so handler threads never leak connections. :param client: ``(ip, port)`` tuple of the remote endpoint. :type client: tuple[str, int] - :param credtype: ``_CLEARTEXT`` for passwords or a hash algorithm name. + :param credtype: ``CLEARTEXT`` for passwords, or a hash algorithm + name like ``"NetNTLMv2"``. :type credtype: str :param username: Username that was observed. :type username: str - :param password: Password or hash value. + :param password: Password or hashcat-formatted hash line. :type password: str - :param logger: Optional logger that provides a ``debug``/``success``/… - interface; defaults to the global ``dm_logger``. + :param logger: Protocol logger with ``success``/``highlight`` + methods. When ``None``, ``protocol`` must be supplied + explicitly, defaults to ``None``. :type logger: Any, optional - :param protocol: Protocol name (e.g. ``"ssh"``); if omitted it is taken - from ``logger.extra["protocol"]``. + :param protocol: Protocol name (e.g. ``"smb"``). When ``None``, + it is read from ``logger.extra["protocol"]``, defaults to ``None``. :type protocol: str | None, optional - :param domain: Optional domain name associated with the credential. + :param domain: Domain name associated with the credential, + defaults to ``None``. :type domain: str | None, optional - :param hostname: Optional host name for the remote system. + :param hostname: Hostname of the remote system, + defaults to ``None``. :type hostname: str | None, optional - :param extras: Optional additional key/value data to store alongside - the credential. - :type extras: Mapping[str, str] | None, optional - :param custom: When ``True`` the output omits the standard “Captured …” - prefix (used for artificial credentials). + :param extras: Additional key-value metadata to store alongside + the credential. The special key :data:`HOST_INFO` is popped + for display only, defaults to ``None``. + :type extras: dict[str, str] | None, optional + :param custom: When ``True``, omit the "Hash"/"Password" label + from the success log line (used for non-standard credential + types), defaults to ``False``. :type custom: bool, optional """ if not logger and not protocol: @@ -367,107 +607,103 @@ def add_auth( ) # Ensure the host exists (or create it) before linking the cred. + # add_host() releases its own connection via _release(). host = self.add_host(client_address, hostname, domain) if host is None: return - # Build the duplicate-check query (case-insensitive). - q = sql.select(Credential).filter( - sql.func.lower(Credential.domain) == sql.func.lower(domain or ""), - sql.func.lower(Credential.username) == sql.func.lower(username), - sql.func.lower(Credential.credtype) == sql.func.lower(credtype), - sql.func.lower(Credential.protocol) == sql.func.lower(protocol), - ) - result = self._execute(q) - if result is None: - return - - results = result.all() - text = "Password" if credtype == _CLEARTEXT else "Hash" - username_text = markup.escape(username) - if len(str(username).strip()) == 0: - username_text = "(blank)" - - # Human-readable part used in log messages. - full_name = ( - f" for [b]{markup.escape(domain)}[/]/[b]{username_text}[/]" - if domain - else f" for [b]{username_text}[/]" - ) - host_info: str | None = extras.pop(_HOST_INFO, None) if extras else None - if host_info: - full_name += f" on [b]{markup.escape(host_info)}[/]" - - if not results or self.config.db_config.db_duplicate_creds: - if credtype != _CLEARTEXT: - log_to("hashes", type=credtype, value=password) + # Pop host_info from extras before DB storage. + host_info: str | None = extras.pop(HOST_INFO, None) if extras else None - cred = Credential( - # REVISIT: replace with util.now() - timestamp=datetime.datetime.now(tz=datetime.UTC).strftime( - "%Y-%m-%d %H:%M:%S" - ), - protocol=protocol.lower(), - credtype=credtype.lower(), - client=f"{client_address}:{port}", - hostname=hostname or "", - domain=(domain or "").lower(), - username=username.lower(), - password=password, - host=host.id, - ) - try: - with self.lock: - self.session.add(cred) - self.session.commit() - except OperationalError as e: - # Special handling for read-only SQLite databases. - if "readonly database" in str(e).lower(): - dm_logger.fail( - f"Failed to add {credtype} for {username} on {client_address}: " - + "Database is read-only! (maybe restart in sudo mode?)" - ) - else: - raise + # --- Phase 1: critical section (duplicate check + insert) --- + # Both operations must be inside the same lock acquisition to prevent + # a TOCTOU race where two threads both pass the duplicate check and + # both insert. This was the original race condition bug. + db_write_ok = False + is_duplicate = False + allow_dupes = self.config.db_config.db_duplicate_creds - with dm_console_lock: - head_text = text if not custom else "" - credtype_esc = markup.escape(credtype) - target_logger.success( - f"Captured {credtype_esc} {head_text}{full_name} from {client_address}:", - host=hostname or client_address, - locked=True, - ) - if username != _NO_USER: - target_logger.highlight( - f"{credtype_esc} Username: {username_text}", - host=hostname or client_address, - locked=True, - ) - target_logger.highlight( - ( - f"{credtype_esc} {text}: {markup.escape(password)}" - if not custom - else f"{credtype_esc}: {markup.escape(password)}" - ), - host=hostname or client_address, - locked=True, + try: + with self.lock: + is_duplicate = not allow_dupes and self._check_duplicate( + protocol, credtype, username, domain ) - if extras: - target_logger.highlight( - f"{credtype_esc} Extras:", - host=hostname or client_address, - locked=True, + + if not is_duplicate: + if credtype != CLEARTEXT: + log_to("hashes", type=credtype, value=password) + + cred = Credential( + timestamp=datetime.datetime.now(tz=datetime.UTC).strftime( + "%Y-%m-%d %H:%M:%S" + ), + protocol=protocol.lower(), + credtype=credtype.lower(), + client=f"{client_address}:{port}", + hostname=hostname or "", + domain=(domain or "").lower(), + username=username.lower(), + password=password, + host=host.id, ) - for name, value in extras.items(): - target_logger.highlight( - f" {name}: {markup.escape(value)}", - host=hostname or client_address, - locked=True, + try: + self.session.add(cred) + self.session.commit() + db_write_ok = True + except PoolTimeoutError: + # Pool is temporarily full. The hash was already + # written to the file stream (log_to above), so + # we just skip the DB insert rather than crashing. + dm_logger.warning( + f"Database pool exhausted; dropped {credtype} " + f"for {username} on {client_address}" ) - else: - # Credential already present - only emit a short notice. - target_logger.highlight( - f"Skipping previously captured {credtype} {text} for {full_name} from {client_address}", - host=hostname or client_address, + except OperationalError as e: + # Rollback so the session isn't left in a broken + # state (which would cause PendingRollbackError + # on every subsequent operation from this thread). + self.session.rollback() + if "readonly database" in str(e).lower(): + dm_logger.fail( + f"Failed to add {credtype} for {username} on " + f"{client_address}: Database is read-only! " + "(maybe restart in sudo mode?)" + ) + else: + raise + finally: + self._release() + + # --- Phase 2: display logging OUTSIDE the DB lock --- + # Rich rendering is slow; holding the lock during it would block + # all other handler threads from writing to the database. + # Only log if the write actually succeeded (db_write_ok) or if + # we're reporting a duplicate skip -- never on write failure. + if is_duplicate: + self._log_credential( + target_logger, + credtype, + username, + password, + domain, + hostname, + client_address, + extras, + host_info, + custom, + is_duplicate=True, + ) + elif db_write_ok: + self._log_credential( + target_logger, + credtype, + username, + password, + domain, + hostname, + client_address, + extras, + host_info, + custom, + is_duplicate=False, ) From 6bf7a98aa2c047ae0e6c8dded8855c78af1c1aad Mon Sep 17 00:00:00 2001 From: StrongWind1 <5987034+StrongWind1@users.noreply.github.com> Date: Wed, 18 Mar 2026 00:39:23 -0400 Subject: [PATCH 2/4] test: comprehensive DB test suite (124 tests, no external DB) test_db.py (121 tests across 20 classes): - __init__.py: constants, aliases, __all__ exports, normalize_client_address - connector.py: DatabaseConfig defaults/loading, init_engine for all backends, create_db success and failure paths - model.py: DementorDB init/lifecycle/close, session thread-locality, _release isolation, add_host, add_host_extra, add_auth, protocol resolution, duplicate detection (10 variations), extras handling, logging output, _check_duplicate, error handling, connection release, thread safety test_db_concurrency.py (2 stress tests): - 20-thread concurrent add_auth on SQLite memory and file All tests use SQLite -- no external database required. --- tests/test_db.py | 1360 ++++++++++++++++++++++++++++++++++ tests/test_db_concurrency.py | 102 +++ 2 files changed, 1462 insertions(+) create mode 100644 tests/test_db.py create mode 100644 tests/test_db_concurrency.py diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..eb0f22b --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,1360 @@ +# ruff: noqa: S105, S106 +"""Comprehensive test suite for dementor.db (__init__, connector, model). + +All tests use SQLite :memory: with StaticPool -- no external DB required. +""" + +from __future__ import annotations + +import json +import os +import tempfile +import threading +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +from sqlalchemy import create_engine, inspect, select +from sqlalchemy.exc import OperationalError +from sqlalchemy.pool import StaticPool + +import dementor.db as db_module +from dementor.db import ( + CLEARTEXT, + HOST_INFO, + NO_USER, + _CLEARTEXT, + _HOST_INFO, + _NO_USER, + normalize_client_address, +) +from dementor.db.connector import DatabaseConfig, create_db, init_engine +from dementor.db.model import ( + Credential, + DementorDB, + HostExtra, + HostInfo, +) + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- +@pytest.fixture +def engine(): + """In-memory SQLite engine shared across threads via StaticPool.""" + return create_engine( + "sqlite+pysqlite:///:memory:", + isolation_level="AUTOCOMMIT", + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + + +@pytest.fixture +def config(): + """Mock SessionConfig with DuplicateCreds=False.""" + cfg = MagicMock() + cfg.db_config.db_duplicate_creds = False + return cfg + + +@pytest.fixture +def config_dupes(): + """Mock SessionConfig with DuplicateCreds=True.""" + cfg = MagicMock() + cfg.db_config.db_duplicate_creds = True + return cfg + + +@pytest.fixture +def db(engine, config): + """DementorDB with dedup enabled.""" + d = DementorDB(engine, config) + yield d + d.close() + + +@pytest.fixture +def db_dupes(engine, config_dupes): + """DementorDB with DuplicateCreds=True.""" + d = DementorDB(engine, config_dupes) + yield d + d.close() + + +@pytest.fixture +def logger(): + """Mock logger with protocol=SMB.""" + lg = MagicMock() + lg.extra = {"protocol": "SMB"} + return lg + + +# =================================================================== +# __init__.py +# =================================================================== +class TestConstantValues: + def test_cleartext(self) -> None: + assert CLEARTEXT == "Cleartext" + + def test_no_user(self) -> None: + assert NO_USER == "" + + def test_host_info(self) -> None: + assert HOST_INFO == "_host_info" + + +class TestConstantAliases: + def test_cleartext_alias_identity(self) -> None: + assert _CLEARTEXT is CLEARTEXT + + def test_no_user_alias_identity(self) -> None: + assert _NO_USER is NO_USER + + def test_host_info_alias_identity(self) -> None: + assert _HOST_INFO is HOST_INFO + + +class TestAllExports: + def test_all_contains_public_names(self) -> None: + + assert "CLEARTEXT" in db_module.__all__ + assert "NO_USER" in db_module.__all__ + assert "HOST_INFO" in db_module.__all__ + assert "normalize_client_address" in db_module.__all__ + + def test_all_does_not_contain_aliases(self) -> None: + + assert "_CLEARTEXT" not in db_module.__all__ + assert "_NO_USER" not in db_module.__all__ + assert "_HOST_INFO" not in db_module.__all__ + + +class TestNormalizeClientAddress: + def test_strips_ipv6_mapped_v4(self) -> None: + assert normalize_client_address("::ffff:192.168.1.1") == "192.168.1.1" + + def test_strips_ipv6_mapped_private(self) -> None: + assert normalize_client_address("::ffff:10.0.0.50") == "10.0.0.50" + + def test_leaves_plain_ipv4(self) -> None: + assert normalize_client_address("10.0.0.1") == "10.0.0.1" + + def test_leaves_real_ipv6(self) -> None: + assert normalize_client_address("2001:db8::1") == "2001:db8::1" + + def test_leaves_localhost(self) -> None: + assert normalize_client_address("127.0.0.1") == "127.0.0.1" + + def test_empty_string(self) -> None: + assert normalize_client_address("") == "" + + def test_only_prefix_itself(self) -> None: + assert normalize_client_address("::ffff:") == "" + + +# =================================================================== +# connector.py — DatabaseConfig +# =================================================================== +class TestDatabaseConfig: + def test_default_fields_from_empty_config(self) -> None: + # Note: TomlConfig resolves defaults from the global Dementor.toml, + # so db_duplicate_creds is True (set in shipped config). + cfg = DatabaseConfig({}) + assert cfg.db_url is None + assert cfg.db_path == "Dementor.db" + # The shipped Dementor.toml sets DuplicateCreds = true + assert cfg.db_duplicate_creds is True + + def test_code_default_duplicate_creds(self) -> None: + # The Attribute default in code is False; this is overridden by TOML. + field = next( + f for f in DatabaseConfig._fields_ if f.attr_name == "db_duplicate_creds" + ) + assert field.default_val is False + + def test_loads_url_from_dict(self) -> None: + + cfg = DatabaseConfig({"Url": "sqlite:///:memory:"}) + assert cfg.db_url == "sqlite:///:memory:" + + def test_loads_path_from_dict(self) -> None: + + cfg = DatabaseConfig({"Path": "custom.db"}) + assert cfg.db_path == "custom.db" + + def test_loads_duplicate_creds_from_dict(self) -> None: + + cfg = DatabaseConfig({"DuplicateCreds": True}) + assert cfg.db_duplicate_creds is True + + def test_section_name(self) -> None: + + assert DatabaseConfig._section_ == "DB" + + +# =================================================================== +# connector.py — init_engine +# =================================================================== +class TestInitEngine: + def _make_session( + self, *, db_url=None, db_path="Dementor.db", tmpdir=None + ) -> MagicMock: + session = MagicMock() + session.db_config.db_url = db_url + session.db_config.db_path = db_path + if tmpdir: + session.resolve_path.return_value = Path(tmpdir) / db_path + else: + session.resolve_path.return_value = Path(tempfile.gettempdir()) / db_path + return session + + def test_sqlite_memory_returns_engine(self) -> None: + + session = self._make_session(db_path=":memory:") + engine = init_engine(session) + assert engine is not None + assert "memory" in str(engine.url) + engine.dispose() + + def test_sqlite_file_returns_engine(self) -> None: + + with tempfile.TemporaryDirectory() as tmpdir: + session = self._make_session(db_path="test.db", tmpdir=tmpdir) + engine = init_engine(session) + assert engine is not None + assert "test.db" in str(engine.url) + engine.dispose() + + def test_sqlite_file_creates_directory(self) -> None: + + with tempfile.TemporaryDirectory() as tmpdir: + subdir = os.path.join(tmpdir, "subdir") + session = self._make_session(db_path="test.db") + session.resolve_path.return_value = Path(subdir) / "test.db" + engine = init_engine(session) + assert engine is not None + assert os.path.isdir(subdir) + engine.dispose() + + def test_empty_path_returns_none(self) -> None: + + session = self._make_session(db_path="") + result = init_engine(session) + assert result is None + + def test_url_overrides_path(self) -> None: + + session = self._make_session(db_url="sqlite:///:memory:", db_path="ignored.db") + engine = init_engine(session) + assert engine is not None + assert "memory" in str(engine.url) + engine.dispose() + + def test_mysql_url_parsed(self) -> None: + pytest.importorskip("pymysql") + session = self._make_session(db_url="mysql+pymysql://user:pass@fakehost/fakedb") + engine = init_engine(session) + assert engine is not None + assert engine.dialect.name == "mysql" + engine.dispose() + + def test_url_without_driver(self) -> None: + + session = self._make_session(db_url="sqlite:///:memory:") + engine = init_engine(session) + assert engine is not None + engine.dispose() + + +# =================================================================== +# connector.py — create_db +# =================================================================== +class TestCreateDb: + def test_returns_dementor_db(self) -> None: + + session = MagicMock() + session.db_config.db_url = None + session.db_config.db_path = ":memory:" + db = create_db(session) + assert isinstance(db, DementorDB) + db.close() + + def test_raises_on_engine_failure(self) -> None: + + session = MagicMock() + session.db_config.db_url = None + session.db_config.db_path = "" + with pytest.raises(RuntimeError, match="Failed to create database engine"): + create_db(session) + + +# =================================================================== +# model.py — DementorDB init / lifecycle +# =================================================================== +class TestDementorDBInit: + def test_creates_all_three_tables(self, engine, config) -> None: + db = DementorDB(engine, config) + with engine.connect() as conn: + tables = inspect(conn).get_table_names() + assert "hosts" in tables + assert "extras" in tables + assert "credentials" in tables + db.close() + + def test_db_path_memory(self, engine, config) -> None: + db = DementorDB(engine, config) + assert db.db_path == ":memory:" + db.close() + + def test_db_path_file(self, config) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "test.db") + eng = create_engine( + f"sqlite+pysqlite:///{path}", isolation_level="AUTOCOMMIT" + ) + db = DementorDB(eng, config) + assert db.db_path == path + db.close() + + def test_stores_engine_reference(self, engine, config) -> None: + db = DementorDB(engine, config) + assert db.db_engine is engine + db.close() + + def test_stores_config_reference(self, engine, config) -> None: + db = DementorDB(engine, config) + assert db.config is config + db.close() + + def test_has_lock(self, db) -> None: + assert isinstance(db.lock, type(threading.Lock())) + + +class TestSession: + def test_returns_session_object(self, db) -> None: + assert db.session is not None + + def test_same_session_in_same_thread(self, db) -> None: + s1 = db.session + s2 = db.session + assert s1 is s2 + + def test_different_session_per_thread(self, db) -> None: + main_session = db.session + other: list[object] = [None] + + def worker(): + other[0] = db.session + + t = threading.Thread(target=worker) + t.start() + t.join() + assert other[0] is not main_session + + def test_new_session_after_release(self, db) -> None: + s1 = db.session + db._release() + s2 = db.session + assert s1 is not s2 + + +class TestCloseAndRelease: + def test_close_does_not_raise(self, engine, config) -> None: + db = DementorDB(engine, config) + db.close() + + def test_release_does_not_raise(self, db) -> None: + _ = db.session + db._release() + + def test_release_from_thread_is_isolated(self, db) -> None: + errors: list[Exception] = [] + + def worker(): + try: + _ = db.session + db._release() + except Exception as e: + errors.append(e) + + t = threading.Thread(target=worker) + t.start() + t.join() + assert errors == [] + assert db.session is not None + + def test_session_works_after_release(self, db) -> None: + db.add_host("1.2.3.4") + # add_host calls _release internally + host = db.add_host("1.2.3.4") + assert host is not None + + +# =================================================================== +# model.py — add_host +# =================================================================== +class TestAddHost: + def test_creates_new_host(self, db) -> None: + host = db.add_host("10.0.0.1") + assert host is not None + assert host.ip == "10.0.0.1" + assert host.id is not None + + def test_with_hostname(self, db) -> None: + host = db.add_host("10.0.0.2", hostname="WS01") + assert host is not None + assert host.hostname == "WS01" + + def test_with_domain(self, db) -> None: + host = db.add_host("10.0.0.3", domain="CORP") + assert host is not None + assert host.domain == "CORP" + + def test_with_hostname_and_domain(self, db) -> None: + host = db.add_host("10.0.0.4", hostname="WS01", domain="CORP") + assert host is not None + assert host.hostname == "WS01" + assert host.domain == "CORP" + + def test_idempotent_returns_same_id(self, db) -> None: + h1 = db.add_host("10.0.0.5") + h2 = db.add_host("10.0.0.5") + assert h1 is not None + assert h2 is not None + assert h1.id == h2.id + + def test_fills_missing_hostname(self, db) -> None: + db.add_host("10.0.0.6") + h2 = db.add_host("10.0.0.6", hostname="LATE") + assert h2 is not None + assert h2.hostname == "LATE" + + def test_fills_missing_domain(self, db) -> None: + db.add_host("10.0.0.7") + h2 = db.add_host("10.0.0.7", domain="LATE") + assert h2 is not None + assert h2.domain == "LATE" + + def test_does_not_overwrite_existing_hostname(self, db) -> None: + db.add_host("10.0.0.8", hostname="FIRST") + h2 = db.add_host("10.0.0.8", hostname="SECOND") + assert h2 is not None + assert h2.hostname == "FIRST" + + def test_does_not_overwrite_existing_domain(self, db) -> None: + db.add_host("10.0.0.9", domain="FIRST") + h2 = db.add_host("10.0.0.9", domain="SECOND") + assert h2 is not None + assert h2.domain == "FIRST" + + def test_no_extras(self, db) -> None: + host = db.add_host("10.0.0.10", extras=None) + assert host is not None + + def test_empty_extras(self, db) -> None: + host = db.add_host("10.0.0.11", extras={}) + assert host is not None + + def test_with_single_extra(self, db) -> None: + host = db.add_host("10.0.0.12", extras={"os": "Win10"}) + assert host is not None + result = db.session.scalars( + select(HostExtra).where(HostExtra.host == host.id) + ).all() + db._release() + assert len(result) == 1 + + def test_with_multiple_extras(self, db) -> None: + host = db.add_host("10.0.0.13", extras={"os": "Win10", "arch": "x64"}) + assert host is not None + result = db.session.scalars( + select(HostExtra).where(HostExtra.host == host.id) + ).all() + db._release() + assert len(result) == 2 + + def test_different_ips_create_different_hosts(self, db) -> None: + h1 = db.add_host("10.0.0.14") + h2 = db.add_host("10.0.0.15") + assert h1 is not None + assert h2 is not None + assert h1.id != h2.id + + +# =================================================================== +# model.py — add_host_extra +# =================================================================== +class TestAddHostExtra: + def test_creates_new_extra(self, db) -> None: + host = db.add_host("10.0.0.20") + assert host is not None + db.add_host_extra(host.id, "service", "smb") + result = db.session.scalars( + select(HostExtra).where(HostExtra.host == host.id, HostExtra.key == "service") + ).one() + db._release() + assert json.loads(result.value) == ["smb"] + + def test_appends_to_existing_key(self, db) -> None: + host = db.add_host("10.0.0.21") + assert host is not None + db.add_host_extra(host.id, "port", "445") + db.add_host_extra(host.id, "port", "139") + result = db.session.scalars( + select(HostExtra).where(HostExtra.host == host.id, HostExtra.key == "port") + ).one() + db._release() + assert json.loads(result.value) == ["445", "139"] + + def test_appends_three_values(self, db) -> None: + host = db.add_host("10.0.0.22") + assert host is not None + for v in ["a", "b", "c"]: + db.add_host_extra(host.id, "tag", v) + result = db.session.scalars( + select(HostExtra).where(HostExtra.host == host.id, HostExtra.key == "tag") + ).one() + db._release() + assert json.loads(result.value) == ["a", "b", "c"] + + def test_different_keys_are_separate_rows(self, db) -> None: + host = db.add_host("10.0.0.23") + assert host is not None + db.add_host_extra(host.id, "os", "Linux") + db.add_host_extra(host.id, "arch", "x86_64") + result = db.session.scalars( + select(HostExtra).where(HostExtra.host == host.id) + ).all() + db._release() + assert len(result) == 2 + + def test_locked_parameter_works(self, db) -> None: + """_locked=True used internally by add_host (via extras dict).""" + host = db.add_host("10.0.0.24") + assert host is not None + # Simulate what add_host does: call with _locked=True while holding lock + with db.lock: + db.add_host_extra(host.id, "test_key", "test_val", _locked=True) + result = db.session.scalars( + select(HostExtra).where(HostExtra.host == host.id) + ).one() + db._release() + assert json.loads(result.value) == ["test_val"] + + def test_value_stored_as_string(self, db) -> None: + host = db.add_host("10.0.0.25") + assert host is not None + db.add_host_extra(host.id, "count", "42") + result = db.session.scalars( + select(HostExtra).where(HostExtra.host == host.id, HostExtra.key == "count") + ).one() + db._release() + assert json.loads(result.value) == ["42"] + + +# =================================================================== +# model.py — add_auth: basic storage +# =================================================================== +class TestAddAuth: + def test_stores_all_fields(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.30", 12345), + credtype="NetNTLMv2", + username="admin", + password="hash123", + logger=logger, + domain="CORP", + hostname="WS01", + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.username == "admin" + assert cred.password == "hash123" + assert cred.protocol == "smb" + assert cred.domain == "corp" + assert cred.hostname == "WS01" + assert cred.client == "10.0.0.30:12345" + + def test_credtype_lowercased(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.31", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.credtype == "netntlmv2" + + def test_stores_cleartext(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.32", 445), + credtype=CLEARTEXT, + username="u", + password="P@ss", + logger=logger, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.credtype == "cleartext" + + def test_creates_host_row(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.33", 445), + credtype="NetNTLMv1", + username="u", + password="h", + logger=logger, + ) + hosts = db.session.scalars(select(HostInfo)).all() + db._release() + assert len(hosts) == 1 + assert hosts[0].ip == "10.0.0.33" + + def test_reuses_existing_host(self, db, logger) -> None: + db.add_host("10.0.0.34", hostname="PRE") + db.add_auth( + client=("10.0.0.34", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + ) + hosts = db.session.scalars(select(HostInfo)).all() + db._release() + assert len(hosts) == 1 + + def test_normalizes_ipv6(self, db, logger) -> None: + db.add_auth( + client=("::ffff:10.0.0.35", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.client == "10.0.0.35:445" + + def test_lowercases_username(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.36", 445), + credtype="NetNTLMv2", + username="ADMIN", + password="h", + logger=logger, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.username == "admin" + + def test_lowercases_domain(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.37", 445), + credtype="NetNTLMv2", + username="u", + password="h", + domain="CORP.LOCAL", + logger=logger, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.domain == "corp.local" + + def test_none_domain_stored_as_empty(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.38", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + domain=None, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.domain == "" + + def test_none_hostname_stored_as_empty(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.39", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + hostname=None, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.hostname == "" + + def test_password_stored_verbatim(self, db, logger) -> None: + raw = "Admin::CORP:544553544348414c:AABBCCDD:0101blob" + db.add_auth( + client=("10.0.0.40", 445), + credtype="NetNTLMv2", + username="u", + password=raw, + logger=logger, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.password == raw + + def test_timestamp_format(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.41", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + date, time = cred.timestamp.split(" ") + assert len(date.split("-")) == 3 + assert len(time.split(":")) == 3 + + def test_credential_fk_links_to_host(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.42", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + ) + cred = db.session.scalars(select(Credential)).one() + host = db.session.scalars(select(HostInfo).where(HostInfo.id == cred.host)).one() + db._release() + assert host.ip == "10.0.0.42" + + +# =================================================================== +# model.py — add_auth: protocol resolution +# =================================================================== +class TestAddAuthProtocol: + def test_from_logger_extra(self, db) -> None: + lg = MagicMock() + lg.extra = {"protocol": "HTTP"} + db.add_auth( + client=("10.0.0.50", 80), + credtype="Token", + username="u", + password="t", + logger=lg, + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.protocol == "http" + + def test_from_parameter(self, db) -> None: + db.add_auth( + client=("10.0.0.51", 1433), + credtype="NetNTLMv2", + username="u", + password="h", + protocol="MSSQL", + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.protocol == "mssql" + + def test_parameter_overrides_logger(self, db) -> None: + lg = MagicMock() + lg.extra = {"protocol": "HTTP"} + db.add_auth( + client=("10.0.0.52", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=lg, + protocol="LDAP", + ) + cred = db.session.scalars(select(Credential)).one() + db._release() + assert cred.protocol == "ldap" + + def test_no_protocol_no_logger_skips(self, db) -> None: + db.add_auth( + client=("10.0.0.53", 445), + credtype="NetNTLMv2", + username="u", + password="h", + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 0 + + +# =================================================================== +# model.py — add_auth: duplicate detection +# =================================================================== +class TestDuplicateDetection: + def test_dedup_skips_second(self, db, logger) -> None: + for _ in range(2): + db.add_auth( + client=("10.0.0.60", 445), + credtype="NetNTLMv2", + username="admin", + password="h", + domain="CORP", + logger=logger, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 1 + + def test_dedup_case_insensitive_username(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.61", 445), + credtype="NetNTLMv2", + username="ADMIN", + password="h1", + logger=logger, + ) + db.add_auth( + client=("10.0.0.61", 445), + credtype="NetNTLMv2", + username="admin", + password="h2", + logger=logger, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 1 + + def test_dedup_case_insensitive_credtype(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.62", 445), + credtype="NetNTLMv2", + username="u", + password="h1", + logger=logger, + ) + db.add_auth( + client=("10.0.0.62", 445), + credtype="netntlmv2", + username="u", + password="h2", + logger=logger, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 1 + + def test_dedup_case_insensitive_domain(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.63", 445), + credtype="NetNTLMv2", + username="u", + password="h1", + domain="CORP", + logger=logger, + ) + db.add_auth( + client=("10.0.0.63", 445), + credtype="NetNTLMv2", + username="u", + password="h2", + domain="corp", + logger=logger, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 1 + + def test_different_credtype_not_deduped(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.64", 445), + credtype="NetNTLMv1", + username="u", + password="h1", + logger=logger, + ) + db.add_auth( + client=("10.0.0.64", 445), + credtype="NetNTLMv2", + username="u", + password="h2", + logger=logger, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 2 + + def test_different_user_not_deduped(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.65", 445), + credtype="NetNTLMv2", + username="admin", + password="h1", + logger=logger, + ) + db.add_auth( + client=("10.0.0.65", 445), + credtype="NetNTLMv2", + username="guest", + password="h2", + logger=logger, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 2 + + def test_different_domain_not_deduped(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.66", 445), + credtype="NetNTLMv2", + username="u", + password="h1", + domain="A", + logger=logger, + ) + db.add_auth( + client=("10.0.0.66", 445), + credtype="NetNTLMv2", + username="u", + password="h2", + domain="B", + logger=logger, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 2 + + def test_different_protocol_not_deduped(self, db) -> None: + db.add_auth( + client=("10.0.0.67", 445), + credtype="NetNTLMv2", + username="u", + password="h1", + protocol="SMB", + ) + db.add_auth( + client=("10.0.0.67", 80), + credtype="NetNTLMv2", + username="u", + password="h2", + protocol="HTTP", + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 2 + + def test_duplicate_creds_true_stores_all(self, db_dupes, logger) -> None: + for i in range(3): + db_dupes.add_auth( + client=("10.0.0.68", 445), + credtype="NetNTLMv2", + username="u", + password=f"h{i}", + domain="D", + logger=logger, + ) + creds = db_dupes.session.scalars(select(Credential)).all() + db_dupes._release() + assert len(creds) == 3 + + def test_same_ip_different_port_still_deduped(self, db, logger) -> None: + """Dedup keys are domain/user/credtype/protocol, NOT client IP:port.""" + db.add_auth( + client=("10.0.0.69", 445), + credtype="NetNTLMv2", + username="u", + password="h1", + logger=logger, + ) + db.add_auth( + client=("10.0.0.69", 12345), + credtype="NetNTLMv2", + username="u", + password="h2", + logger=logger, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 1 + + +# =================================================================== +# model.py — add_auth: HOST_INFO extras +# =================================================================== +class TestAddAuthExtras: + def test_host_info_popped(self, db, logger) -> None: + extras = {HOST_INFO: "WS.corp", "os": "Win10"} + db.add_auth( + client=("10.0.0.70", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + extras=extras, + ) + assert HOST_INFO not in extras + assert "os" in extras # other keys preserved + + def test_none_extras(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.71", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + extras=None, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 1 + + def test_empty_extras(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.72", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + extras={}, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 1 + + +# =================================================================== +# model.py — add_auth: logging +# =================================================================== +class TestAddAuthLogging: + def test_success_logs_captured(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.80", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + ) + logger.success.assert_called_once() + assert "Captured" in logger.success.call_args[0][0] + + def test_duplicate_logs_skipping(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.81", 445), + credtype="NetNTLMv2", + username="u", + password="h1", + logger=logger, + ) + logger.reset_mock() + db.add_auth( + client=("10.0.0.81", 445), + credtype="NetNTLMv2", + username="u", + password="h2", + logger=logger, + ) + assert any("Skipping" in str(c) for c in logger.highlight.call_args_list) + + def test_no_user_skips_username_line(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.82", 445), + credtype="Token", + username=NO_USER, + password="tok", + logger=logger, + ) + calls = [str(c) for c in logger.highlight.call_args_list] + assert not any("Username" in c for c in calls) + + def test_custom_flag_omits_hash_label(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.83", 445), + credtype="Custom", + username="u", + password="v", + logger=logger, + custom=True, + ) + msg = logger.success.call_args[0][0] + assert "Hash" not in msg + assert "Password" not in msg + + def test_cleartext_label_says_password(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.84", 445), + credtype=CLEARTEXT, + username="u", + password="p", + logger=logger, + ) + msg = logger.success.call_args[0][0] + assert "Password" in msg + + def test_hash_label_says_hash(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.85", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + ) + msg = logger.success.call_args[0][0] + assert "Hash" in msg + + def test_domain_appears_in_log(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.86", 445), + credtype="NetNTLMv2", + username="u", + password="h", + domain="TESTDOM", + logger=logger, + ) + msg = logger.success.call_args[0][0] + assert "TESTDOM" in msg + + def test_extras_logged(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.87", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + extras={"SPN": "cifs/server"}, + ) + calls = [str(c) for c in logger.highlight.call_args_list] + assert any("SPN" in c for c in calls) + assert any("cifs/server" in c for c in calls) + + def test_no_log_on_failed_write(self, db, logger) -> None: + """When add_host returns None, no credential is stored or logged.""" + db.add_auth( + client=("10.0.0.88", 445), + credtype="NetNTLMv2", + username="u", + password="h", + # no logger, no protocol -> early return before DB write + ) + logger.success.assert_not_called() + + +# =================================================================== +# model.py — _check_duplicate +# =================================================================== +class TestCheckDuplicate: + def test_false_on_empty_db(self, db) -> None: + assert db._check_duplicate("smb", "NetNTLMv2", "u", "D") is False + + def test_true_after_insert(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.90", 445), + credtype="NetNTLMv2", + username="u", + password="h", + domain="D", + logger=logger, + ) + assert db._check_duplicate("smb", "netntlmv2", "u", "d") is True + + def test_case_insensitive(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.91", 445), + credtype="NetNTLMv2", + username="ADMIN", + password="h", + domain="CORP", + logger=logger, + ) + assert db._check_duplicate("SMB", "NETNTLMV2", "admin", "corp") is True + + def test_none_domain_matches_empty(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.92", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + ) + assert db._check_duplicate("smb", "netntlmv2", "u", None) is True + + def test_different_domain_returns_false(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.93", 445), + credtype="NetNTLMv2", + username="u", + password="h", + domain="A", + logger=logger, + ) + assert db._check_duplicate("smb", "netntlmv2", "u", "B") is False + + +# =================================================================== +# model.py — error handling +# =================================================================== +class TestErrorHandling: + def test_handle_db_error_reraises_unknown(self, db) -> None: + exc = OperationalError("random error", {}, Exception()) + with pytest.raises(OperationalError): + db._handle_db_error(exc) + + def test_handle_db_error_swallows_schema_error(self, db) -> None: + exc = OperationalError("no such column: foo", {}, Exception()) + db._handle_db_error(exc) + + def test_handle_db_error_swallows_case_variants(self, db) -> None: + exc = OperationalError("No Such Column: bar", {}, Exception()) + db._handle_db_error(exc) + + def test_execute_succeeds(self, db) -> None: + result = db._execute(select(Credential)) + assert result is not None + + def test_commit_succeeds(self, db) -> None: + db.session.add(HostInfo(ip="99.99.99.99")) + db.commit() + hosts = db.session.scalars(select(HostInfo)).all() + db._release() + assert any(h.ip == "99.99.99.99" for h in hosts) + + +# =================================================================== +# model.py — connection release +# =================================================================== +class TestConnectionRelease: + def test_add_host_releases(self, db) -> None: + db.add_host("10.0.0.100") + assert db.session is not None + db._release() + + def test_add_auth_releases(self, db, logger) -> None: + db.add_auth( + client=("10.0.0.101", 445), + credtype="NetNTLMv2", + username="u", + password="h", + logger=logger, + ) + assert db.session is not None + db._release() + + def test_early_return_no_leak(self, db) -> None: + db.add_auth( + client=("10.0.0.102", 445), credtype="NetNTLMv2", username="u", password="h" + ) + assert db.session is not None + db._release() + + def test_sequential_operations_work(self, db, logger) -> None: + """Multiple add_auth calls in sequence (simulates a handler thread).""" + for i in range(5): + db.add_auth( + client=(f"10.0.0.{110 + i}", 445), + credtype="NetNTLMv2", + username=f"user{i}", + password=f"hash{i}", + logger=logger, + ) + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 5 + + +# =================================================================== +# model.py — thread safety +# =================================================================== +class TestThreadSafety: + def test_concurrent_add_host_same_ip(self, db) -> None: + errors: list[Exception] = [] + + def worker(): + try: + db.add_host("10.0.0.120") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + assert errors == [] + hosts = db.session.scalars( + select(HostInfo).where(HostInfo.ip == "10.0.0.120") + ).all() + db._release() + assert len(hosts) == 1 + + def test_concurrent_add_auth_different_users(self, db_dupes) -> None: + errors: list[Exception] = [] + + def worker(i: int): + try: + lg = MagicMock() + lg.extra = {"protocol": "SMB"} + db_dupes.add_auth( + client=(f"10.0.0.{130 + i}", 445), + credtype="NetNTLMv2", + username=f"user{i}", + password=f"hash{i}", + logger=lg, + ) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + assert errors == [] + creds = db_dupes.session.scalars(select(Credential)).all() + db_dupes._release() + assert len(creds) == 10 + + def test_atomic_dedup_insert(self, db) -> None: + """Concurrent threads with same cred: exactly 1 stored.""" + errors: list[Exception] = [] + + def worker(): + try: + lg = MagicMock() + lg.extra = {"protocol": "SMB"} + db.add_auth( + client=("10.0.0.140", 445), + credtype="NetNTLMv2", + username="shared", + password="h", + domain="D", + logger=lg, + ) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + assert errors == [] + creds = db.session.scalars(select(Credential)).all() + db._release() + assert len(creds) == 1 + + def test_concurrent_add_host_extra(self, db) -> None: + host = db.add_host("10.0.0.150") + assert host is not None + errors: list[Exception] = [] + + def worker(i: int): + try: + db.add_host_extra(host.id, "tag", f"val{i}") + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + assert errors == [] diff --git a/tests/test_db_concurrency.py b/tests/test_db_concurrency.py new file mode 100644 index 0000000..7baaf3d --- /dev/null +++ b/tests/test_db_concurrency.py @@ -0,0 +1,102 @@ +"""Concurrency stress tests for DementorDB. + +Spawns multiple threads that call add_auth() simultaneously and verifies: + - Zero exceptions (no PendingRollbackError, no packet sequence errors) + - All credentials are persisted (no lost writes) + +Tests run against SQLite backends only (no external DB required): + 1. SQLite :memory: (StaticPool, single shared connection) + 2. SQLite file (QueuePool, SQLAlchemy 2.0 default) +""" + +import os +import tempfile +import threading +from unittest.mock import MagicMock + +from sqlalchemy import create_engine, select +from sqlalchemy.pool import StaticPool + +from dementor.db.model import Credential, DementorDB + +THREAD_COUNT = 20 + + +def _make_config(*, duplicate_creds: bool = True) -> MagicMock: + config = MagicMock() + config.db_config.db_duplicate_creds = duplicate_creds + return config + + +def _make_logger() -> MagicMock: + logger = MagicMock() + logger.extra = {"protocol": "SMB"} + return logger + + +def _worker( + db: DementorDB, + index: int, + errors: list[Exception], +) -> None: + try: + db.add_auth( + client=(f"10.0.0.{index}", 12345), + credtype="NetNTLMv2", + username=f"user{index}", + password=f"hash_value_{index}", + protocol="SMB", + domain=f"DOMAIN{index}", + logger=_make_logger(), + ) + except Exception as exc: + errors.append(exc) + + +def _run_concurrent_test(db: DementorDB) -> list[Exception]: + errors: list[Exception] = [] + threads = [ + threading.Thread(target=_worker, args=(db, i, errors)) + for i in range(THREAD_COUNT) + ] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + return errors + + +# --- SQLite :memory: --------------------------------------------------------- +def test_concurrent_add_auth_sqlite_memory() -> None: + engine = create_engine( + "sqlite+pysqlite:///:memory:", + isolation_level="AUTOCOMMIT", + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + db = DementorDB(engine, _make_config()) + + errors = _run_concurrent_test(db) + assert errors == [], f"Got {len(errors)} errors: {errors}" + + creds = db.session.scalars(select(Credential)).all() + assert len(creds) == THREAD_COUNT, f"Expected {THREAD_COUNT}, got {len(creds)}" + db.close() + + +# --- SQLite file -------------------------------------------------------------- +def test_concurrent_add_auth_sqlite_file() -> None: + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test.db") + engine = create_engine( + f"sqlite+pysqlite:///{db_path}", + isolation_level="AUTOCOMMIT", + ) + db = DementorDB(engine, _make_config()) + + errors = _run_concurrent_test(db) + assert errors == [], f"Got {len(errors)} errors: {errors}" + + creds = db.session.scalars(select(Credential)).all() + assert len(creds) == THREAD_COUNT, f"Expected {THREAD_COUNT}, got {len(creds)}" + db.close() From 4bcebef4ea9e5c6d7d74657c299b53d9690b3f7a Mon Sep 17 00:00:00 2001 From: StrongWind1 <5987034+StrongWind1@users.noreply.github.com> Date: Wed, 18 Mar 2026 00:39:23 -0400 Subject: [PATCH 3/4] config: rewrite [DB] section with clear separated options Three options, each with its own section header and examples: - Url: advanced, for MySQL/MariaDB/PostgreSQL (makes Path ignored) - Path: default SQLite backend (relative, absolute, or :memory:) - DuplicateCreds: store all hashes or deduplicate Removed stale commented-out Dialect/Driver fields. --- dementor/assets/Dementor.toml | 38 +++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/dementor/assets/Dementor.toml b/dementor/assets/Dementor.toml index 72ad0c8..190a309 100755 --- a/dementor/assets/Dementor.toml +++ b/dementor/assets/Dementor.toml @@ -204,17 +204,39 @@ UPnP = true # ============================================================================= [DB] -# If true, allows duplicate credentials to be stored. If false, only unique -# credentials will be stored and printed once. -# The default value is: true - -DuplicateCreds = true +# --- Url (advanced) ---------------------------------------------------------- +# Full SQLAlchemy database URL. When set, Path is ignored. +# Use this to connect to an external database server for multi-session or +# shared access. Leave empty (the default) to use the SQLite Path below. +# +# Examples: +# Url = "mysql+pymysql://user:pass@127.0.0.1/dementor" # MySQL / MariaDB +# Url = "postgresql+psycopg2://user:pass@127.0.0.1/dementor" # PostgreSQL +# +# Url = -# Dialect = "sqlite" -# Driver = "pysqlite" -# Url = "sqlite:///:memory:" +# --- Path (default backend) -------------------------------------------------- +# Path to the SQLite database file. Only used when Url is empty. +# +# Relative paths are resolved from the workspace directory: +# Path = "Dementor.db" (default -- file in workspace) +# Path = "data/captures.db" (subfolder, created automatically) +# +# Absolute paths are used as-is: +# Path = "/opt/dementor/creds.db" +# +# In-memory database (fast, but all data is lost when Dementor exits; +# the TUI can still query captured creds while running): +# Path = ":memory:" +# # Path = "Dementor.db" +# --- DuplicateCreds ----------------------------------------------------------- +# When true, every captured hash is stored even if an identical credential +# (same domain + username + type + protocol) was seen before. When false, +# only the first capture is kept and repeats are silently skipped. +DuplicateCreds = true + # ============================================================================= # mDNS # ============================================================================= From 722544e9b7c891cbd6b00c882e173ad47f8ea34c Mon Sep 17 00:00:00 2001 From: StrongWind1 <5987034+StrongWind1@users.noreply.github.com> Date: Wed, 18 Mar 2026 00:39:23 -0400 Subject: [PATCH 4/4] docs: rewrite database.rst to match current code and Sphinx style - Add Choosing a backend section with comparison table - Document all 3 active options (Url, Path, DuplicateCreds) with py:attribute directives, examples, bullet points, tips, and notes - Mark Dialect and Driver as removed with versionremoved directives - Sphinx HTML build passes with -W --keep-going (zero warnings) --- docs/source/config/database.rst | 181 ++++++++++++++++++++++++++------ 1 file changed, 148 insertions(+), 33 deletions(-) diff --git a/docs/source/config/database.rst b/docs/source/config/database.rst index c567c90..5e47ebd 100644 --- a/docs/source/config/database.rst +++ b/docs/source/config/database.rst @@ -4,70 +4,188 @@ Database ======== -Section ``[Database]`` ----------------------- +Section ``[DB]`` +---------------- -.. py:currentmodule:: Database +Dementor stores every captured credential (hashes and cleartext passwords) in a +database so you can query them later through the TUI or export them for offline +cracking. The ``[DB]`` section controls where that database lives and how +duplicates are handled. -.. py:attribute:: DuplicateCreds - :type: bool - :value: false +.. tip:: - *Maps to* :attr:`db.connector.DatabaseConfig.db_duplicate_creds` + Most users don't need to touch this section. With no configuration at all, + Dementor creates a SQLite file called ``Dementor.db`` in your workspace + directory. That works out of the box for most engagements. - Controls whether duplicate credentials are stored. If set to ``false``, each unique - credential set is stored and displayed only once. -.. py:attribute:: Dialect - :type: str - :value: "sqlite" +Choosing a backend +~~~~~~~~~~~~~~~~~~ - *Maps to* :attr:`db.connector.DatabaseConfig.db_dialect` +Dementor supports three database backends. Pick the one that fits your +use case: - .. versionadded:: 1.0.0.dev14 +.. list-table:: + :widths: 25 45 30 + :header-rows: 1 - Specifies the SQL dialect to use + * - Backend + - When to use it + - How to configure + * - **SQLite file** *(default)* + - Credentials persist to disk across restarts. Good for most + engagements. + - Leave ``Url`` empty. Optionally set ``Path``. + * - **SQLite in-memory** + - Fast, no disk I/O. Credentials are lost when Dementor exits, but + the TUI can still query them while running. Good for quick tests. + - ``Path = ":memory:"`` + * - **MySQL / MariaDB / PostgreSQL** + - Shared access across multiple Dementor instances or integration + with external tooling. Requires a running database server. + - Set ``Url`` to a full connection string. -.. py:attribute:: Driver + +Options +~~~~~~~ + +.. py:currentmodule:: DB + + +.. py:attribute:: Url :type: str - :value: "pysqlite" + :value: *(empty)* - *Maps to* :attr:`db.connector.DatabaseConfig.db_driver` + *Maps to* :attr:`db.connector.DatabaseConfig.db_url` .. versionadded:: 1.0.0.dev14 - Specifies the SQL driver (external packages allowed) to be used for the database connection. - Additional third-party packages must be installed before they can be used. + .. versionchanged:: 1.0.0.dev22 + + Renamed internally from ``db_raw_path`` to ``db_url``. The TOML key + ``Url`` is unchanged. + + Full `SQLAlchemy database URL `_ + for connecting to an external database server. When set, :attr:`Path` is + ignored. Leave empty (the default) to use SQLite via :attr:`Path`. + + .. code-block:: toml + + # MySQL / MariaDB + Url = "mysql+pymysql://user:pass@127.0.0.1/dementor" + + # PostgreSQL + Url = "postgresql+psycopg2://user:pass@127.0.0.1/dementor" + + .. note:: + + The database driver (e.g. ``pymysql``, ``psycopg2``) must be installed + separately — it is not bundled with Dementor. + .. py:attribute:: Path - :type: RelativePath | RelativeWorkspacePath | AbsolutePath + :type: str :value: "Dementor.db" *Maps to* :attr:`db.connector.DatabaseConfig.db_path` .. versionadded:: 1.0.0.dev14 - Specifies the database filename. Not used if :attr:`~DB.Url` is set. + Path to the SQLite database file. Only used when :attr:`Url` is empty. -.. py:attribute:: Url - :type: str + * **Relative paths** are resolved from the workspace directory + (:attr:`Dementor.Workspace`). + * **Absolute paths** are used as-is. + * ``:memory:`` creates an in-memory database — fast, but all data is lost + when Dementor exits. The TUI can still query credentials while running. + + .. code-block:: toml + + # Default — file in the workspace directory + Path = "Dementor.db" + + # Subfolder (created automatically if it doesn't exist) + Path = "data/captures.db" + + # Absolute path + Path = "/opt/dementor/creds.db" + + # In-memory — fast, but data is lost on exit + Path = ":memory:" + + .. tip:: + + Use ``:memory:`` for quick tests where you don't need persistence. + The TUI can still query captured credentials while Dementor is running. - *Maps to* :attr:`db.connector.DatabaseConfig.db_raw_path` + +.. py:attribute:: DuplicateCreds + :type: bool + :value: true + + *Maps to* :attr:`db.connector.DatabaseConfig.db_duplicate_creds` + + Controls whether duplicate credentials are stored in the database. + + * ``true`` *(default)* — Every captured hash is stored, even if the same + credential was already seen in this session. + * ``false`` — Only the first capture of each unique credential is stored. + Subsequent duplicates are silently skipped. + + A credential is considered a duplicate when all four of these fields match + (case-insensitive): + + * Domain + * Username + * Credential type (e.g. ``NetNTLMv2``, ``Cleartext``) + * Protocol (e.g. ``smb``, ``http``) + + .. note:: + + The hash is always written to the log file stream regardless of this + setting, so no captured data is ever lost — only the database storage + is affected. + + .. tip:: + + Set to ``false`` on long-running engagements to keep the database small + and the TUI output clean. + + +Removed options +~~~~~~~~~~~~~~~ + +The following options have been removed in previous versions. They are silently +ignored if still present in your configuration file. + +.. py:attribute:: Dialect + :type: str + :value: "sqlite" .. versionadded:: 1.0.0.dev14 - Custom database connection URL to use. Overwrites driver, dialect and path. + .. versionremoved:: 1.0.0.dev22 + **Removed.** The SQL dialect is now determined automatically — from + :attr:`Url` when set, or defaults to ``sqlite`` when using :attr:`Path`. +.. py:attribute:: Driver + :type: str + :value: "pysqlite" + + .. versionadded:: 1.0.0.dev14 + + .. versionremoved:: 1.0.0.dev22 + + **Removed.** The SQL driver is now determined automatically — from + :attr:`Url` when set, or defaults to ``pysqlite`` when using :attr:`Path`. .. py:attribute:: Directory :type: str .. versionremoved:: 1.0.0.dev14 - **DEPRECATED** Specifies a custom directory for storing the database. This setting overrides the default - directory configured via :attr:`Dementor.Workspace`. - + **Removed.** Use :attr:`Path` with an absolute path instead. .. py:attribute:: Name :type: str @@ -75,7 +193,4 @@ Section ``[Database]`` .. versionremoved:: 1.0.0.dev14 - **DEPRECATED** Sets the filename of the database to be used. - - - + **Removed.** Use :attr:`Path` instead.