diff --git a/dementor/assets/Dementor.toml b/dementor/assets/Dementor.toml index 85aac70..35f1cca 100644 --- a/dementor/assets/Dementor.toml +++ b/dementor/assets/Dementor.toml @@ -19,6 +19,18 @@ # 3. [NTLM] section -- shared default for all NTLM-enabled protocols # 4. [Globals] section -- last resort # +# Host identity settings (NetBIOSComputer, NetBIOSDomain, DnsComputer, +# DnsDomain) follow the same chain as NTLM: +# 1. [[Protocol.Server]] entry +# 2. [Protocol] section +# 3. [NTLM] section +# 4. [Globals] section -- derived automatically from Globals.Host +# +# Protocol Host values (SMTP, LDAP, MSSQL, ...) resolve as: +# 1. [[Protocol.Server]] entry +# 2. [Protocol] section +# 3. [Globals] section -- Host derived from Globals.Host +# # All other settings stop at step 2. # # Note: Some settings can only be used in the most local section (e.g. "Port"). @@ -171,6 +183,28 @@ UPnP = true # be used if the local (protocol-specific) section does not define the value. [Globals] +# --- Host ------------------------------------------------------------------- +# Single Host (FQDN or bare hostname) that defines the server's identity for all +# protocol responses. When set, the following values are automatically derived +# and used as fallbacks throughout all protocol servers: +# +# NetBIOSDomainName = Host.split(".", 1)[1].upper() +# DNSHostName = Host.split(".", 1)[0] +# NetBIOSName = Host.split(".", 1)[0][:15].upper() +# DNSDomainName = Host.split(".", 1)[1].lower() +# Host (FQDN) = Host +# +# Individual values (NetBIOSComputer, NetBIOSDomain, DnsComputer, DnsDomain, +# Host) can be overridden in [NTLM] or any [Protocol] section. The Host +# entry is the last-resort fallback for all of them. +# +# Can also be set via the CLI: +# sudo dementor -I eth0 -H DC01.contoso.lab +# sudo dementor -I eth0 -O Globals.Host="DC01.contoso.lab" + +# Host = "DC01.contoso.lab" + +# --- Filtering -------------------------------------------------------------- # Describes a list of hosts to *include* for poisoning (whitelist approach). # Supported filter formats: # @@ -210,8 +244,8 @@ UPnP = true # shared access. Leave empty (the default) to use the SQLite Path below. # # Examples: -# Url = "mysql+pymysql://user:pass@127.0.0.1/dementor" # MySQL / MariaDB -# Url = "postgresql+psycopg2://user:pass@127.0.0.1/dementor" # PostgreSQL +# Url = "mysql-pymysql://user:pass@127.0.0.1/dementor" # MySQL / MariaDB +# Url = "postgresql-psycopg2://user:pass@127.0.0.1/dementor" # PostgreSQL # # Url = @@ -233,7 +267,7 @@ UPnP = true # --- DuplicateCreds ----------------------------------------------------------- # When true, every captured hash is stored even if an identical credential -# (same domain + username + type + protocol) was seen before. When false, +# (same domain - username - type - protocol) was seen before. When false, # only the first capture is kept and repeats are silently skipped. DuplicateCreds = true @@ -321,13 +355,15 @@ CapturesPerConnection = 0 # Optional. NTSTATUS code returned after the final capture. Accepts an integer # value or a string name from impacket.nt_errors (e.g. "STATUS_ACCESS_DENIED", # "STATUS_LOGON_FAILURE"). The string is resolved via getattr(nt_errors, value). -# Default: "STATUS_SMB_BAD_UID" -ErrorCode = "STATUS_SMB_BAD_UID" +# +# WARNING: Enabling this feature with disable all file path capture events +# Default: Not set +# ErrorCode = "STATUS_SMB_BAD_UID" # --- SMB1 Identity (optional) --- # These strings appear only in SMB1 negotiate and session-setup responses. # They are purely SMB-layer and do NOT affect the NTLM CHALLENGE_MESSAGE. -# SMB2/3 has no equivalent fields — these are ignored for modern clients. +# SMB2/3 has no equivalent fields - these are ignored for modern clients. # To control the NTLM identity (AV_PAIRs), use the [NTLM] section instead. # Optional. ServerName in the SMB1 non-extended-security negotiate response. @@ -355,7 +391,7 @@ Port = 139 [[SMB.Server]] Port = 445 # Per-server overrides (highest priority -- override [SMB] and [NTLM] for this port only): -# FQDN = "other.corp.com" +# Host = "other.corp.com" # ServerOS = "Windows Server 2022" # ErrorCode = "STATUS_ACCESS_DENIED" # SMB2Support = true @@ -371,16 +407,16 @@ Port = 445 # Specifies the NetBIOS domain name to advertise in NETLOGON responses. # This value is used when responding to PDC queries (LOGON_PRIMARY_QUERY) # and DC discovery requests (LOGON_SAM_LOGON_REQUEST). -# The default value is: "CONTOSO" +# The default value is: "WORKGROUP" -# DomainName = "WORKGROUP" +# DNSDomain = "WORKGROUP" # Specifies the hostname to advertise as the domain controller in NETLOGON # responses. This value is used when responding to PDC queries and DC # discovery requests. -# The default value is: "DC01" +# The default value is: "DEMENTOR" -# Hostname = "FILESERVER" +# DnsComputer = "DC01" # ============================================================================= @@ -392,10 +428,10 @@ Port = 445 AuthMechanisms = ["NTLM", "PLAIN", "LOGIN"] -# Fully Qualified Domain Name (FQDN) used by the SMTP server. -# The first part of the FQDN is used as the hostname in responses. +# Fully Qualified Domain Name used by the SMTP server in banner and +# NTLM identity responses. Derives from Globals.Host when not set here. -FQDN = "DEMENTOR" +# Host = "DC01.contoso.lab" # SMTP server banner (identifier and version sent to clients). @@ -470,29 +506,6 @@ Port = 25 Challenge = "1337LEET" -# When true, the ESS flag (NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY) is -# stripped from the CHALLENGE_MESSAGE sent to the client. -# -# false (default): ESS is echoed back when the client requests it. Clients -# that support ESS produce NTLMv1-ESS hashes (hashcat mode 5500 with -# MD5-mixed challenge). This is the modern, common NTLMv1 variant. -# -# true: ESS is suppressed regardless of what the client requests. Clients -# fall back to plain NTLMv1 (DES-only). With a fixed Challenge above, -# these hashes are vulnerable to precomputed rainbow table attacks. - -DisableExtendedSessionSecurity = false - -# When true, TargetInfoFields are omitted from the CHALLENGE_MESSAGE. -# Without TargetInfoFields clients cannot construct the NTLMv2 Blob -# (MS-NLMP S3.3.2), which has the following effect by client security level: -# Level 0-2 (older Windows, manually downgraded): fall back to NTLMv1. -# Level 3+ (all modern Windows defaults): refuse to authenticate -- zero -# hashes captured from these clients. -# -# Leave false unless specifically targeting legacy NTLMv1-only environments. - -DisableNTLMv2 = false # Optional. Remove the NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY flag from the # CHALLENGE_MESSAGE, controlling which NTLMv1 hash variant clients produce. # false (default): NTLMv1 clients produce NetNTLMv1-ESS (MD5-mixed, hashcat -m 5500). @@ -505,15 +518,21 @@ DisableExtendedSessionSecurity = false # prevents clients from constructing NTLMv2 responses. # false (default): AV_PAIRs present. All clients (LmCompat 0-5) can authenticate. # true: AV_PAIRs absent. LmCompat 0-2 clients fall back to NTLMv1. -# LmCompat 3+ clients (all modern Windows defaults) will REFUSE +# LmCompat 3- clients (all modern Windows defaults) will REFUSE # to authenticate, producing zero hashes. Use with caution. # Default: false DisableNTLMv2 = false # --- Server Identity (optional) --- # These control the identity fields inside the NTLMSSP CHALLENGE_MESSAGE. -# They are INDEPENDENT from SMB identity — [SMB] NetBIOSComputer is the SMB1 +# They are INDEPENDENT from SMB identity - [SMB] NetBIOSComputer is the SMB1 # negotiate ServerName, while [NTLM] NetBIOSComputer is AV_PAIR 0x0001. +# +# Resolution order (highest priority first): +# 1. [[Protocol.Server]] entry +# 2. [Protocol] section +# 3. [NTLM] section <- override here to apply to ALL NTLM protocols +# 4. [Globals] section <- derived automatically from Globals.Host # Optional. Controls the NTLMSSP_TARGET_TYPE flag and the TargetName field # in the CHALLENGE_MESSAGE. @@ -536,28 +555,32 @@ DisableNTLMv2 = false # These appear inside the TargetInfoFields of the CHALLENGE_MESSAGE. # AV_PAIRs 0x0001 and 0x0002 are always sent (required by MS-NLMP spec). # AV_PAIRs 0x0003, 0x0004, 0x0005 are omitted when set to "" (empty string). +# +# When Globals.Host is configured these values are derived automatically. +# Set them here to override only the NTLM identity without changing other +# protocol server names. # Required. MsvAvNbComputerName (AV_PAIR 0x0001). NetBIOS name of the server. # Also used as TargetName when TargetType="server". -# Default: "DEMENTOR" -# NetBIOSComputer = "DEMENTOR" +# Default: derived from Globals.Host (hostname[:15].upper()), else "DEMENTOR" +# NetBIOSComputer = "DC01" # Required. MsvAvNbDomainName (AV_PAIR 0x0002). NetBIOS domain name. # Also used as TargetName when TargetType="domain". -# Default: "WORKGROUP" -# NetBIOSDomain = "WORKGROUP" +# Default: derived from Globals.Host (domain.upper()), else "WORKGROUP" +# NetBIOSDomain = "CONTOSO" # Optional. MsvAvDnsComputerName (AV_PAIR 0x0003). DNS FQDN of the server. -# Default: "" (omitted from CHALLENGE_MESSAGE) -# DnsComputer = "server.corp.local" +# Default: derived from Globals.Host (full FQDN), else "" (omitted) +# DnsComputer = "dc01.contoso.lab" # Optional. MsvAvDnsDomainName (AV_PAIR 0x0004). DNS domain name. -# Default: "" (omitted from CHALLENGE_MESSAGE) -# DnsDomain = "corp.local" +# Default: derived from Globals.Host (domain.lower()), else "" (omitted) +# DnsDomain = "contoso.lab" # Optional. MsvAvDnsTreeName (AV_PAIR 0x0005). Active Directory forest name. # Default: "" (omitted from CHALLENGE_MESSAGE) -# DnsTree = "corp.local" +# DnsTree = "contoso.lab" # ============================================================================= # FTP Server(s) @@ -610,10 +633,10 @@ EncType = "aes256_cts_hmac_sha1_96" Timeout = 0 -# Hostname + fully qualified domain name, whereby the domain name is optional -# Full example: "HOSTNAME.domain.local" +# Hostname - fully qualified domain name used in LDAP responses. +# Derives from Globals.Host when not set here. -FQDN = "DEMENTOR" +# Host = "DC01.contoso.lab" # Global TLS option, will enable TLS on all TCP servers @@ -904,7 +927,7 @@ InstanceName = "MSSQLServer" [SSRP] # The following values can be defined here or inherited from the [MSSQL] section: -# - FQDN +# - Host # - ServerVersion # - InstanceName diff --git a/dementor/config/__init__.py b/dementor/config/__init__.py index f90efe9..bf3e481 100644 --- a/dementor/config/__init__.py +++ b/dementor/config/__init__.py @@ -21,6 +21,7 @@ import sys import pathlib import tomllib +import warnings from typing import Any @@ -51,7 +52,7 @@ def _set_global_config(config: dict[str, Any]) -> None: :param config: New configuration dictionary. :type config: dict """ - sys.modules[__name__].dm_config = config + sys.modules[__name__].dm_config = config # ty:ignore[unresolved-attribute] def init_from_file(path: str) -> None: @@ -82,8 +83,42 @@ def init_from_file(path: str) -> None: # --------------------------------------------------------------------------- # -# Default initialisation - performed on import so that the rest of the -# package can rely on ``dementor.config.dm_config`` being available. +# Explicit entry point for application startup. # --------------------------------------------------------------------------- # -init_from_file(DEFAULT_CONFIG_PATH) # 1. bundled defaults -init_from_file(CONFIG_PATH) # 2. user-provided overrides +def init_config( + default_path: str | None = None, + user_path: str | None = None, +) -> None: + """Load the default and user TOML configuration files. + + Call this explicitly from the application entry point (e.g. ``standalone.py``). + Using the defaults, it loads the bundled Dementor.toml first, then the + user-provided override, so user settings win. + + :param default_path: Path to bundled defaults (uses :data:`DEFAULT_CONFIG_PATH`). + :param user_path: Path to user overrides (uses :data:`CONFIG_PATH`). + """ + try: + init_from_file(default_path or DEFAULT_CONFIG_PATH) # 1. bundled defaults + init_from_file(user_path or CONFIG_PATH) # 2. user-provided overrides + except Exception as exc: # pragma: no cover + warnings.warn( + f"dementor.config: failed to load configuration at startup: {exc}", + RuntimeWarning, + stacklevel=1, + ) + + +# --------------------------------------------------------------------------- # +# Default initialisation - runs on first import so that protocol modules +# that call get_global_config() / get_value() without going through +# standalone.py still get the bundled defaults. standalone.py re-runs +# init_config() with the user config path, which overwrites these defaults. +# --------------------------------------------------------------------------- # +init_config() + +__all__ = [ + "get_global_config", + "init_config", + "init_from_file", +] diff --git a/dementor/config/attr.py b/dementor/config/attr.py index c817315..83ed9cf 100644 --- a/dementor/config/attr.py +++ b/dementor/config/attr.py @@ -20,7 +20,7 @@ # # Shared attributes for all configuration classes from dementor.config.toml import Attribute -from dementor.config.util import is_true +from dementor.config.util import HostValue, is_true # TLS/Certificate Configuration Attributes # These attributes are shared across protocols that support TLS and @@ -113,3 +113,18 @@ ATTR_CERT_STATE, ATTR_CERT_VALIDITY_DAYS, ] + + +# Host Configuration Attribute +# Single attribute representing the host FQDN from [Globals]. +# Protocols that need the full HostValue object (e.g. to call get_value()) +# can include this in their _fields_ list. The HostValue factory derives +# NetBIOSComputer, NetBIOSDomain, DnsComputer, DnsDomain, FQDN, etc. + +ATTR_GLOBALS_HOST = Attribute( + attr_name="host", + qname="Host", + default_val=HostValue.DEFAULT, + section_local=False, + factory=HostValue, +) diff --git a/dementor/config/session.py b/dementor/config/session.py index bbac709..d022422 100644 --- a/dementor/config/session.py +++ b/dementor/config/session.py @@ -65,7 +65,7 @@ class SessionConfig(TomlConfig): Attribute("extra_modules", "ExtraModules", list), Attribute("workspace_path", "Workspace", DEMENTOR_PATH), ] + [ - # TODO: place this somewhere else + # TODO: move per-protocol enabled flags into a dedicated ProtocolFlags config class Attribute(f"{name.lower()}_enabled", name, True, factory=is_true) for name in ( "LLMNR", @@ -93,7 +93,6 @@ class SessionConfig(TomlConfig): ) ] - # TODO: move into .pyi if typing.TYPE_CHECKING: workspace_path: str extra_modules: list[str] @@ -111,7 +110,7 @@ class SessionConfig(TomlConfig): llmnr_config: llmnr.LLMNRConfig netbiosns_config: netbios.NBTNSConfig ldap_config: list[ldap.LDAPServerConfig] - smtp_servers: list[smtp.SMTPServerConfig] + smtp_config: list[smtp.SMTPServerConfig] smb_config: list[smb.SMBServerConfig] ftp_config: list[ftp.FTPServerConfig] proxy_config: http.ProxyAutoConfig diff --git a/dementor/config/tls.py b/dementor/config/tls.py new file mode 100644 index 0000000..ffa9f00 --- /dev/null +++ b/dementor/config/tls.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025-Present MatrixEditor +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""TLS certificate generation utilities.""" + +import datetime +import pathlib +import random +import string +import tempfile + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + + +def generate_self_signed_cert( + cn: str, + org: str, + country: str, + state: str, + locality: str, + validity_days: int, +) -> tuple[str, str, tempfile.TemporaryDirectory]: + """ + Generate a self-signed certificate and private key in a temporary directory. + + :param cn: Common name for the certificate. + :param org: Organization name. + :param country: Country code. + :param state: State or province. + :param locality: Locality or city. + :param validity_days: Number of days the certificate is valid. + :return: Tuple of (certificate path, key path, temporary directory object). + """ + temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) + + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, country), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state), + x509.NameAttribute(NameOID.LOCALITY_NAME, locality), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, org), + x509.NameAttribute(NameOID.COMMON_NAME, cn), + ] + ) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.UTC)) + .not_valid_after( + datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=validity_days) + ) + .add_extension( + x509.SubjectAlternativeName([x509.DNSName(cn)]), + critical=False, + ) + .sign(private_key, hashes.SHA256()) + ) + + key_id = "".join(random.choices(string.hexdigits)) + key_path = pathlib.Path(temp_dir.name) / f"key_{key_id}.pem" + key_path.write_bytes( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + cert_id = "".join(random.choices(string.hexdigits)) + cert_path = pathlib.Path(temp_dir.name) / f"cert_{cert_id}.pem" + cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + + return str(cert_path), str(key_path), temp_dir diff --git a/dementor/config/toml.py b/dementor/config/toml.py index 2a0663d..309cce0 100644 --- a/dementor/config/toml.py +++ b/dementor/config/toml.py @@ -225,7 +225,7 @@ def _set_field( if alt_section: # 1. Nested sub-dict within own section (e.g. SMB.NTLM.X) sections.append(own_section_dict.get(alt_section, {})) - # 2. Own section flat key (e.g. SMB.X — doubles as default) + # 2. Own section flat key (e.g. SMB.X - doubles as default) sections.append(own_section_dict) # 3. Alt section (e.g. [NTLM]) sections.append(get_value(alt_section or "", key=None, default={})) diff --git a/dementor/config/util.py b/dementor/config/util.py index c639e6f..ca18aa3 100644 --- a/dementor/config/util.py +++ b/dementor/config/util.py @@ -22,18 +22,11 @@ import random import string import secrets -import os -import tempfile from typing import Any +from collections.abc import Callable from jinja2.sandbox import SandboxedEnvironment -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography import x509 -from cryptography.x509.oid import NameOID -from cryptography.hazmat.primitives import hashes - from dementor.config import get_global_config # --------------------------------------------------------------------------- # @@ -176,6 +169,140 @@ def _parse_str(self, value: str) -> bytes: return stripped.encode() +class HostValue: + """Parse a host FQDN and derive all host-related configuration values. + + A single ``Host = "HOSTNAME.DOMAIN"`` entry in ``[Globals]`` derives: + + - ``NetBIOSDomainName`` / ``NetBIOSDomain`` -> ``domain.upper()`` + - ``DNSHostName`` -> ``hostname`` + - ``NetBIOSName`` / ``NetBIOSComputer`` -> ``hostname[:15].upper()`` + - ``DNSDomainName`` / ``DnsDomain`` -> ``domain.lower()`` + - ``DnsComputer`` / ``FQDN`` -> full ``"hostname.domain"`` + + :param value: Raw host string, e.g. ``"DC01.contoso.lab"`` + :type value: str + """ + + DEFAULT = "DEMENTOR" + DNS_COMPUTER = "DnsComputer" + FQDN = "FQDN" + HOST = "Host" + DNS_HOSTNAME = "DNSHostName" + DNS_DOMAIN = "DNSDomain" + DNS_TREE = "DNSTree" + NETBIOS_NAME = "NetBIOSName" + NETBIOS_COMPUTER = "NetBIOSComputer" + NETBIOS_DOMAIN = "NetBIOSDomain" + + def __init__(self, value: Any) -> None: + self._raw = str(value).strip() if value is not None else self.DEFAULT + if "." in self._raw: + self.hostname, self.domain = self._raw.split(".", 1) + else: + self.hostname = self._raw + self.domain = "" + + def get_value(self, field: str) -> str: + """Return the derived configuration value for *field*. + + :param field: Supported field names: ``Host``, ``FQDN``, ``DnsComputer``, + ``DNSHostName``, ``NetBIOSName``, ``NetBIOSComputer``, + ``NetBIOSDomainName``, ``NetBIOSDomain``, ``DNSDomainName``, + ``DnsDomain``, ``DnsTree``. + :type field: str + :return: Derived string value. + :rtype: str + """ + value: str = self.hostname + match field: + case "Host" | "FQDN": + value = self._raw + case "DnsComputer": + # Full FQDN when a domain is present; empty (omit AV_PAIR) otherwise + value = self._raw if self.domain else "" + case "DNSHostName": + value = self.hostname + case "NetBIOSName" | "NetBIOSComputer": + value = self.hostname[:15].upper() + case "NetBIOSDomainName" | "NetBIOSDomain": + value = self.domain.upper() if self.domain else "WORKGROUP" + case "DNSDomainName" | "DnsDomain" | "DnsTree": + value = self.domain.lower() if self.domain else "WORKGROUP" + case _: + pass + return value + + def __str__(self) -> str: + return self._raw + + def __call__(self, value: Any) -> "HostValue": + """Allow a :class:`HostValue` instance to serve as a factory callable.""" + return HostValue(value) + + +class HostFallbackValue: + """Attribute factory that applies an explicit-first, Host-derived fallback strategy. + + Resolution order when used as an :class:`~dementor.config.toml.Attribute` + factory: + + 1. Any explicit value resolved by the Attribute system (e.g. + ``[Server].FQDN``, ``[Protocol].FQDN``, ``[Globals].FQDN``) - returned + as-is. + 2. When the resolved value is ``None`` (nothing set anywhere): derive the + field from ``Globals.Host`` via :class:`HostValue`. + 3. When ``Globals.Host`` is also absent: return *fallback*. + + Unlike :class:`HostDerivedValue` (which always parses the input through + :class:`HostValue` and is used with ``qname="Host"``), this class treats + explicit values as opaque strings and invokes :class:`HostValue` derivation + only as a last resort. Use this with the actual field's own ``qname`` + (e.g. ``"FQDN"``, ``"NetBIOSComputer"``) so that each value can be + configured independently before Host derivation kicks in. + + :param field: :class:`HostValue` field used for Host-based derivation + (e.g. ``"FQDN"``, ``"NetBIOSComputer"``). + :type field: str + :param fallback: Hard-coded last-resort value. + :type fallback: str + :param post_factory: Optional callable applied after the value is resolved + or derived. Useful for chaining with e.g. :func:`format_string`. + :type post_factory: Callable[[str], str] | None + """ + + def __init__( + self, + field: str, + fallback: str = "", + post_factory: Callable[[str], str] | None = None, + ) -> None: + self.field = field + self.fallback = fallback + self.post_factory = post_factory + + def __call__(self, value: Any) -> str: + """Resolve *value* with explicit-first, Host-derived fallback. + + :param value: Raw value from the Attribute system, or ``None`` when no + configuration key matched. + :type value: Any + :return: Final string value. + :rtype: str + """ + if value is not None: + result = str(value) + else: + explicit_value = get_value("Globals", self.field, default=None) + if explicit_value is not None: + result = str(explicit_value) + else: + host = get_value("Globals", "Host", default=None) + derived = HostValue(host).get_value(self.field) if host else "" + result = derived or self.fallback + return self.post_factory(result) if self.post_factory else result + + def random_value(size: int) -> str: """ Produce a random alphabetic string of *size* characters. @@ -214,8 +341,10 @@ def format_string(value: str, locals: dict[str, Any] | None = None) -> str: try: template = _SANDBOX.from_string(value) return template.render(config=config, random=random_value, **(locals or {})) - except Exception: # pragma: no cover - defensive fallback - # TODO: replace with proper logging once the logging subsystem is ready. + except Exception as exc: # pragma: no cover - defensive fallback + from dementor.log.logger import dm_logger # noqa: PLC0415 + + dm_logger.debug("Template render failed: %s", exc) return value @@ -227,84 +356,3 @@ def now() -> str: :rtype: str """ return datetime.datetime.now(tz=datetime.UTC).strftime("%Y-%m-%d-%H-%M-%S") - - -def generate_self_signed_cert( - cn: str, - org: str, - country: str, - state: str, - locality: str, - validity_days: int, -) -> tuple[str, str, tempfile.TemporaryDirectory]: - """ - Generate a self-signed certificate and private key in a temporary directory. - - :param cn: Common name for the certificate. - :param org: Organization name. - :param country: Country code. - :param state: State or province. - :param locality: Locality or city. - :param validity_days: Number of days the certificate is valid. - :return: Tuple of (certificate path, key path, temporary directory object). - """ - # Create temp dir - temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) - - # Generate private key - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - ) - - # Create certificate - subject = issuer = x509.Name( - [ - x509.NameAttribute(NameOID.COUNTRY_NAME, country), - x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state), - x509.NameAttribute(NameOID.LOCALITY_NAME, locality), - x509.NameAttribute(NameOID.ORGANIZATION_NAME, org), - x509.NameAttribute(NameOID.COMMON_NAME, cn), - ] - ) - - cert = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(private_key.public_key()) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.datetime.now(datetime.UTC)) - .not_valid_after( - datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=validity_days) - ) - .add_extension( - x509.SubjectAlternativeName( - [ - x509.DNSName(cn), - ] - ), - critical=False, - ) - .sign(private_key, hashes.SHA256()) - ) - - # Save private key - key_id = "".join(random.choices(string.hexdigits)) - key_path = os.path.join(temp_dir.name, f"key_{key_id}.pem") - with open(key_path, "wb") as f: - f.write( - private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ) - ) - - # Save certificate - cert_id = "".join(random.choices(string.hexdigits)) - cert_path = os.path.join(temp_dir.name, f"cert_{cert_id}.pem") - with open(cert_path, "wb") as f: - f.write(cert.public_bytes(serialization.Encoding.PEM)) - - return cert_path, key_path, temp_dir diff --git a/dementor/db/__init__.py b/dementor/db/__init__.py index 8c1d473..b0658e6 100644 --- a/dementor/db/__init__.py +++ b/dementor/db/__init__.py @@ -40,6 +40,12 @@ HOST_INFO = "_host_info" """Key used in extras dict to store host information for credential logging.""" +BEARER_TOKEN = "BearerToken" # noqa: S105 +"""Credential type for HTTP Bearer token authentication.""" + +DIGEST_MD5 = "digest-md5" +"""Credential type for SASL DIGEST-MD5 authentication.""" + # Backward-compatible aliases so existing imports like # from dementor.db import _CLEARTEXT # keep working without a mass-rename across all protocol files. diff --git a/dementor/filters.py b/dementor/filters.py index 57f325e..3e6bc8d 100644 --- a/dementor/filters.py +++ b/dementor/filters.py @@ -97,19 +97,6 @@ def matches(self, source: str) -> bool: else self.target == source ) - @staticmethod - def from_string(target: str, extra: Any | None = None) -> "FilterObj": - """Create a `FilterObj` from a string pattern. - - :param target: Pattern string. - :type target: str - :param extra: Optional metadata. - :type extra: Any, optional - :return: Filter object. - :rtype: FilterObj - """ - return FilterObj(target, extra) - @staticmethod def from_file(source: str, extra: Any | None) -> list["FilterObj"]: """Load multiple filters from a text file (one per line). @@ -263,7 +250,7 @@ def __init__(self, config: list[str | dict[str, Any]]) -> None: # String means simple filter expression without extra config if not filter_config: continue - self.filters.append(FilterObj.from_string(filter_config)) + self.filters.append(FilterObj(filter_config)) else: # must be a dictionary # 1. Direct target specification diff --git a/dementor/loader.py b/dementor/loader.py index 8832d1c..4f5f6ec 100644 --- a/dementor/loader.py +++ b/dementor/loader.py @@ -17,7 +17,6 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import os import types import typing import pathlib @@ -249,9 +248,9 @@ def __init__(self) -> None: 1. Dementor package's internal `protocols/` directory 2. External `DEMENTOR_PATH/protocols/` directory (for user extensions) """ - self.rs_path: str = os.path.join(DEMENTOR_PATH, "protocols") + self.rs_path: str = str(pathlib.Path(DEMENTOR_PATH) / "protocols") self.search_path: list[str] = [ - os.path.join(os.path.dirname(dementor.__file__), "protocols"), + str(pathlib.Path(dementor.__file__).parent / "protocols"), self.rs_path, ] @@ -304,24 +303,22 @@ def resolve_protocols( if session is not None: protocol_paths.extend(session.extra_modules) - for path in protocol_paths: - if not os.path.exists(path): + for path_str in protocol_paths: + path = pathlib.Path(path_str) + if not path.exists(): # Missing entries are ignored - they may be optional. continue - if os.path.isfile(path): - if not path.endswith(".py"): + if path.is_file(): + if path.suffix != ".py": continue - name = os.path.basename(path)[:-3] # strip .py - protocols[name] = path + protocols[path.stem] = str(path) continue - for filename in os.listdir(path): - if not filename.endswith(".py") or filename == "__init__.py": + for child in path.iterdir(): + if child.suffix != ".py" or child.name == "__init__.py": continue - protocol_path = os.path.join(path, filename) - name = filename[:-3] # strip extension - protocols[name] = protocol_path + protocols[child.stem] = str(child) return protocols @@ -421,7 +418,6 @@ def create_all_threads(self) -> None: servers = self.loader.create_servers(protocol, self.session) self.threads[name.lower()] = list(servers) except Exception as e: - # Log error if needed, but for now pass dm_logger.error(f"Error creating servers for protocol '{name}': {e}") self.threads[name.lower()] = [] @@ -433,7 +429,6 @@ def create_threads(self, name: str) -> None: servers = self.loader.create_servers(protocol, self.session) self.threads[name.lower()] = list(servers) except Exception as e: - # Log error if needed, but for now pass dm_logger.error(f"Error creating servers for protocol '{name}': {e}") self.threads[name.lower()] = [] diff --git a/dementor/log/hexdump.py b/dementor/log/hexdump.py index 4b0bff7..f964eeb 100644 --- a/dementor/log/hexdump.py +++ b/dementor/log/hexdump.py @@ -18,7 +18,9 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -# just reusing impacket's hexdump +# Thin re-export shim that isolates the impacket.structure dependency boundary. +# All internal code imports hexdump from here; if the upstream location ever +# changes, only this file needs updating. from impacket.structure import hexdump __all__ = ["hexdump"] diff --git a/dementor/log/logger.py b/dementor/log/logger.py index 3435388..e728a34 100644 --- a/dementor/log/logger.py +++ b/dementor/log/logger.py @@ -194,47 +194,19 @@ def _get_extra( # Accessors used by the formatting helpers # ----------------------------------------------------------------- def get_protocol_name(self, extra: dict[str, Any] | None = None) -> str: - """ - Return the protocol name (or an empty string). - - :param extra: Optional per-call extra mapping. - :type extra: dict | None - :return: Protocol name. - :rtype: str - """ + """Return the protocol name (or an empty string).""" return str(self._get_extra("protocol", extra, "")) def get_protocol_color(self, extra: dict[str, Any] | None = None) -> str: - """ - Return the colour used for the protocol prefix - defaults to ``white``. - - :param extra: Optional per-call extra mapping. - :type extra: dict | None - :return: Colour name. - :rtype: str - """ + """Return the colour used for the protocol prefix - defaults to ``white``.""" return str(self._get_extra("protocol_color", extra, "white")) def get_host(self, extra: dict[str, Any] | None = None) -> str: - """ - Return the host string (or empty). - - :param extra: Optional per-call extra mapping. - :type extra: dict | None - :return: Host. - :rtype: str - """ + """Return the host string (or empty).""" return str(self._get_extra("host", extra, "")) def get_port(self, extra: dict[str, Any] | None = None) -> str: - """ - Return the port string (or empty). - - :param extra: Optional per-call extra mapping. - :type extra: dict | None - :return: Port. - :rtype: str - """ + """Return the port string (or empty).""" return str(self._get_extra("port", extra, "")) # ----------------------------------------------------------------- @@ -418,7 +390,7 @@ def highlight(self, msg: str, *args: Any, **kwargs: Any) -> None: dm_print(msg, *args, **kwargs) self._emit_log_entry(msg, logging.INFO, *args) - def fail(self, msg: str, color: str | None = None, *args: Any, **kwargs: Any) -> None: + def fail(self, msg: str, *args: Any, color: str | None = None, **kwargs: Any) -> None: """ Log an error condition (red ``[-]`` prefix). @@ -426,10 +398,10 @@ def fail(self, msg: str, color: str | None = None, *args: Any, **kwargs: Any) -> :type msg: str :param color: Override the colour of the ``[-]`` marker. :type color: str | None - :param _args: Positional arguments forwarded to :func:`dm_print`. - :type _args: typing.Any - :param _kwargs: Keyword arguments forwarded to :func:`dm_print`. - :type _kwargs: dict + :param args: Positional arguments forwarded to :func:`dm_print`. + :type args: typing.Any + :param kwargs: Keyword arguments forwarded to :func:`dm_print`. + :type kwargs: dict """ colour = color or "red" prefix = f"[bold {colour}]" + r"\[-]" + f"[/bold {colour}]" diff --git a/dementor/log/stream.py b/dementor/log/stream.py index 605ae04..53457c0 100644 --- a/dementor/log/stream.py +++ b/dementor/log/stream.py @@ -362,7 +362,7 @@ def add(self, **kwargs: Any) -> None: write_to(f"HASH_{hash_type}", str(hash_value)) -def init_streams(session: SessionConfig): +def init_streams(session: SessionConfig) -> None: """Initialize all configured logging streams at startup. Calls `.start()` on each stream class to load config and register instances. @@ -376,7 +376,7 @@ def init_streams(session: SessionConfig): session.streams = dm_streams -def add_stream(name: str, stream: LoggingStream[_T]): +def add_stream(name: str, stream: LoggingStream[_T]) -> None: """Manually register a stream instance. Useful for dynamic or custom streams. @@ -400,7 +400,7 @@ def get_stream(name: str) -> LoggingStream[_T] | None: return dm_streams.get(name) -def close_streams(session: SessionConfig): +def close_streams(session: SessionConfig) -> None: """Close all active streams. Called during graceful shutdown. @@ -412,7 +412,7 @@ def close_streams(session: SessionConfig): stream.close() -def log_to(__name: str, /, **kwargs: Any): +def log_to(__name: str, /, **kwargs: Any) -> None: """Write structured data to a registered stream. :param __name: Stream name (e.g., `"hosts"`, `"hashes"`). @@ -423,7 +423,7 @@ def log_to(__name: str, /, **kwargs: Any): dm_streams[__name].add(**kwargs) -def write_to(name: str, line: str): +def write_to(name: str, line: str) -> None: """Write a raw line to a stream (no formatting). :param name: Stream name. diff --git a/dementor/paths.py b/dementor/paths.py index 905c342..993e0d1 100644 --- a/dementor/paths.py +++ b/dementor/paths.py @@ -28,6 +28,21 @@ HTTP_TEMPLATES_PATH = os.path.join(ASSETS_PATH, "www") +def get_dementor_path() -> str: + """Return the user workspace directory, re-resolving HOME each call. + + Unlike the module-level :data:`DEMENTOR_PATH` constant (which is frozen + at import time), this function reads ``HOME`` on every call. Use it + when the path must reflect a ``HOME`` change after import (e.g., tests). + """ + return os.path.expanduser("~/.dementor") + + +def get_config_path() -> str: + """Return the user config file path, re-resolving HOME each call.""" + return os.path.join(get_dementor_path(), "Dementor.toml") + + def main() -> None: print(f"DefaultWorkspace : {DEMENTOR_PATH}") print(f"AssetsPath : {ASSETS_PATH}") diff --git a/dementor/protocols/ftp.py b/dementor/protocols/ftp.py index 1085681..80fd574 100644 --- a/dementor/protocols/ftp.py +++ b/dementor/protocols/ftp.py @@ -36,7 +36,7 @@ ServerThread, BaseServerThread, ) -from dementor.db import _CLEARTEXT # pyright: ignore[reportPrivateUsage] +from dementor.db import CLEARTEXT __proto__ = ["FTP"] @@ -69,8 +69,8 @@ class FTPServerConfig(TomlConfig): :type ftp_port: int """ - _section_: ClassVar[str] = "FTP" - _fields_: ClassVar[list[A]] = [A("ftp_port", "Port")] + _section_ = "FTP" + _fields_ = [A("ftp_port", "Port")] if typing.TYPE_CHECKING: # pragma: no cover ftp_port: int @@ -86,13 +86,13 @@ class FTP(BaseProtocolModule[FTPServerConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: FTPServerConfig - ) -> BaseServerThread: - """Build :class:`ServerThread` objects for each configured FTP server. + ) -> BaseServerThread[FTPServerConfig]: + """Build a :class:`ServerThread` for the given FTP server config. - :param session: Session containing the ``ftp_config`` list. + :param session: Active session configuration. :type session: :class:`dementor.config.session.SessionConfig` - :return: List of ready-to-start :class:`ServerThread` objects. - :rtype: list[ServerThread] + :return: A ready-to-start :class:`ServerThread` object. + :rtype: ServerThread """ return ServerThread( session, @@ -188,7 +188,7 @@ def handle_data(self, data: bytes | None, transport: socket) -> None: self.config.db.add_auth( client=self.client_address, - credtype=_CLEARTEXT, # intentional clear-text + credtype=CLEARTEXT, username=username, password=password, logger=self.logger, diff --git a/dementor/protocols/http.py b/dementor/protocols/http.py index 3dde031..66dab5b 100644 --- a/dementor/protocols/http.py +++ b/dementor/protocols/http.py @@ -40,22 +40,28 @@ from dementor.loader import BaseProtocolModule, DEFAULT_ATTR from dementor.config.session import SessionConfig from dementor.config.toml import TomlConfig, Attribute as A -from dementor.config.util import format_string, get_value, is_true +from dementor.config.util import ( + format_string, + get_value, + is_true, + HostValue, + HostFallbackValue, +) from dementor.log.logger import ProtocolLogger, dm_logger from dementor.servers import ServerThread, bind_server, BaseServerThread -from dementor.db import _CLEARTEXT, normalize_client_address, _NO_USER +from dementor.db import BEARER_TOKEN, CLEARTEXT, normalize_client_address, NO_USER from dementor.paths import HTTP_TEMPLATES_PATH from dementor.protocols.ntlm import ( - NTLM_build_challenge_message, - NTLM_handle_negotiate_message, - NTLM_handle_authenticate_message, + ntlm_build_challenge_message, + ntlm_handle_negotiate_message, + ntlm_handle_authenticate_message, ) __proto__ = ["HTTP", "WinRM"] -def apply_config(session: SessionConfig): +def apply_config(session: SessionConfig) -> None: session.proxy_config = ProxyAutoConfig(get_value("Proxy", key=None, default={})) @@ -68,7 +74,7 @@ class ProxyAutoConfig(TomlConfig): if typing.TYPE_CHECKING: proxy_script: str | None - def set_proxy_script(self, script): + def set_proxy_script(self, script: str | None) -> None: self.proxy_script = None match script: case str(): @@ -123,10 +129,12 @@ class HTTPServerConfig(TomlConfig): ), A( "http_fqdn", - "FQDN", - "DEMENTOR", + "Host", + None, section_local=False, - factory=format_string, + factory=HostFallbackValue( + HostValue.HOST, "DEMENTOR", post_factory=format_string + ), ), A("http_extra_headers", "ExtraHeaders", list), A("http_wpad_enabled", "WPAD", True, factory=is_true), @@ -154,7 +162,7 @@ class HTTPServerConfig(TomlConfig): http_cert_key: str | None http_use_ssl: bool - def set_http_templates(self, templates_dirs: list[str]): + def set_http_templates(self, templates_dirs: list[str]) -> None: dirs: list[str] = [] for templates_dir in templates_dirs: path = pathlib.Path(templates_dir) @@ -182,7 +190,7 @@ class HTTP(BaseProtocolModule[HTTPServerConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: HTTPServerConfig - ) -> BaseServerThread: + ) -> BaseServerThread[HTTPServerConfig]: return ServerThread( session, server_config, @@ -203,7 +211,7 @@ class WinRM(BaseProtocolModule[HTTPServerConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: HTTPServerConfig - ) -> BaseServerThread: + ) -> BaseServerThread[HTTPServerConfig]: return ServerThread( session, server_config, @@ -216,12 +224,22 @@ def create_server_thread( @override def apply_config(self, session: SessionConfig) -> None: + # Load from TOML [WinRM].Server entries first; fall back to hardcoded + # defaults (5985/5986) only when the user has not configured anything. + super().apply_config(session) + if session.winrm_config: + return + winrm_config: list[HTTPServerConfig] = [] config = HTTPServerConfig({"Port": 5985}) config.http_wpad_enabled = False config.http_webdav_enabled = False - ssl_enabled = bool(config.http_cert) + # Use a fully-loaded config to check global SSL cert availability. + # HTTPServerConfig({"Port": 5985}) is a minimal dict with no TOML context, + # so config.http_cert is always None there. We need the global defaults instead. + global_config = TomlConfig.build_config(HTTPServerConfig) + ssl_enabled = bool(global_config.http_cert) config.http_cert = None config.http_cert_key = None winrm_config.append(config) @@ -230,6 +248,8 @@ def apply_config(self, session: SessionConfig) -> None: ssl_config.http_wpad_enabled = False ssl_config.http_webdav_enabled = False ssl_config.http_use_ssl = True + ssl_config.http_cert = global_config.http_cert + ssl_config.http_cert_key = global_config.http_cert_key winrm_config.append(ssl_config) if not session.winrm_enabled: @@ -243,7 +263,7 @@ class HTTPHeaders: class HTTPHandler(BaseHTTPRequestHandler): - # NTLM is a connection-based auth — the 3-message handshake must happen + # NTLM is a connection-based auth - the 3-message handshake must happen # on a single persistent connection. HTTP/1.0 closes after each response, # breaking the handshake. HTTP/1.1 keeps the connection alive by default. protocol_version = "HTTP/1.1" @@ -251,17 +271,17 @@ class HTTPHandler(BaseHTTPRequestHandler): def __init__( self, session: SessionConfig, - config: HTTPServerConfig, + server_config: HTTPServerConfig, request, client_address: tuple[str, int], server, ) -> None: - self.config = config # REVISIT: this is confusing - self.session = session + self.config = session + self.server_config = server_config self.client_address = client_address self.challenge = None self.setup_proto_logger() - for http_method in config.http_methods: + for http_method in server_config.http_methods: if http_method in ("OPTIONS", "PROPFIND"): # reserved options continue @@ -274,13 +294,13 @@ def __init__( super().__init__(request, client_address, server) - def setup_proto_logger(self): + def setup_proto_logger(self) -> None: self.logger: ProtocolLogger = ProtocolLogger( extra={ "protocol": "HTTP", "protocol_color": "chartreuse3", "host": normalize_client_address(self.client_address[0]), - "port": self.config.http_port, + "port": self.server_config.http_port, } ) self.webdav_logger: ProtocolLogger = ProtocolLogger( @@ -288,34 +308,34 @@ def setup_proto_logger(self): "protocol": "WebDAV", "protocol_color": "sea_green3", "host": normalize_client_address(self.client_address[0]), - "port": self.config.http_port, + "port": self.server_config.http_port, } ) - def do_PROPFIND(self): - if self.config.http_webdav_enabled: + def do_PROPFIND(self) -> None: + if self.server_config.http_webdav_enabled: self.handle_request(self.webdav_logger) else: self.send_error(HTTPStatus.NOT_FOUND, "Not Found") - def do_OPTIONS(self): + def do_OPTIONS(self) -> None: # always support everything self.send_response(HTTPStatus.OK) self.send_header("Allow", "OPTIONS,GET,HEAD,POST,TRACE,PROPFIND") self.end_headers() - def do_HEAD(self): + def do_HEAD(self) -> None: self.send_response(HTTPStatus.OK) self.send_header("Content-Length", "0") self.end_headers() def version_string(self) -> str: - return self.config.http_server_type + return self.server_config.http_server_type def log_message(self, format: str, *args) -> None: # let us log mssages text = format % args - msg = text.translate(self._control_char_table) + msg = text.translate(self._control_char_table) # ty:ignore[unresolved-attribute] self.logger.debug(msg) def send_response(self, code: int, message: str | None = None) -> None: @@ -323,7 +343,7 @@ def send_response(self, code: int, message: str | None = None) -> None: self._headers_buffer = [] super().send_response(code, message) - for header in self.config.http_extra_headers: + for header in self.server_config.http_extra_headers: self._headers_buffer.append(f"{header}\r\n".encode("latin-1", "strict")) def send_error( @@ -359,27 +379,29 @@ def send_error( if body: self.wfile.write(body) - def is_wpad_request(self): + def is_wpad_request(self) -> bool: path = pathlib.Path(self.path) return path.suffix == ".pac" or path.stem == "wpad" - def display_request(self, req_type: str | None = None, logger=None): + def display_request(self, req_type: str | None = None, logger=None) -> None: line = f"{self.command} request for {markup.escape(self.path)}" if req_type: line = f"{line} ({req_type})" (logger or self.logger).display(line) - def send_wpad_script(self): - if self.config.proxy_config.proxy_script: + def send_wpad_script(self) -> None: + if self.server_config.proxy_config.proxy_script: # try to render the custom script - template = Template(self.config.proxy_config.proxy_script, autoescape=True) + template = Template( + self.server_config.proxy_config.proxy_script, autoescape=True + ) script = template.render( - server=self.config, - session=self.session, + server=self.server_config, + session=self.config, ) else: script = self.server.render_page("wpad.dat") - if self.config.http_wpad_enabled and not script: + if self.server_config.http_wpad_enabled and not script: self.logger.fail("WPAD enabled but script not configured") return self.send_error(HTTPStatus.NOT_FOUND) @@ -392,13 +414,13 @@ def send_wpad_script(self): self.end_headers() self.wfile.write(body) - def handle_request(self, logger): + def handle_request(self, logger: "ProtocolLogger") -> None: if HTTPHeaders.AUTHORIZATION not in self.headers: # make sure the client authenticates to us if ( - self.config.http_wpad_enabled + self.server_config.http_wpad_enabled and self.is_wpad_request() - and not self.config.http_wpad_auth + and not self.server_config.http_wpad_auth ): return self.send_wpad_script() @@ -408,7 +430,7 @@ def handle_request(self, logger): "Unauthorized", headers=[ (HTTPHeaders.WWW_AUTHENTICATE, scheme) - for scheme in self.config.http_auth_schemes + for scheme in self.server_config.http_auth_schemes ], ) else: @@ -420,7 +442,7 @@ def handle_request(self, logger): logger.debug(f"Unknown authentication scheme: {name}") self.send_error(HTTPStatus.NOT_FOUND, "Not Found") - def auth_negotiate(self, token, logger): + def auth_negotiate(self, token: str, logger: "ProtocolLogger") -> None: # try to decode negotiate token if token.startswith("YII"): # possible kerberos authentication attempt, try to downgrade @@ -428,7 +450,9 @@ def auth_negotiate(self, token, logger): self.auth_ntlm(token, logger, scheme="Negotiate") - def auth_ntlm(self, token, logger, scheme=None): + def auth_ntlm( + self, token: str, logger: "ProtocolLogger", scheme: str | None = None + ) -> None: try: message = ntlm.NTLM_HTTP.get_instance(f"NTLM {token}") except Exception: @@ -440,17 +464,24 @@ def auth_ntlm(self, token, logger, scheme=None): match message: case ntlm.NTLM_HTTP_AuthNegotiate(): self.display_request("NTLMSSP_NEGOTIATE", logger) - self._ntlm_negotiate_fields = NTLM_handle_negotiate_message( + self._ntlm_negotiate_fields = ntlm_handle_negotiate_message( message, logger ) - challenge = NTLM_build_challenge_message( + host = HostValue(self.server_config.http_fqdn) + challenge = ntlm_build_challenge_message( message, - challenge=self.session.ntlm_challenge, - nb_computer=self.session.ntlm_nb_computer, - nb_domain=self.session.ntlm_nb_domain, - disable_ess=self.session.ntlm_disable_ess, - disable_ntlmv2=self.session.ntlm_disable_ntlmv2, - log=logger, + challenge=self.config.ntlm_challenge, + nb_computer=host.get_value(HostValue.NETBIOS_COMPUTER), + nb_domain=host.get_value(HostValue.NETBIOS_DOMAIN), + disable_ess=self.config.ntlm_disable_ess, + disable_ntlmv2=self.config.ntlm_disable_ntlmv2, + target_type=self.config.ntlm_target_type, + version=self.config.ntlm_version, + dns_computer=host.get_value(HostValue.DNS_COMPUTER), + dns_domain=host.get_value(HostValue.DNS_DOMAIN), + # REVISIT: capture DNSTree too + # dns_tree=self.config.ntlm_dns_tree, + log=self.logger, ) self.send_response(HTTPStatus.UNAUTHORIZED, "Unauthorized") data = base64.b64encode(challenge.getData()).decode() @@ -462,35 +493,35 @@ def auth_ntlm(self, token, logger, scheme=None): case ntlm.NTLM_HTTP_AuthChallengeResponse(): self.display_request("NTLMSSP_AUTH", logger) - NTLM_handle_authenticate_message( + ntlm_handle_authenticate_message( message, - challenge=self.session.ntlm_challenge, + challenge=self.config.ntlm_challenge, client=self.client_address, - session=self.session, + session=self.config, logger=logger, extras=self.get_extras(), negotiate_fields=getattr(self, "_ntlm_negotiate_fields", None), ) - self.finish_request(logger) + self._complete_auth_request(logger) case _: logger.fail(f"Invalid negotiate authentication: {token}") self.send_error(HTTPStatus.INTERNAL_SERVER_ERROR, "Internal Server Error") - def auth_bearer(self, token, logger): + def auth_bearer(self, token: str, logger: "ProtocolLogger") -> None: self.display_request("Bearer", logger) - self.session.db.add_auth( + self.config.db.add_auth( client=self.client_address, - credtype="BearerToken", - username=_NO_USER, + credtype=BEARER_TOKEN, + username=NO_USER, password=token.encode().hex(), logger=logger, extras=self.get_extras(), custom=True, ) - self.finish_request(logger) + self._complete_auth_request(logger) - def auth_basic(self, token, logger): + def auth_basic(self, token: str, logger: "ProtocolLogger") -> None: self.display_request("Basic", logger) try: username, password = base64.b64decode(token).decode().split(":", 1) @@ -499,24 +530,24 @@ def auth_basic(self, token, logger): self.send_error(HTTPStatus.INTERNAL_SERVER_ERROR, "Internal Server Error") return - self.session.db.add_auth( + self.config.db.add_auth( client=self.client_address, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, password=password, logger=logger, username=username, extras=self.get_extras(), ) - self.finish_request(logger) + self._complete_auth_request(logger) - def finish_request(self, logger): + def _complete_auth_request(self, logger): # inspect the path first, WPAD Auth and custom files are handled separately - if self.is_wpad_request() and self.config.http_wpad_enabled: + if self.is_wpad_request() and self.server_config.http_wpad_enabled: return self.send_wpad_script() self.send_error(418, "I'm a teapot") - def get_extras(self): + def get_extras(self) -> dict[str, str]: extras = {} if "User-Agent" in self.headers: extras["User-Agent"] = self.headers["User-Agent"] @@ -533,13 +564,13 @@ def get_extras(self): class WinRMHandler(HTTPHandler): - def setup_proto_logger(self): + def setup_proto_logger(self) -> None: self.logger = ProtocolLogger( extra={ "protocol": "WinRM", "protocol_color": "spring_green1", "host": normalize_client_address(self.client_address[0]), - "port": self.config.http_port, + "port": self.server_config.http_port, } ) diff --git a/dementor/protocols/imap.py b/dementor/protocols/imap.py index cdaa733..ec46642 100644 --- a/dementor/protocols/imap.py +++ b/dementor/protocols/imap.py @@ -28,15 +28,16 @@ import shlex import typing +from typing import Literal, overload from typing_extensions import override from impacket import ntlm from dementor.config.session import SessionConfig from dementor.loader import BaseProtocolModule, DEFAULT_ATTR from dementor.protocols.ntlm import ( - NTLM_build_challenge_message, - NTLM_handle_authenticate_message, - NTLM_handle_negotiate_message, + ntlm_build_challenge_message, + ntlm_handle_authenticate_message, + ntlm_handle_negotiate_message, ) from dementor.servers import ( ServerThread, @@ -46,35 +47,17 @@ BaseServerThread, ) from dementor.log.logger import ProtocolLogger -from dementor.db import _CLEARTEXT +from dementor.db import CLEARTEXT from dementor.config.toml import ( TomlConfig, Attribute as A, ) from dementor.config.attr import ATTR_TLS, ATTR_CERT, ATTR_KEY -from dementor.config.util import get_value +from dementor.config.util import HostValue, HostFallbackValue __proto__ = ["IMAP"] -def apply_config(session: SessionConfig): - session.imap_config = list( - map(IMAPServerConfig, get_value("IMAP", "Server", default=[])) - ) - - -def create_server_threads(session: SessionConfig): - return [ - ServerThread( - session, - IMAPServer, - server_config=server_config, - server_address=(session.bind_address, server_config.imap_port), - ) - for server_config in (session.imap_config if session.imap_enabled else []) - ] - - IMAP_CAPABILITIES = [ # NOTE: support STARTTLS is currently not avaialble # "STARTTLS", @@ -89,7 +72,13 @@ class IMAPServerConfig(TomlConfig): _section_ = "IMAP" _fields_ = [ A("imap_port", "Port"), - A("imap_fqdn", "FQDN", "Dementor", section_local=False), + A( + "imap_fqdn", + "Host", + None, + section_local=False, + factory=HostFallbackValue(HostValue.HOST, "DEMENTOR"), + ), A("imap_caps", "Capabilities", IMAP_CAPABILITIES), A("imap_auth_mechanisms", "AuthMechanisms", IMAP_AUTH_MECHS), A("imap_banner", "Banner", "IMAP4rev2 service ready"), @@ -118,7 +107,7 @@ class IMAP(BaseProtocolModule[IMAPServerConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: IMAPServerConfig - ) -> BaseServerThread: + ) -> BaseServerThread[IMAPServerConfig]: return ServerThread( session, server_config, @@ -133,8 +122,10 @@ class StopHandler(Exception): class IMAPHandler(BaseProtoHandler): - def __init__(self, config, server_config, request, client_address, server) -> None: - self.server_config = server_config + def __init__( + self, config, server_config: IMAPServerConfig, request, client_address, server + ) -> None: + self.server_config: IMAPServerConfig = server_config self.seq_id = None super().__init__(config, request, client_address, server) @@ -173,13 +164,13 @@ def _write_line(self, msg: str) -> None: # - NO (indicating failure), or # - BAD (indicating a protocol error such as unrecognized command or # command syntax error). - def ok(self, msg: str, seq=True): + def ok(self, msg: str, seq: bool = True) -> None: self._push(f"OK {msg}", seq) - def no(self, msg: str, seq=True): + def no(self, msg: str, seq: bool = True) -> None: self._push(f"NO {msg}", seq) - def bad(self, msg: str, seq=True): + def bad(self, msg: str, seq: bool = True) -> None: self._push(f"BAD {msg}", seq) # NOTE: Section 2.2.2 states: @@ -195,6 +186,22 @@ def unquoted(self, data: str) -> str: # CR and LF, encoded in UTF-8, with double quote (<">) characters at each end. return data.removeprefix('"').removesuffix('"') + @overload + def challenge_auth( + self, + token: bytes | None = ..., + decode: Literal[False] = ..., + prefix: str | None = ..., + ) -> bytes: ... + + @overload + def challenge_auth( + self, + token: bytes | None = ..., + decode: Literal[True] = ..., + prefix: str | None = ..., + ) -> str: ... + def challenge_auth( self, token: bytes | None = None, @@ -277,7 +284,7 @@ def recv_line(self, size: int) -> str | None: # implementation # 7.2.2. CAPABILITY Response - def do_CAPABILITY(self, args): + def do_CAPABILITY(self, args: list[str]) -> None: # The CAPABILITY response occurs as a result of a CAPABILITY command. The # capability listing contains a space-separated listing of capability names # that the server supports. The capability listing MUST include the atom @@ -289,17 +296,17 @@ def do_CAPABILITY(self, args): self.ok("CAPABILITY completed") # 6.1.2. NOOP Command - def do_NOOP(self, args): + def do_NOOP(self, args: list[str]) -> None: # The NOOP command always succeeds. It does nothing. self.ok("NOOP completed") # 6.4.1. CLOSE Command - def do_CLOSE(self, args): + def do_CLOSE(self, args: list[str]) -> None: self.ok("CLOSE completed") raise StopHandler # 6.2.3. LOGIN Command - def do_LOGIN(self, args: str): + def do_LOGIN(self, args: str) -> None: if len(args) != 2: return self.bad("Invalid number of arguments") @@ -309,12 +316,12 @@ def do_LOGIN(self, args: str): username=self.unquoted(username), password=self.unquoted(password), logger=self.logger, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, ) self.no("LOGIN failed") # 6.2.2. AUTHENTICATE Command - def do_AUTHENTICATE(self, args): + def do_AUTHENTICATE(self, args: list[str]) -> None: if len(args) < 1: return self.bad("Invalid number of arguments") @@ -326,12 +333,12 @@ def do_AUTHENTICATE(self, args): self.bad("Unknown authentication mechanism") # 6.2.1. STARTTLS Command - def do_STARTTLS(self, args): + def do_STARTTLS(self, args: list[str]) -> None: # NO - TLS negotiation can't be initiated, due to server configuration error self.no("STARTTLS not supported") # [MS-OXIMAP] 2.2.1 IMAP4 NTLM - def auth_NTLM(self, initial_response=None): + def auth_NTLM(self, initial_response: bytes | None = None) -> None: # IMAP4_AUTHENTICATE_NTLM_Supported_Response if not initial_response: token = self.challenge_auth() @@ -353,14 +360,21 @@ def auth_NTLM(self, initial_response=None): return self.bad("NTLM negotiation failed") # IMAP4_AUTHENTICATE_NTLM_Blob_Response - negotiate_fields = NTLM_handle_negotiate_message(negotiate, self.logger) - challenge = NTLM_build_challenge_message( + negotiate_fields = ntlm_handle_negotiate_message(negotiate, self.logger) + host = HostValue(self.server_config.imap_fqdn) + challenge = ntlm_build_challenge_message( negotiate, challenge=self.config.ntlm_challenge, - nb_computer=self.config.ntlm_nb_computer, - nb_domain=self.config.ntlm_nb_domain, + nb_computer=host.get_value(HostValue.NETBIOS_COMPUTER), + nb_domain=host.get_value(HostValue.NETBIOS_DOMAIN), disable_ess=self.config.ntlm_disable_ess, disable_ntlmv2=self.config.ntlm_disable_ntlmv2, + target_type=self.config.ntlm_target_type, + version=self.config.ntlm_version, + dns_computer=host.get_value(HostValue.DNS_COMPUTER), + dns_domain=host.get_value(HostValue.DNS_DOMAIN), + # REVISIT: capture DNSTree too + # dns_tree=self.config.ntlm_dns_tree, log=self.logger, ) @@ -373,7 +387,7 @@ def auth_NTLM(self, initial_response=None): self.logger.debug(f"NTLM authentication failed: {e}") return self.bad("NTLM authentication failed") - NTLM_handle_authenticate_message( + ntlm_handle_authenticate_message( auth_message, challenge=self.config.ntlm_challenge, client=self.client_address, @@ -388,7 +402,7 @@ def auth_NTLM(self, initial_response=None): self.ok("AUTHENTICATE completed") - def auth_PLAIN(self, initial_response=None): + def auth_PLAIN(self, initial_response: bytes | None = None) -> None: if initial_response: login_and_password = base64.b64decode(initial_response) else: @@ -405,11 +419,11 @@ def auth_PLAIN(self, initial_response=None): username=login.decode(errors="replace"), password=password.decode(errors="replace"), logger=self.logger, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, ) self.no("LOGIN failed") - def auth_LOGIN(self): + def auth_LOGIN(self) -> None: username = self.challenge_auth(decode=True) password = self.challenge_auth(decode=True) self.config.db.add_auth( @@ -417,7 +431,7 @@ def auth_LOGIN(self): username=username, password=password, logger=self.logger, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, ) self.no("LOGIN failed") diff --git a/dementor/protocols/ipp.py b/dementor/protocols/ipp.py index c7da201..5f49ab2 100644 --- a/dementor/protocols/ipp.py +++ b/dementor/protocols/ipp.py @@ -169,13 +169,13 @@ class IPPConfig(TomlConfig): ipp_remote_cmd_attr: str ipp_remote_cmd_filter: str - def set_ipp_supported_operations(self, value): + def set_ipp_supported_operations(self, value) -> None: self.ipp_supported_operations = [ IppOperation[operation] if isinstance(operation, str) else operation for operation in (value or []) ] - def set_ipp_extra_attrib(self, extra: list[dict[str, Any]] | None): + def set_ipp_extra_attrib(self, extra: list[dict[str, Any]] | None) -> None: # A list of attributes to add to the GET-PRINTER-ATTRIBUTES response. # This settings can also be used to add custom attributes to the # ATTRIBUTE_TAG_MAP. @@ -238,7 +238,7 @@ def __init__(self, session, config, request, client_address, server) -> None: self.setup_proto_logger() super().__init__(request, client_address, server) - def setup_proto_logger(self): + def setup_proto_logger(self) -> None: self.logger = ProtocolLogger( extra={ "protocol": "IPP", @@ -287,7 +287,7 @@ def log_message(self, format: str, *args: Any) -> None: # let us log mssages pass - def do_POST(self): + def do_POST(self) -> None: # handle IPP request try: data = self.rfile.read(self.content_length) @@ -311,7 +311,7 @@ def do_POST(self): else: method(req) - def ipp_get_printer_attributes(self, req: dict[str, Any]): + def ipp_get_printer_attributes(self, req: dict[str, Any]) -> None: # [4.2.5. Get-Printer-Attributes Operation] # This REQUIRED operation allows a Client to request the values of the # attributes of a Printer. In the request, the Client supplies the set diff --git a/dementor/protocols/kerberos.py b/dementor/protocols/kerberos.py index 3a25a68..c991794 100644 --- a/dementor/protocols/kerberos.py +++ b/dementor/protocols/kerberos.py @@ -74,13 +74,13 @@ class KerberosConfig(TomlConfig): krb5_etype: int krb5_error_code: int - def set_krb5_salt(self, value): + def set_krb5_salt(self, value) -> None: if isinstance(value, bytes): self.krb5_salt = value else: self.krb5_salt = str(value).encode("utf-8", errors="replace") - def set_krb5_etype(self, value): + def set_krb5_etype(self, value) -> None: match value: case int(): self.krb5_etype = value @@ -89,7 +89,7 @@ def set_krb5_etype(self, value): case _: self.krb5_etype = EncryptionTypes[value].value - def set_krb5_error_code(self, value): + def set_krb5_error_code(self, value) -> None: match value: case int(): self.krb5_error_code = value @@ -116,7 +116,7 @@ def create_server_threads(self, session: SessionConfig) -> list[BaseServerThread ) -def KRB5_Err( +def krb5_err( error_code: int, realm: str | None = None, sname: list[str] | None = None, @@ -161,13 +161,13 @@ def KRB5_Err( return krb_error -def KRB5_ASREQ_to_hashcat_format( +def krb5_asreq_to_hashcat_format( etype: int, username: str | bytes, realm: str | bytes, enc_timestamp: bytes, salt: bytes, -) -> tuple: +) -> tuple[str, str]: if isinstance(username, bytes): username = username.decode("utf-8", errors="replace") @@ -270,7 +270,7 @@ def handle_data(self, data, transport) -> None: user_name = str(req_body["cname"]["name-string"][0]) domain = str(req_body["realm"]) - hashname, hashvalue = KRB5_ASREQ_to_hashcat_format( + hashname, hashvalue = krb5_asreq_to_hashcat_format( encrypted_data["etype"], username=user_name, realm=domain, @@ -292,10 +292,10 @@ def handle_data(self, data, transport) -> None: realm = str(as_req["req-body"]["realm"]) sname = ["krbtgt", realm] if error_code != ErrorCodes.KDC_ERR_PREAUTH_REQUIRED.value: - krb_error = KRB5_Err(error_code, realm, sname) + krb_error = krb5_err(error_code, realm, sname) else: # make sure we require pre-authentication - krb_error = KRB5_Err( + krb_error = krb5_err( error_code, realm=realm, sname=sname, diff --git a/dementor/protocols/ldap.py b/dementor/protocols/ldap.py index d548d9e..0e79164 100644 --- a/dementor/protocols/ldap.py +++ b/dementor/protocols/ldap.py @@ -91,8 +91,9 @@ from dementor.config.session import SessionConfig from dementor.config.toml import Attribute as A from dementor.config.toml import TomlConfig -from dementor.config.util import generate_self_signed_cert -from dementor.db import _CLEARTEXT +from dementor.config.tls import generate_self_signed_cert +from dementor.config.util import HostValue, HostFallbackValue +from dementor.db import CLEARTEXT, DIGEST_MD5 from dementor.loader import DEFAULT_ATTR, BaseProtocolModule from dementor.log import hexdump from dementor.log.logger import ProtocolLogger @@ -100,9 +101,10 @@ ATTR_NTLM_CHALLENGE, ATTR_NTLM_DISABLE_ESS, ATTR_NTLM_DISABLE_NTLMV2, - NTLM_AUTH_CreateChallenge, - NTLM_report_auth, - NTLM_split_fqdn, + ntlm_auth_create_challenge, + ntlm_report_auth, + ntlm_split_fqdn, + ntlm_build_challenge_message, ) from dementor.protocols.spnego import SPNEGO_NTLMSSP_MECH, SPNEGONegotiator from dementor.servers import ( @@ -335,7 +337,13 @@ class LDAPServerConfig(TomlConfig): A("ldap_mech", "SASLMechanisms", LDAP_DEFAULT_MECH), # Connection settings A("ldap_timeout", "Timeout", 0), - A("ldap_fqdn", "FQDN", "DEMENTOR", section_local=False), + A( + "ldap_fqdn", + "Host", + None, + section_local=False, + factory=HostFallbackValue(HostValue.HOST, "DEMENTOR"), + ), # TLS configuration ATTR_TLS, ATTR_KEY, @@ -639,11 +647,14 @@ def enable_sealing(self) -> None: """ self.sealing_active = True - def set_negotiated_qop(self, qop: str) -> None: - """Set the negotiated quality of protection. + def apply_negotiated_qop(self, qop: str) -> None: + """Set the negotiated quality of protection and activate security layers. :param qop: Quality of protection ("auth", "auth-int", or "auth-conf") :type qop: str + + Side effects: "auth-int" activates signing; "auth-conf" activates both + signing and sealing. """ self.negotiated_qop = qop if qop == "auth-int": @@ -728,7 +739,7 @@ def add_if_requested(attr_name: str, values: list[str]) -> None: "supportedSASLQoPOptions", self.server_config.ldap_sasl_qop_options ) - dns_hostname, dns_domain = NTLM_split_fqdn(self.server_config.ldap_fqdn) + dns_hostname, dns_domain = ntlm_split_fqdn(self.server_config.ldap_fqdn) add_if_requested("dnsHostName", [dns_hostname]) naming_context = self.server_config._parse_domain_to_dn(dns_domain) @@ -1085,7 +1096,7 @@ def _is_encrypted_message(self, data: bytes) -> bool: return False @override - def recv(self, size: int) -> LDAPMessage | None: # type: ignore[override] + def recv(self, size: int) -> LDAPMessage | None: # type: ignore[override] # ty:ignore[invalid-method-override] """Receive and decode an LDAP message from the client. :param size: Maximum bytes to receive (ignored, uses 8192) @@ -1121,7 +1132,7 @@ def recv(self, size: int) -> LDAPMessage | None: # type: ignore[override] return message @override - def send(self, data: LDAPMessage | list[LDAPMessage] | None) -> None: # type: ignore[override] + def send(self, data: LDAPMessage | list[LDAPMessage] | None) -> None: # type: ignore[override] # ty:ignore[invalid-method-override] """Send an LDAP message or list of messages to the client. :param data: LDAP message(s) to send, or None to skip @@ -1217,19 +1228,28 @@ def _handle_spnego_ntlm_mech( else: token = ntlm.NTLMAuthNegotiate() - ntlm_challenge = NTLM_AUTH_CreateChallenge( + host = HostValue(self.server.server_config.ldap_fqdn) + ntlm_challenge = ntlm_build_challenge_message( token, - *NTLM_split_fqdn(self.server.server_config.ldap_fqdn), - challenge=self.server.server_config.ntlm_challenge, - disable_ess=self.server.server_config.ntlm_disable_ess, - disable_ntlmv2=self.server.server_config.ntlm_disable_ntlmv2, + challenge=self.config.ntlm_challenge, + nb_computer=host.get_value(HostValue.NETBIOS_COMPUTER), + nb_domain=host.get_value(HostValue.NETBIOS_DOMAIN), + disable_ess=self.config.ntlm_disable_ess, + disable_ntlmv2=self.config.ntlm_disable_ntlmv2, + target_type=self.config.ntlm_target_type, + version=self.config.ntlm_version, + dns_computer=host.get_value(HostValue.DNS_COMPUTER), + dns_domain=host.get_value(HostValue.DNS_DOMAIN), + # REVISIT: capture DNSTree too + # dns_tree=self.config.ntlm_dns_tree, + log=self.logger, ) return ntlm_challenge.getData(), False if mech_token: token = ntlm.NTLMAuthChallengeResponse() token.fromString(mech_token) - NTLM_report_auth( + ntlm_report_auth( auth_token=token, challenge=self.server.server_config.ntlm_challenge, client=self.client_address, @@ -1347,7 +1367,8 @@ def handle_bindRequest( :type message: LDAPMessage :param bind_req: Bind request protocol operation :type bind_req: BindRequest - :return: None (response is sent directly to client) + :return: ``None`` when the response is sent directly to the client, or + a :class:`LDAPMessage` on the unsupported-version path. :rtype: LDAPMessage | None """ self.logger.debug(f"LDAP Bind Request from {self.client_address}") @@ -1465,7 +1486,7 @@ def _handle_simple_bind( username, domain, extras = self._parse_cleartext_user(bind_name) self.config.db.add_auth( client=self.client_address, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, username=username, password=bind_password, domain=domain, @@ -1536,12 +1557,9 @@ def _handle_sicily_negotiate( negotiate.fromString(nego_token_raw) fqdn = self.server.server_config.ldap_fqdn - name, domain = fqdn.split(".", 1) if "." in fqdn else (fqdn, "") - - ntlm_challenge = NTLM_AUTH_CreateChallenge( + ntlm_challenge = ntlm_auth_create_challenge( negotiate, - name, - domain, + *ntlm_split_fqdn(fqdn), challenge=self.server.server_config.ntlm_challenge, disable_ess=self.server.server_config.ntlm_disable_ess, disable_ntlmv2=self.server.server_config.ntlm_disable_ntlmv2, @@ -1587,7 +1605,7 @@ def _handle_sicily_response(self, message: LDAPMessage, blob: bytes) -> None: self.logger.debug("NTLM authenticate phase") auth_message = NTLMAuthChallengeResponse() auth_message.fromString(blob) - NTLM_report_auth( + ntlm_report_auth( auth_token=auth_message, challenge=self.server.server_config.ntlm_challenge, client=self.client_address, @@ -1834,27 +1852,23 @@ def handle_extendedReq( else: self.logger.debug(f"Extended OID: {req_oid}, no value") - print(req_oid) if req_oid == LDAP_STARTTLS_OID: - """Handle StartTLS Extended Operation per RFC 4513 §3. - - Per RFC 4513 §3: StartTLS upgrades an existing LDAP connection to use - TLS encryption. This protects subsequent operations from eavesdropping. - - StartTLS Protocol Flow (RFC 4513 §3.1): - 1. Client sends StartTLS extended request - 2. Server sends success response - 3. Client and server perform TLS handshake - 4. Connection is now encrypted, authentication can proceed - - Restrictions (RFC 4513 §3.1.1): - - MUST NOT be used if TLS is already active - - MUST NOT be used during SASL negotiation - - MUST NOT be used after successful bind (some implementations) - - Per RFC 5929: After StartTLS, channel binding can be used to bind - SASL authentication to the TLS channel, preventing relay attacks. - """ + # Handle StartTLS Extended Operation per RFC 4513 §3. + # StartTLS upgrades an existing LDAP connection to use TLS encryption. + # + # Protocol flow (RFC 4513 §3.1): + # 1. Client sends StartTLS extended request + # 2. Server sends success response + # 3. Client and server perform TLS handshake + # 4. Connection is now encrypted, authentication can proceed + # + # Restrictions (RFC 4513 §3.1.1): + # - MUST NOT be used if TLS is already active + # - MUST NOT be used during SASL negotiation + # - MUST NOT be used after successful bind (some implementations) + # + # Per RFC 5929: After StartTLS, channel binding can bind SASL auth + # to the TLS channel, preventing relay attacks. self.logger.debug("Processing StartTLS extended operation") # RFC 4513 §3.1.1: StartTLS MUST NOT be used if TLS is already active @@ -2048,11 +2062,11 @@ def _handle_direct_ntlm( msg_type = data[8] if msg_type == 1: self.logger.debug("Direct NTLM: detected NTLM negotiate message") - return self._handle_NTLM_Negotiate(message, data) + return self._handle_ntlm_negotiate(message, data) if msg_type == 3: self.logger.debug("Direct NTLM: detected NTLM authenticate message") - return self._handle_NTLM_Auth(message, data) + return self._handle_ntlm_auth(message, data) self.logger.debug(f"Direct NTLM: unsupported NTLM message type {msg_type}") return self.server.bind_result(message, reason=LDAP_AUTH_METHOD_NOT_SUPPORTED) @@ -2061,7 +2075,7 @@ def _handle_direct_ntlm( self.logger.debug(f"Direct NTLM: failed to parse token: {e}") return self.server.bind_result(message, reason=LDAP_AUTH_METHOD_NOT_SUPPORTED) - def _handle_NTLM_Negotiate(self, message: LDAPMessage, nego_token_raw: bytes) -> None: + def _handle_ntlm_negotiate(self, message: LDAPMessage, nego_token_raw: bytes) -> None: """Handle NTLM Negotiate message. :param message: LDAP message containing bind request @@ -2073,19 +2087,16 @@ def _handle_NTLM_Negotiate(self, message: LDAPMessage, nego_token_raw: bytes) -> negotiate.fromString(nego_token_raw) fqdn = self.server.server_config.ldap_fqdn - name, domain = fqdn.split(".", 1) if "." in fqdn else (fqdn, "") - - ntlm_challenge = NTLM_AUTH_CreateChallenge( + ntlm_challenge = ntlm_auth_create_challenge( negotiate, - name, - domain, + *ntlm_split_fqdn(fqdn), challenge=self.server.server_config.ntlm_challenge, disable_ess=self.server.server_config.ntlm_disable_ess, disable_ntlmv2=self.server.server_config.ntlm_disable_ntlmv2, ) self.send(self.server.bind_result(message, matched_dn=ntlm_challenge.getData())) - def _handle_NTLM_Auth(self, message: LDAPMessage, blob: bytes) -> None: + def _handle_ntlm_auth(self, message: LDAPMessage, blob: bytes) -> None: """Handle NTLM Authenticate message. :param message: LDAP message containing bind request @@ -2096,7 +2107,7 @@ def _handle_NTLM_Auth(self, message: LDAPMessage, blob: bytes) -> None: """ auth_message = NTLMAuthChallengeResponse() auth_message.fromString(blob) - NTLM_report_auth( + ntlm_report_auth( auth_token=auth_message, challenge=self.server.server_config.ntlm_challenge, client=self.client_address, @@ -2160,17 +2171,14 @@ def _handle_sasl_DIGEST_MD5( # RFC 2831 §2.1: Client may send empty initial response # Server responds with challenge containing realm, nonce, qop, etc. - if not credentials or len(credentials) == 0: + if not credentials: self.logger.debug("DIGEST-MD5: Sending initial challenge") nonce = f"+Upgraded+v1{secrets.token_hex(32)}" - timestamp = str(int(time.time())) # Use configured domain as realm instead of hardcoded value - _, realm = NTLM_split_fqdn(self.server.server_config.ldap_fqdn) + _, realm = ntlm_split_fqdn(self.server.server_config.ldap_fqdn) self.digest_md5_state = { "nonce": nonce, - "timestamp": timestamp, - "realm": realm, } # Only offer QoP options that are actually supported @@ -2246,7 +2254,7 @@ def _handle_sasl_DIGEST_MD5( ) self.config.db.add_auth( client=self.client_address, - credtype="digest-md5", + credtype=DIGEST_MD5, username=username, password=digest_hash, domain=domain, @@ -2258,7 +2266,7 @@ def _handle_sasl_DIGEST_MD5( "DIGEST-MD5: Captured digest response for offline analysis (WARNING: not currently crackable with hashcat)" ) if qop and qop in ["auth", "auth-int", "auth-conf"]: - self.auth_state.set_negotiated_qop(qop) + self.auth_state.apply_negotiated_qop(qop) self.logger.debug(f"DIGEST-MD5: Negotiated QOP: {qop}") if self.server.server_config.ldap_require_sealing: @@ -2302,7 +2310,7 @@ def _parse_digest_response(self, response_str: str) -> dict[str, str] | None: try: directives = parse_http_list(response_str) parsed = parse_keqv_list(directives) - except Exception as e: + except ValueError as e: self.logger.debug(f"DIGEST-MD5: Failed to parse response: {e}") return None else: @@ -2335,7 +2343,7 @@ def _handle_sasl_PLAIN( ) return self.server.bind_result(message, reason=LDAP_INVALID_CREDENTIALS) - except Exception as e: + except (ValueError, AttributeError) as e: self.logger.debug(f"SASL PLAIN: Failed to parse credentials: {e}") return self.server.bind_result(message, reason=LDAP_INVALID_CREDENTIALS) else: @@ -2351,7 +2359,7 @@ def _handle_sasl_PLAIN( username, domain, extras = self._parse_cleartext_user(authcid) self.config.db.add_auth( client=self.client_address, - credtype="plain", + credtype=CLEARTEXT, username=username, password=passwd, domain=domain, @@ -2442,12 +2450,9 @@ def _handle_sasl_ntlm_negotiate( return self.server.bind_result(message, reason=LDAP_AUTH_METHOD_NOT_SUPPORTED) fqdn = self.server.server_config.ldap_fqdn - name, domain = fqdn.split(".", 1) if "." in fqdn else (fqdn, "") - - ntlm_challenge = NTLM_AUTH_CreateChallenge( + ntlm_challenge = ntlm_auth_create_challenge( negotiate, - name, - domain, + *ntlm_split_fqdn(fqdn), challenge=self.server.server_config.ntlm_challenge, disable_ess=self.server.server_config.ntlm_disable_ess, disable_ntlmv2=self.server.server_config.ntlm_disable_ntlmv2, @@ -2485,7 +2490,7 @@ def _handle_sasl_ntlm_authenticate( self.sasl_state.transition(SASLAuthState.FAILED) return self.server.bind_result(message, reason=LDAP_AUTH_METHOD_NOT_SUPPORTED) - NTLM_report_auth( + ntlm_report_auth( auth_token=auth_message, challenge=self.server.server_config.ntlm_challenge, client=self.client_address, diff --git a/dementor/protocols/llmnr.py b/dementor/protocols/llmnr.py index f8605a0..f6a22df 100644 --- a/dementor/protocols/llmnr.py +++ b/dementor/protocols/llmnr.py @@ -98,7 +98,7 @@ def handle_data(self, data: bytes, transport) -> None: self.send_poisoned_answer(packet, question, transport) def send_poisoned_answer(self, req, question: dns.DNSQR, transport) -> None: - # check if we can send a response + # skip response if the address family (A vs AAAA) doesn't match local config if question.qtype == 28 and not self.config.ipv6: self.logger.highlight( "Client requested AAAA record (IPv6) but local config does not " diff --git a/dementor/protocols/mdns.py b/dementor/protocols/mdns.py index 0fb0b7a..e5188a9 100644 --- a/dementor/protocols/mdns.py +++ b/dementor/protocols/mdns.py @@ -66,7 +66,7 @@ class MDNSConfig(TomlConfig): ignored: Filters | None targets: Filters | None - def set_mdns_qtypes(self, value: list[str | int]): + def set_mdns_qtypes(self, value: list[str | int]) -> None: # REVISIT: maybe add error check here self.mdns_qtypes = [x if isinstance(x, int) else QTYPES[x] for x in value] @@ -108,10 +108,9 @@ def build_dns_answer(req_id: int, question: dns.DNSQR, config: SessionConfig): # --- Poisoner/Server --------------------------------------------------------- class MDNSPoisoner(BaseProtoHandler): - def proto_logger(self): + def proto_logger(self) -> ProtocolLogger: return ProtocolLogger( - extra={ - "protocol": "MDNS", + { "protocol_color": "deep_sky_blue1", "host": self.client_host, "port": self.client_port, @@ -167,7 +166,7 @@ def handle_data(self, data: bytes, transport) -> None: def send_poisoned_answer( self, req, question: dns.DNSQR, transport, name: str ) -> None: - # check if we can send a response + # skip response if the address family (A vs AAAA) doesn't match local config if question.qtype == 28 and not self.config.ipv6: self.logger.highlight( "Client requested AAAA record (IPv6) but local config does not specify IPv6 address. Ignoring..." diff --git a/dementor/protocols/msrpc/rpc.py b/dementor/protocols/msrpc/rpc.py index 1bc80f0..1acc67e 100644 --- a/dementor/protocols/msrpc/rpc.py +++ b/dementor/protocols/msrpc/rpc.py @@ -33,11 +33,12 @@ from impacket import ntlm from dementor.config.toml import TomlConfig, Attribute as A +from dementor.config.util import HostValue, HostFallbackValue from dementor.log.logger import ProtocolLogger, dm_logger from dementor.protocols.ntlm import ( - NTLM_build_challenge_message, - NTLM_handle_negotiate_message, - NTLM_handle_authenticate_message, + ntlm_build_challenge_message, + ntlm_handle_negotiate_message, + ntlm_handle_authenticate_message, ) from dementor.servers import ThreadingTCPServer, BaseProtoHandler from dementor.loader import ProtocolLoader @@ -76,7 +77,13 @@ class RPCModule(typing.Protocol): class RPCConfig(TomlConfig): _section_ = "RPC" _fields_ = [ - A("rpc_fqdn", "FQDN", "DEMENTOR", section_local=False), + A( + "rpc_fqdn", + "Host", + None, + section_local=False, + factory=HostFallbackValue(HostValue.HOST, "DEMENTOR"), + ), A("epm_port", "EPM.TargetPort", 49000), A("epm_port_range", "EPM.TargetPortRange", None), A("rpc_modules", "Interfaces", list), @@ -90,7 +97,7 @@ class RPCConfig(TomlConfig): rpc_modules: list[RPCModule] rpc_error_code: int - def set_rcp_error_code(self, value: str | int): + def set_rpc_error_code(self, value: str | int) -> None: if isinstance(value, str): value = rev_rpc_status_codes[value] @@ -102,7 +109,7 @@ def set_rcp_error_code(self, value: str | int): self.rpc_error_code = value - def set_epm_port_range(self, value: str | dict): + def set_epm_port_range(self, value: str | dict) -> None: start = end = None match value: case dict(): @@ -128,7 +135,7 @@ def set_epm_port_range(self, value: str | dict): self.epm_port_range = (start, end) self.epm_port = random.randrange(start, end) - def set_rpc_modules(self, extra_paths: list): + def set_rpc_modules(self, extra_paths: list) -> None: loader = ProtocolLoader() loader.search_path = [os.path.dirname(__file__)] loader.search_path.extend(extra_paths) @@ -152,7 +159,7 @@ class RPCHandler(BaseProtoHandler): server: "MSRPCServer" def __init__(self, config, request, client_address, server) -> None: - self.rpc_config = config.rpc_config + self.rpc_config: RPCConfig = config.rpc_config super().__init__(config, request, client_address, server) def proto_logger(self) -> ProtocolLogger: @@ -234,7 +241,7 @@ def handle_bind( ctx_items = [] data = bind_req["ctx_items"] - conn = self.server.get_conn_by_call_id(header["call_id"]) + conn = self.server.get_or_create_conn_by_call_id(header["call_id"]) endpoints = set() for _ in range(bind_req["ctx_num"]): result = rpcrt.MSRPC_CONT_RESULT_PROV_REJECT @@ -289,15 +296,22 @@ def handle_bind( # generate challenge negotiate = ntlm.NTLMAuthNegotiate() negotiate.fromString(token) - negotiate_fields = NTLM_handle_negotiate_message(negotiate, self.logger) + negotiate_fields = ntlm_handle_negotiate_message(negotiate, self.logger) conn.negotiate_fields = negotiate_fields - challenge = NTLM_build_challenge_message( + host = HostValue(self.rpc_config.rpc_fqdn) + challenge = ntlm_build_challenge_message( negotiate, challenge=self.config.ntlm_challenge, - nb_computer=self.config.ntlm_nb_computer, - nb_domain=self.config.ntlm_nb_domain, + nb_computer=host.get_value(HostValue.NETBIOS_COMPUTER), + nb_domain=host.get_value(HostValue.NETBIOS_DOMAIN), disable_ess=self.config.ntlm_disable_ess, disable_ntlmv2=self.config.ntlm_disable_ntlmv2, + target_type=self.config.ntlm_target_type, + version=self.config.ntlm_version, + dns_computer=host.get_value(HostValue.DNS_COMPUTER), + dns_domain=host.get_value(HostValue.DNS_DOMAIN), + # REVISIT: capture DNSTree too + # dns_tree=self.config.ntlm_dns_tree, log=self.logger, ) bind_ack["auth_data"] = challenge.getData() @@ -330,7 +344,7 @@ def handle_auth3(self, header: rpcrt.MSRPCHeader): self.logger.display(f"Rejecting AUTH3 request using AuthType: {auth_type:#x}") return rev_rpc_status_codes["nca_s_unsupported_authn_level"] - conn = self.server.get_conn_by_call_id(header["call_id"]) + conn = self.server.get_or_create_conn_by_call_id(header["call_id"]) token = header["auth_data"] if not conn.challenge: # challenge not set, invalid request @@ -338,7 +352,7 @@ def handle_auth3(self, header: rpcrt.MSRPCHeader): auth_resp = ntlm.NTLMAuthChallengeResponse() auth_resp.fromString(token) - NTLM_handle_authenticate_message( + ntlm_handle_authenticate_message( auth_token=auth_resp, challenge=conn.challenge["challenge"], client=self.client_address, @@ -350,7 +364,7 @@ def handle_auth3(self, header: rpcrt.MSRPCHeader): def handle_request(self, data): request = rpcrt.MSRPCRequestHeader(data) - conn = self.server.get_conn_by_call_id(request["call_id"]) + conn = self.server.get_or_create_conn_by_call_id(request["call_id"]) conn.ctx_id = request["ctx_id"] if not conn.target: # Interface not set, we can't handle this @@ -375,14 +389,14 @@ def __init__( self._conn_lock = threading.Lock() super().__init__(config, server_address, RequestHandlerClass) - def get_conn_by_call_id(self, call_id: int) -> RPCConnection: + def get_or_create_conn_by_call_id(self, call_id: int) -> RPCConnection: with self._conn_lock: conn = self.conn_data[call_id] if conn.call_id == -1: conn.call_id = call_id return conn - def get_conn_by_auth_ctx_id(self, auth_ctx_id: int) -> RPCConnection: + def get_or_create_conn_by_auth_ctx_id(self, auth_ctx_id: int) -> RPCConnection: conn = next( filter(lambda x: x.auth_ctx_id == auth_ctx_id, self.conn_data.values()), None, @@ -405,6 +419,8 @@ def _module_handler(self, module) -> RPCEndpointHandlerFunc | None: if handler_cls: return handler_cls() + return None + def get_handler_by_uuid(self, uuid: bytes) -> RPCEndpointHandlerFunc | None: uuid_str, _ = rpcrt.bin_to_uuidtup(uuid) for module in self.config.rpc_config.rpc_modules: diff --git a/dementor/protocols/mssql.py b/dementor/protocols/mssql.py index 3ce12e1..0edcfa8 100644 --- a/dementor/protocols/mssql.py +++ b/dementor/protocols/mssql.py @@ -45,14 +45,15 @@ from dementor.config.session import SessionConfig from dementor.loader import BaseProtocolModule, DEFAULT_ATTR -from dementor.db import _CLEARTEXT +from dementor.db import CLEARTEXT from dementor.config.toml import TomlConfig, Attribute as A +from dementor.config.util import HostValue, HostFallbackValue from dementor.log.hexdump import hexdump from dementor.log.logger import ProtocolLogger from dementor.protocols.ntlm import ( - NTLM_build_challenge_message, - NTLM_handle_authenticate_message, - NTLM_handle_negotiate_message, + ntlm_build_challenge_message, + ntlm_handle_authenticate_message, + ntlm_handle_negotiate_message, ) from dementor.servers import ( BaseProtoHandler, @@ -126,7 +127,13 @@ class SVR_RESP_DAC: class SSRPConfig(TomlConfig): _section_ = "SSRP" _fields_ = [ - A("ssrp_server_name", "MSSQL.FQDN", "DEMENTOR"), + A( + "ssrp_server_name", + "MSSQL.Host", + None, + section_local=False, + factory=HostFallbackValue(HostValue.HOST, "DEMENTOR"), + ), A("ssrp_server_version", "MSSQL.Version", "9.00.1399.06"), A("ssrp_server_instance", "MSSQL.InstanceName", "MSSQLServer"), A("ssrp_instance_config", "InstanceConfig", ""), @@ -189,7 +196,10 @@ def handle_data(self, data, transport) -> None: ) resp = SVR_RESP( data=( - f"ServerName;{self.config.ssrp_config.ssrp_server_name};InstanceName;{instance_name};IsClustered;No;Version;{self.config.ssrp_config.ssrp_server_version};tcp;{self.config.mssql_config.mssql_port}{self.config.ssrp_config.ssrp_instance_config};;" + f"ServerName;{self.config.ssrp_config.ssrp_server_name};" + f"InstanceName;{instance_name};IsClustered;No;Version;" + f"{self.config.ssrp_config.ssrp_server_version};tcp;" + f"{self.config.mssql_config.mssql_port}{self.config.ssrp_config.ssrp_instance_config};;" ) ) self.send(pack(resp)) @@ -209,7 +219,13 @@ class MSSQLConfig(TomlConfig): _fields_ = [ A("mssql_port", "Port", 1433), A("mssql_server_version", "Version", "9.00.1399.06"), - A("mssql_fqdn", "FQDN", "DEMENTOR", section_local=False), + A( + "mssql_fqdn", + "Host", + None, + section_local=False, + factory=HostFallbackValue(HostValue.HOST, "DEMENTOR"), + ), A("mssql_instance", "InstanceName", "MSSQLSerevr"), A("mssql_error_code", "ErrorCode", 1205), # LK_VICTIM A("mssql_error_state", "ErrorState", 1), @@ -295,7 +311,7 @@ def length_hint(self) -> int: ) -class MSSQLHandler(BaseProtoHandler): +class MSSQLHandler(BaseProtoHandler["MSSQLServer"]): def __init__(self, config, request, client_address, server) -> None: self.challenge = None super().__init__(config, request, client_address, server) @@ -356,11 +372,13 @@ def handle_pre_login(self, packet: tds.TDSPacket) -> int: tds.TDS_ENCRYPT_ON, ): self.logger.display( - f"Pre-Login request for [i]{escape(instance)}[/] ([bold red]Encryption requested[/])" + f"Pre-Login request for [i]{escape(instance)}[/] " + f"([bold red]Encryption requested[/])" ) else: self.logger.display( - f"PreLogin request for [i]{escape(instance)}[/] (version: {unpack(PL_OPTION_TOKEN_VERSION, version)})" + f"PreLogin request for [i]{escape(instance)}[/] " + f"(version: {unpack(PL_OPTION_TOKEN_VERSION, version)})" ) pre_login = tds.TDS_PRELOGIN() @@ -399,7 +417,7 @@ def handle_login(self, packet: tds.TDSPacket) -> int: cleartext_password = self.decode_password(password) self.config.db.add_auth( client=self.client_address, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, username=username, password=cleartext_password, logger=self.logger, @@ -419,14 +437,21 @@ def handle_login(self, packet: tds.TDSPacket) -> int: self.send_error(packet) return 1 - self.negotiate_fields = NTLM_handle_negotiate_message(negotiate, self.logger) - self.challenge = NTLM_build_challenge_message( + self.negotiate_fields = ntlm_handle_negotiate_message(negotiate, self.logger) + host = HostValue(self.config.mssql_config.mssql_fqdn) + self.challenge = ntlm_build_challenge_message( negotiate, challenge=self.config.ntlm_challenge, - nb_computer=self.config.ntlm_nb_computer, - nb_domain=self.config.ntlm_nb_domain, + nb_computer=host.get_value(HostValue.NETBIOS_COMPUTER), + nb_domain=host.get_value(HostValue.NETBIOS_DOMAIN), disable_ess=self.config.ntlm_disable_ess, disable_ntlmv2=self.config.ntlm_disable_ntlmv2, + target_type=self.config.ntlm_target_type, + version=self.config.ntlm_version, + dns_computer=host.get_value(HostValue.DNS_COMPUTER), + dns_domain=host.get_value(HostValue.DNS_DOMAIN), + # REVISIT: capture DNSTree too + # dns_tree=self.config.ntlm_dns_tree, log=self.logger, ) @@ -443,6 +468,12 @@ def handle_login(self, packet: tds.TDSPacket) -> int: return 1 # terminate connection def handle_sspi(self, packet: tds.TDSPacket) -> int: + """Handle TDS SSPI (type 17) auth packet; always terminates the session. + + Captures the NTLM hash, then forces a login error so the client + retries (which may produce additional credential material). Return + value is always 1 (terminate) - SSPI auth is never continued. + """ raw_data = packet["Data"] try: auth_message = ntlm.NTLMAuthChallengeResponse() @@ -452,7 +483,7 @@ def handle_sspi(self, packet: tds.TDSPacket) -> int: self.send_error(packet) return 1 - NTLM_handle_authenticate_message( + ntlm_handle_authenticate_message( auth_message, challenge=self.challenge["challenge"], client=self.client_address, @@ -521,7 +552,7 @@ class MSSQL(BaseProtocolModule[MSSQLConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: MSSQLConfig - ) -> BaseServerThread: + ) -> BaseServerThread[MSSQLConfig]: return ServerThread(session, server_config, MSSQLServer) @@ -535,5 +566,5 @@ class SSRP(BaseProtocolModule[SSRPConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: SSRPConfig - ) -> BaseServerThread: + ) -> BaseServerThread[SSRPConfig]: return ServerThread(session, server_config, SSRPServer) diff --git a/dementor/protocols/mysql.py b/dementor/protocols/mysql.py index 265d96f..753d14b 100644 --- a/dementor/protocols/mysql.py +++ b/dementor/protocols/mysql.py @@ -46,7 +46,6 @@ LittleEndian, uint16, uint64, - singleton, unpack, ) from caterpillar.exception import DynamicSizeError, StructException @@ -65,7 +64,7 @@ from dementor.log.logger import ProtocolLogger from dementor.config.attr import Attribute as A, ATTR_TLS, ATTR_CERT, ATTR_KEY from dementor.config.toml import TomlConfig -from dementor.db import _CLEARTEXT +from dementor.db import CLEARTEXT __proto__ = ["MySQL"] @@ -104,7 +103,7 @@ class MySQL(BaseProtocolModule[MySQLConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: MySQLConfig - ) -> BaseServerThread: + ) -> BaseServerThread[MySQLConfig]: return ServerThread( session, server_config, @@ -153,7 +152,7 @@ def create_server_thread( CLIENT_REMEMBER_OPTIONS = 1 << 31 -class SERVER_STATUS_flags_enum(enum.IntEnum): +class ServerStatusFlags(enum.IntEnum): __struct__ = uint16 SERVER_STATUS_IN_TRANS = 1 @@ -179,7 +178,7 @@ class SERVER_STATUS_flags_enum(enum.IntEnum): # 251 216 0xFC + 2-byte integer # 216 224 0xFD + 3-byte integer # 224 264 0xFE + 8-byte integer -@singleton +# @singleton class LengthEncodedInteger: def __type__(self): return int @@ -283,7 +282,7 @@ class HandshakeV10: # default server a_protocol_character_set, only the lower 8-bits character_set: uint8_t = 0x3F - status_flags: SERVER_STATUS_flags_enum = 0x00 + status_flags: ServerStatusFlags = 0x00 flags_upper: uint16_t = 0x000 # length of the combined auth_plugin_data (scramble), if @@ -298,11 +297,11 @@ class HandshakeV10: # name of the auth_method that the auth_plugin_data belongs to auth_plugin_name: f[str | None, CString() // _has_auth_plugin_data] - def set_flags(self, flags): + def set_flags(self, flags) -> None: self.flags_lower = flags & 0xFFFF self.flags_upper = (flags >> 16) & 0xFFFF - def get_flags(self): + def get_flags(self) -> int: return self.flags_lower | (self.flags_upper << 16) @@ -330,8 +329,8 @@ class SSLRequest: @struct(order=LittleEndian) class ConnectionAttribute(struct_factory.mixin): # will be stored as bytes rather than string - key: f[bytes, Prefixed(LengthEncodedInteger)] - value: f[bytes, Prefixed(LengthEncodedInteger)] + key: f[bytes, Prefixed(LengthEncodedInteger())] + value: f[bytes, Prefixed(LengthEncodedInteger())] # [Protocol::HandshakeResponse] @@ -374,7 +373,7 @@ class HandshakeResponse: # --- MySQL Handler --- class MySQLHandler(BaseProtoHandler): @property - def mysql_config(self): + def mysql_config(self) -> "MySQLConfig": return self.config.mysql_config def proto_logger(self) -> ProtocolLogger: @@ -441,7 +440,7 @@ def handle_data(self, data, transport) -> None: server_version=self.mysql_config.mysql_version, thread_id=10, salt=b"A" * 8, - status_flags=SERVER_STATUS_flags_enum.SERVER_STATUS_AUTOCOMMIT, + status_flags=ServerStatusFlags.SERVER_STATUS_AUTOCOMMIT, auth_plugin_data_len=21, # REVISIT: maybe add automatic calculation here salt2=b"A" * 12 + b"\0", auth_plugin_name=plugin_name, @@ -507,7 +506,9 @@ def handle_data(self, data, transport) -> None: else: self.logger.debug(f"Unknown authentication plugin: {resp_plugin_name}") - def mysql_clear_password(self, greeting: HandshakeV10, response: HandshakeResponse): + def mysql_clear_password( + self, greeting: HandshakeV10, response: HandshakeResponse + ) -> None: username = response.username password = response.auth_response.decode(errors="replace").strip("\x00") @@ -525,7 +526,7 @@ def mysql_clear_password(self, greeting: HandshakeV10, response: HandshakeRespon username=username, password=password, logger=self.logger, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, extras=OrderedDict(sorted(extras.items())), ) diff --git a/dementor/protocols/netbios.py b/dementor/protocols/netbios.py index 31c950b..6a8217c 100644 --- a/dementor/protocols/netbios.py +++ b/dementor/protocols/netbios.py @@ -33,6 +33,7 @@ from dementor.log.logger import ProtocolLogger from dementor.config.session import TomlConfig from dementor.config.toml import Attribute as A +from dementor.config.util import HostValue, HostFallbackValue from dementor.filters import ATTR_BLACKLIST, ATTR_WHITELIST, in_scope from dementor.protocols import mailslot, netlogon @@ -54,8 +55,18 @@ class NBTNSConfig(TomlConfig): class BrowserConfig(TomlConfig): _section_ = "Browser" _fields_ = [ - A("browser_domain_name", "DomainName", "CONTOSO"), - A("browser_hostname", "Hostname", "DC01"), + A( + "browser_domain_name", + f"NTLM.{HostValue.DNS_DOMAIN}", + default_val=None, + factory=HostFallbackValue(HostValue.DNS_DOMAIN, "WORKGROUP"), + ), + A( + "browser_hostname", + f"NTLM.{HostValue.DNS_COMPUTER}", + default_val=None, + factory=HostFallbackValue(HostValue.DNS_COMPUTER, HostValue.DEFAULT), + ), ATTR_WHITELIST, ATTR_BLACKLIST, ] @@ -98,7 +109,7 @@ class BrowserConfig(TomlConfig): class NetBiosNSPoisoner(BaseProtoHandler): - def proto_logger(self): + def proto_logger(self) -> ProtocolLogger: return ProtocolLogger( extra={ "protocol": "NetBIOS", @@ -108,7 +119,7 @@ def proto_logger(self): } ) - def handle_data(self, data: bytes, transport) -> None: + def handle_data(self, data: bytes, transport) -> None: # ty:ignore[invalid-method-override] header = netbios.NBNSHeader(data) if header.RESPONSE: # response sent by server, ignore diff --git a/dementor/protocols/ntlm.py b/dementor/protocols/ntlm.py index a08dd49..208562d 100644 --- a/dementor/protocols/ntlm.py +++ b/dementor/protocols/ntlm.py @@ -62,6 +62,7 @@ """ import time +import struct import calendar import secrets @@ -72,8 +73,8 @@ from dementor.config.toml import Attribute from dementor.config.session import SessionConfig -from dementor.config.util import is_true, get_value, BytesValue -from dementor.db import _HOST_INFO +from dementor.config.util import is_true, get_value, BytesValue, HostFallbackValue +from dementor.db import CLEARTEXT, HOST_INFO from dementor.log.logger import ProtocolLogger, dm_logger # --- Constants --------------------------------------------------------------- @@ -103,7 +104,7 @@ # VERSION structure per [MS-NLMP section 2.2.2.10] NTLM_VERSION_LEN: int = 8 -# NTLMSSP_REVISION_W2K3 per [MS-NLMP] §2.2.2.10 — all modern Windows use 0x0F. +# NTLMSSP_REVISION_W2K3 per [MS-NLMP] §2.2.2.10 - all modern Windows use 0x0F. NTLM_REVISION_W2K3: int = 0x0F # Offset from the Unix epoch (1 Jan 1970) to the Windows FILETIME epoch @@ -160,7 +161,7 @@ # Challenge parsing is handled by BytesValue(NTLM_CHALLENGE_LEN) from -# dementor.config.util — supports hex:/ascii: prefixes, auto-detect, +# dementor.config.util - supports hex:/ascii: prefixes, auto-detect, # and length validation in a single reusable helper. _parse_challenge = BytesValue(NTLM_CHALLENGE_LEN) @@ -228,7 +229,7 @@ def _config_version_to_bytes(value: str | None) -> bytes: ) # These control the server identity inside the NTLMSSP CHALLENGE_MESSAGE. -# None means "derive from the protocol's own identity config" — each +# None means "derive from the protocol's own identity config" - each # protocol handler resolves the fallback chain. ATTR_NTLM_TARGET_TYPE = Attribute( @@ -248,40 +249,75 @@ def _config_version_to_bytes(value: str | None) -> bytes: ATTR_NTLM_NB_COMPUTER = Attribute( "ntlm_nb_computer", - "NTLM.NetBIOSComputer", - "DEMENTOR", # MsvAvNbComputerName (AV_PAIR 0x0001) + "NTLM.NetBIOSComputer", # explicit: [NTLM] > [Globals]; fallback: derived from [Globals].Host + None, section_local=False, + factory=HostFallbackValue("NetBIOSComputer", "DEMENTOR"), ) ATTR_NTLM_NB_DOMAIN = Attribute( "ntlm_nb_domain", - "NTLM.NetBIOSDomain", - "WORKGROUP", # MsvAvNbDomainName (AV_PAIR 0x0002) + "NTLM.NetBIOSDomain", # explicit: [NTLM] > [Globals]; fallback: derived from [Globals].Host + None, section_local=False, + factory=HostFallbackValue("NetBIOSDomain", "WORKGROUP"), ) ATTR_NTLM_DNS_COMPUTER = Attribute( "ntlm_dns_computer", - "NTLM.DnsComputer", - "", # MsvAvDnsComputerName (AV_PAIR 0x0003); "" → omitted from AV_PAIRs + "NTLM.DnsComputer", # explicit: [NTLM] > [Globals]; fallback: derived from [Globals].Host + None, section_local=False, + factory=HostFallbackValue("DnsComputer", ""), ) ATTR_NTLM_DNS_DOMAIN = Attribute( "ntlm_dns_domain", - "NTLM.DnsDomain", - "", # MsvAvDnsDomainName (AV_PAIR 0x0004); "" → omitted from AV_PAIRs + "NTLM.DnsDomain", # explicit: [NTLM] > [Globals]; fallback: derived from [Globals].Host + None, section_local=False, + factory=HostFallbackValue("DnsDomain", ""), ) ATTR_NTLM_DNS_TREE = Attribute( "ntlm_dns_tree", - "NTLM.DnsTree", - "", # MsvAvDnsTreeName (AV_PAIR 0x0005); "" → omitted from AV_PAIRs + "NTLM.DnsTree", # explicit: [NTLM] > [Globals]; fallback: derived from [Globals].Host + None, section_local=False, + factory=HostFallbackValue("DnsTree", ""), ) +def _apply_ntlm_field( + session: SessionConfig, + key: str, + attr: str, + transform, + factory, +) -> None: + """Apply a single NTLM config field to the session, logging on failure. + + Resolution order: ``[NTLM]`` -> ``[Globals]`` (explicit) -> *factory*. + When *factory* is a :class:`~dementor.config.util.HostFallbackValue`, a + ``None`` raw value triggers lazy derivation from ``Globals.Host``. + + :param session: Session object to update. + :param key: TOML key name within the ``[NTLM]`` section. + :param attr: Session attribute name to set. + :param transform: Callable to coerce the raw value. + :param factory: Callable ``(value_or_none) -> resolved_value``. For + identity fields pass a :class:`~dementor.config.util.HostFallbackValue` + instance; for plain defaults pass a simple fallback factory. + """ + try: + raw = get_value("NTLM", key, default=None) + if raw is None: + raw = get_value("Globals", key, default=None) + setattr(session, attr, transform(factory(raw))) + except (TypeError, ValueError): + dm_logger.exception("Failed to apply NTLM.%s; using default", key) + + def apply_config(session: SessionConfig) -> None: """Apply global NTLM settings from the ``[NTLM]`` TOML section to the session. @@ -295,22 +331,23 @@ def apply_config(session: SessionConfig) -> None: :type session: SessionConfig """ # Safe defaults (session remains valid even if config parsing fails). + # factory(None) triggers HostFallbackValue to read [Globals].Host if available. session.ntlm_challenge = secrets.token_bytes(NTLM_CHALLENGE_LEN) session.ntlm_disable_ess = False session.ntlm_disable_ntlmv2 = False session.ntlm_target_type = str(ATTR_NTLM_TARGET_TYPE.default_val) session.ntlm_version = _config_version_to_bytes(ATTR_NTLM_VERSION.default_val) - session.ntlm_nb_computer = str(ATTR_NTLM_NB_COMPUTER.default_val) - session.ntlm_nb_domain = str(ATTR_NTLM_NB_DOMAIN.default_val) - session.ntlm_dns_computer = str(ATTR_NTLM_DNS_COMPUTER.default_val) - session.ntlm_dns_domain = str(ATTR_NTLM_DNS_DOMAIN.default_val) - session.ntlm_dns_tree = str(ATTR_NTLM_DNS_TREE.default_val) + session.ntlm_nb_computer = str(ATTR_NTLM_NB_COMPUTER.factory(None)) # ty:ignore[call-non-callable] + session.ntlm_nb_domain = str(ATTR_NTLM_NB_DOMAIN.factory(None)) # ty:ignore[call-non-callable] + session.ntlm_dns_computer = str(ATTR_NTLM_DNS_COMPUTER.factory(None)) # ty:ignore[call-non-callable] + session.ntlm_dns_domain = str(ATTR_NTLM_DNS_DOMAIN.factory(None)) # ty:ignore[call-non-callable] + session.ntlm_dns_tree = str(ATTR_NTLM_DNS_TREE.factory(None)) # ty:ignore[call-non-callable] # -- ServerChallenge --------------------------------------------------- try: raw_challenge = get_value("NTLM", "Challenge", default=None) session.ntlm_challenge = _parse_challenge(raw_challenge) - except Exception: + except (TypeError, ValueError): dm_logger.exception("Failed to parse NTLM Challenge; using random bytes") dm_logger.debug( "NTLM Challenge set to value: %s with len %d", @@ -318,29 +355,24 @@ def apply_config(session: SessionConfig) -> None: len(session.ntlm_challenge), ) - # -- Extended Session Security ----------------------------------------- - try: - raw = get_value("NTLM", "DisableExtendedSessionSecurity", default=False) - session.ntlm_disable_ess = bool(is_true(raw)) - except Exception: - session.ntlm_disable_ess = False - dm_logger.exception( - "Failed to apply NTLM.DisableExtendedSessionSecurity; defaulting to False" - ) - else: - dm_logger.debug( - "NTLM DisableExtendedSessionSecurity: %s", session.ntlm_disable_ess - ) - - # -- Disable NTLMv2 ---------------------------------------------------- - try: - raw = get_value("NTLM", "DisableNTLMv2", default=False) - session.ntlm_disable_ntlmv2 = bool(is_true(raw)) - except Exception: - session.ntlm_disable_ntlmv2 = False - dm_logger.exception("Failed to apply NTLM.DisableNTLMv2; defaulting to False") - else: - dm_logger.debug("NTLM DisableNTLMv2: %s", session.ntlm_disable_ntlmv2) + # -- Boolean flags (ESS and NTLMv2) ------------------------------------ + _apply_ntlm_field( + session, + "DisableExtendedSessionSecurity", + "ntlm_disable_ess", + is_true, + lambda v: v if v is not None else False, + ) + dm_logger.debug("NTLM DisableExtendedSessionSecurity: %s", session.ntlm_disable_ess) + + _apply_ntlm_field( + session, + "DisableNTLMv2", + "ntlm_disable_ntlmv2", + is_true, + lambda v: v if v is not None else False, + ) + dm_logger.debug("NTLM DisableNTLMv2: %s", session.ntlm_disable_ntlmv2) if session.ntlm_disable_ntlmv2: dm_logger.warning( @@ -350,67 +382,42 @@ def apply_config(session: SessionConfig) -> None: + "Use with caution." ) - # -- Target Type ------------------------------------------------------- - try: - raw = get_value("NTLM", "TargetType", default=ATTR_NTLM_TARGET_TYPE.default_val) - session.ntlm_target_type = str(raw) - except Exception: - dm_logger.exception("Failed to apply NTLM.TargetType; using default") - - # -- Version ----------------------------------------------------------- - try: - raw = get_value("NTLM", "Version", default=ATTR_NTLM_VERSION.default_val) - session.ntlm_version = _config_version_to_bytes(raw) - except Exception: - dm_logger.exception("Failed to apply NTLM.Version; using default") - - # -- NetBIOS Computer -------------------------------------------------- - try: - raw = get_value( - "NTLM", "NetBIOSComputer", default=ATTR_NTLM_NB_COMPUTER.default_val - ) - session.ntlm_nb_computer = str(raw) - except Exception: - dm_logger.exception("Failed to apply NTLM.NetBIOSComputer; using default") - - # -- NetBIOS Domain ---------------------------------------------------- - try: - raw = get_value("NTLM", "NetBIOSDomain", default=ATTR_NTLM_NB_DOMAIN.default_val) - session.ntlm_nb_domain = str(raw) - except Exception: - dm_logger.exception("Failed to apply NTLM.NetBIOSDomain; using default") - - # -- DNS Computer ------------------------------------------------------ - try: - raw = get_value("NTLM", "DnsComputer", default=ATTR_NTLM_DNS_COMPUTER.default_val) - session.ntlm_dns_computer = str(raw) - except Exception: - dm_logger.exception("Failed to apply NTLM.DnsComputer; using default") - - # -- DNS Domain -------------------------------------------------------- - try: - raw = get_value("NTLM", "DnsDomain", default=ATTR_NTLM_DNS_DOMAIN.default_val) - session.ntlm_dns_domain = str(raw) - except Exception: - dm_logger.exception("Failed to apply NTLM.DnsDomain; using default") - - # -- DNS Tree ---------------------------------------------------------- - try: - raw = get_value("NTLM", "DnsTree", default=ATTR_NTLM_DNS_TREE.default_val) - session.ntlm_dns_tree = str(raw) - except Exception: - dm_logger.exception("Failed to apply NTLM.DnsTree; using default") + # -- String / typed fields --------------------------------------------- + _target_type_default = ATTR_NTLM_TARGET_TYPE.default_val + _version_default = ATTR_NTLM_VERSION.default_val + for toml_key, attr_name, transform, factory in ( + ( + "TargetType", + "ntlm_target_type", + str, + lambda v: v if v is not None else _target_type_default, + ), + ( + "Version", + "ntlm_version", + _config_version_to_bytes, + lambda v: v if v is not None else _version_default, + ), + # Identity fields: each checked explicitly in [NTLM] -> [Globals] first; + # HostFallbackValue derives from [Globals].Host only when both are absent. + ("NetBIOSComputer", "ntlm_nb_computer", str, ATTR_NTLM_NB_COMPUTER.factory), + ("NetBIOSDomain", "ntlm_nb_domain", str, ATTR_NTLM_NB_DOMAIN.factory), + ("DnsComputer", "ntlm_dns_computer", str, ATTR_NTLM_DNS_COMPUTER.factory), + ("DnsDomain", "ntlm_dns_domain", str, ATTR_NTLM_DNS_DOMAIN.factory), + ("DnsTree", "ntlm_dns_tree", str, ATTR_NTLM_DNS_TREE.factory), + ): + _apply_ntlm_field(session, toml_key, attr_name, transform, factory) # --- Encoding ---------------------------------------------------------------- # # NEGOTIATE_MESSAGE fields: always OEM (Unicode not yet negotiated). # CHALLENGE_MESSAGE / AUTHENTICATE_MESSAGE: governed by NegotiateFlags: -# NTLMSSP_NEGOTIATE_UNICODE (0x01) → UTF-16LE (no BOM) -# NTLM_NEGOTIATE_OEM (0x02) → cp437 baseline +# NTLMSSP_NEGOTIATE_UNICODE (0x01) -> UTF-16LE (no BOM) +# NTLM_NEGOTIATE_OEM (0x02) -> cp437 baseline -def NTLM_decode_string( +def ntlm_decode_string( data: bytes | None, negotiate_flags: int, is_negotiate_oem: bool = False, @@ -423,7 +430,7 @@ def NTLM_decode_string( Unicode has not been negotiated yet per [MS-NLMP] §2.2.1.1. * **AUTHENTICATE_MESSAGE** (``is_negotiate_oem=False``): encoding is determined by ``NTLMSSP_NEGOTIATE_UNICODE`` (flag A, 0x00000001) - in the message's NegotiateFlags. When set → UTF-16LE, else OEM + in the message's NegotiateFlags. When set -> UTF-16LE, else OEM (cp437 as baseline). Per [MS-NLMP] §2.2.1.3. :param data: Raw bytes from the NTLM message field @@ -446,11 +453,11 @@ def NTLM_decode_string( if negotiate_flags & ntlm.NTLMSSP_NEGOTIATE_UNICODE: return data.decode("utf-16-le", errors="replace").rstrip().rstrip("\x00") - # OEM fallback — cp437 as baseline; actual code page is system-dependent + # OEM fallback - cp437 as baseline; actual code page is system-dependent return data.decode("cp437", errors="replace") -def NTLM_encode_string(string: str | None, negotiate_flags: int) -> bytes: +def ntlm_encode_string(string: str | None, negotiate_flags: int) -> bytes: """Encode a Python str for inclusion in a CHALLENGE_MESSAGE. :param string: The string to encode (server name, domain, etc.) @@ -500,7 +507,7 @@ def _decode_ntlmssp_os_version( major = ver_raw[0] minor = ver_raw[1] build = uint16.from_bytes(ver_raw[2:4], order=LittleEndian) - except Exception: + except (IndexError, KeyError, ValueError, struct.error): dm_logger.debug("Failed to parse VERSION bytes from NTLM message") elif "os_version" in token.fields: try: @@ -508,7 +515,7 @@ def _decode_ntlmssp_os_version( major = ver_obj["ProductMajorVersion"] minor = ver_obj["ProductMinorVersion"] build = ver_obj["ProductBuild"] - except Exception: + except (KeyError, TypeError): dm_logger.debug("Failed to parse os_version from NTLM message") if major is None or minor is None or build is None: @@ -541,7 +548,7 @@ def _is_anonymous_authenticate(token: ntlm.NTLMAuthChallengeResponse) -> bool: try: # Structural anonymous: all response fields empty or Z(1) flags: int = token["flags"] - user_name: bytes = token["user_name"] or b"" + user_name: bytes = (token["user_name"] or b"").rstrip(b"\x00") nt_response: bytes = token["ntlm"] or b"" lm_response: bytes = token["lanman"] or b"" @@ -555,10 +562,16 @@ def _is_anonymous_authenticate(token: ntlm.NTLMAuthChallengeResponse) -> bool: dm_logger.debug("Structurally anonymous AUTHENTICATE_MESSAGE detected") return True + if len(user_name.strip()) == 0: # user name containts only spaces + dm_logger.debug( + "Anonymous AUTHENTICATE_MESSAGE detected (username with spaces only)" + ) + return True + # [MS-NLMP] §2.2.2.5 flag J: supplementary anonymous flag check return bool(flags & ntlm.NTLMSSP_NEGOTIATE_ANONYMOUS) - except Exception: + except (KeyError, TypeError): dm_logger.debug( "Failed to check anonymous status in AUTHENTICATE_MESSAGE; " + "treating as non-anonymous to avoid dropping captures", @@ -570,7 +583,7 @@ def _is_anonymous_authenticate(token: ntlm.NTLMAuthChallengeResponse) -> bool: # --- NTLMSSP Transaction ---------------------------------------------------- -def NTLM_handle_negotiate_message( +def ntlm_handle_negotiate_message( negotiate: ntlm.NTLMAuthNegotiate, logger: ProtocolLogger, ) -> dict[str, str]: @@ -590,21 +603,21 @@ def NTLM_handle_negotiate_message( os_str = _decode_ntlmssp_os_version(negotiate) domain_str = "" workstation_str = "" + flags: int = 0 try: - flags: int = negotiate["flags"] + flags = negotiate["flags"] # [MS-NLMP] §2.2.1.1: NEGOTIATE domain/workstation are OEM-encoded domain_str = ( - NTLM_decode_string(negotiate["domain_name"], flags, is_negotiate_oem=True) + ntlm_decode_string(negotiate["domain_name"], flags, is_negotiate_oem=True) or "" ) workstation_str = ( - NTLM_decode_string(negotiate["host_name"], flags, is_negotiate_oem=True) or "" + ntlm_decode_string(negotiate["host_name"], flags, is_negotiate_oem=True) or "" ) - except Exception: + except (KeyError, UnicodeDecodeError, TypeError): dm_logger.debug("Failed to parse hostname/domain from NEGOTIATE_MESSAGE") try: - flags = negotiate["flags"] parts = [f"flags=0x{flags:08x}"] parts.append(f"os={os_str!r}" if os_str else "os=(empty)") parts.append(f"domain={domain_str!r}" if domain_str else "domain=(empty)") @@ -614,7 +627,7 @@ def NTLM_handle_negotiate_message( else "workstation=(empty)" ) logger.debug("NTLMSSP NEGOTIATE: %s", " ".join(parts), is_client=True) - except Exception: + except (KeyError, TypeError): logger.debug("NTLMSSP NEGOTIATE: (failed to parse fields)", is_client=True) # Build return dict with only non-empty values @@ -640,7 +653,7 @@ def NTLM_handle_negotiate_message( # rainbow tables with a fixed ServerChallenge) -def NTLM_build_challenge_message( +def ntlm_build_challenge_message( token: ntlm.NTLMAuthNegotiate | dict[str, Any], *, challenge: bytes, @@ -761,15 +774,15 @@ def NTLM_build_challenge_message( # -- Assemble the CHALLENGE_MESSAGE ------------------------------------ # TargetName (§2.2.1.2): the server's authentication realm. - # [MS-NLMP] §3.2.5.1.1: TARGET_TYPE_SERVER → TargetName = NetBIOSComputer; - # TARGET_TYPE_DOMAIN → TargetName = NetBIOSDomain. + # [MS-NLMP] §3.2.5.1.1: TARGET_TYPE_SERVER -> TargetName = NetBIOSComputer; + # TARGET_TYPE_DOMAIN -> TargetName = NetBIOSDomain. if target_type == "domain": target_name_str = nb_domain.upper() else: target_name_str = nb_computer.upper() - target_name_bytes: bytes = NTLM_encode_string(target_name_str, response_flags) + target_name_bytes: bytes = ntlm_encode_string(target_name_str, response_flags) - # VERSION structure — [MS-NLMP] §2.2.2.10 + # VERSION structure - [MS-NLMP] §2.2.2.10 version_bytes = version if version is not None else NTLM_VERSION_PLACEHOLDER challenge_message = ntlm.NTLMAuthChallenge() @@ -796,7 +809,7 @@ def NTLM_build_challenge_message( challenge_message["TargetInfoFields_offset"] = target_info_offset else: # TargetInfo is a sequence of AV_PAIR structures (§2.2.2.1). - # Full AvId space — disposition for each entry: + # Full AvId space - disposition for each entry: # # AvId Constant Sent Notes # 0x0000 MsvAvEOL auto List terminator; ntlm.AV_PAIRS appends it. @@ -806,18 +819,18 @@ def NTLM_build_challenge_message( # 0x0004 MsvAvDnsDomainName YES DNS domain FQDN. # 0x0005 MsvAvDnsTreeName COND Forest FQDN; omitted when not domain-joined. # 0x0006 MsvAvFlags NO Constrained-auth flag (0x1); not applicable - # here — Dementor does not enforce constrained - # delegation. 0x2/0x4 bits are client→server. + # here - Dementor does not enforce constrained + # delegation. 0x2/0x4 bits are client->server. # 0x0007 MsvAvTimestamp NO Intentionally omitted; see note below. - # 0x0008 MsvAvSingleHost N/A Client→server only (AUTHENTICATE_MESSAGE). - # 0x0009 MsvAvTargetName N/A Client→server only (AUTHENTICATE_MESSAGE). - # 0x000A MsvAvChannelBindings N/A Client→server only (AUTHENTICATE_MESSAGE). + # 0x0008 MsvAvSingleHost N/A Client->server only (AUTHENTICATE_MESSAGE). + # 0x0009 MsvAvTargetName N/A Client->server only (AUTHENTICATE_MESSAGE). + # 0x000A MsvAvChannelBindings N/A Client->server only (AUTHENTICATE_MESSAGE). # # §2.2.2.1: 0x0001 and 0x0002 MUST be present. MsvAvEOL is # appended automatically by ntlm.AV_PAIRS. AV_PAIRs may appear in # any order per spec; ascending AvId matches real Windows behaviour. - # AV_PAIR values used directly from kwargs — no derivation chains. + # AV_PAIR values used directly from kwargs - no derivation chains. # 0x0001 and 0x0002 are required by spec (always sent). # 0x0003, 0x0004, 0x0005 are optional (omitted when empty). @@ -828,7 +841,7 @@ def NTLM_build_challenge_message( # 4. AV_PAIRS ------------------------------------------------------- # §2.2.2.1: 0x0001 and 0x0002 MUST be present. - # 0x0003-0x0005 are optional — omitted when empty/not configured. + # 0x0003-0x0005 are optional - omitted when empty/not configured. av_pairs = ntlm.AV_PAIRS() av_pairs[ntlm.NTLMSSP_AV_HOSTNAME] = nb_computer.encode( "utf-16le" @@ -871,6 +884,42 @@ def NTLM_build_challenge_message( return challenge_message +def ntlm_auth_create_challenge( + token: "ntlm.NTLMAuthNegotiate | dict[str, Any]", + nb_computer: str, + nb_domain: str, + *, + challenge: bytes, + disable_ess: bool = False, + disable_ntlmv2: bool = False, +) -> "ntlm.NTLMAuthChallenge": + """Build a CHALLENGE_MESSAGE with positional computer/domain arguments. + + Convenience wrapper around :func:`ntlm_build_challenge_message` for callers + (e.g. LDAP) that supply ``nb_computer`` and ``nb_domain`` as positional + arguments rather than keyword arguments. + + :param token: Parsed NEGOTIATE_MESSAGE + :param nb_computer: NetBIOS computer name + :param nb_domain: NetBIOS domain name + :param challenge: 8-byte ServerChallenge nonce + :param disable_ess: Strip ESS flag + :param disable_ntlmv2: Omit TargetInfoFields + :return: Serialisable CHALLENGE_MESSAGE + """ + return ntlm_build_challenge_message( + token, + challenge=challenge, + nb_computer=nb_computer, + nb_domain=nb_domain, + disable_ess=disable_ess, + disable_ntlmv2=disable_ntlmv2, + ) + + +# -- Hash capture helpers (within NTLMSSP Transaction) ----------------------- + + def _log_ntlmv2_blob( auth_token: ntlm.NTLMAuthChallengeResponse, log: ProtocolLogger, @@ -895,7 +944,7 @@ def _log_ntlmv2_blob( try: nt_response: bytes = auth_token["ntlm"] or b"" if len(nt_response) <= NTLMV1_RESPONSE_LEN: - return None # NTLMv1 — no blob + return None # NTLMv1 - no blob # NTLMv2 blob starts after NTProofStr (16 bytes) blob = nt_response[NTLM_NTPROOFSTR_LEN:] @@ -907,7 +956,7 @@ def _log_ntlmv2_blob( # + ChallengeFromClient(8) + Reserved3(4) = 28 bytes # AV_PAIRs start at offset 28 in the blob. - # ClientChallenge — 8-byte client nonce at blob[16:24] + # ClientChallenge - 8-byte client nonce at blob[16:24] client_challenge = blob[16:24] av_data = blob[28:] @@ -951,20 +1000,117 @@ def _log_ntlmv2_blob( else "ChannelBindings=(empty)" ) - # MsvAvSingleHost (0x0008) — machine identity claim + # MsvAvSingleHost (0x0008) - machine identity claim if ntlm.NTLMSSP_AV_RESTRICTIONS in av_pairs.fields: _, sh_raw = av_pairs[ntlm.NTLMSSP_AV_RESTRICTIONS] blob_parts.append(f"SingleHost={sh_raw.hex()}") log.debug("NTLMv2 blob: %s", " ".join(blob_parts), is_client=True) - except Exception: + except (KeyError, AttributeError, TypeError, struct.error): log.debug("Failed to parse NTLMv2 blob AV_PAIRs", exc_info=True) return target_name -def NTLM_handle_authenticate_message( +def _build_ntlm_display_line( + negotiate_fields: dict[str, str] | None, + os_str: str, + user_name: str, + domain_name: str, + host_name_str: str, + spn: str, +) -> str | None: + """Build a deduped NTLM display line from Type 1 + Type 3 fields. + + Returns a formatted ``'key:value | ...'`` string, or ``None`` when all + fields are empty. + """ + ntlm_fields: dict[str, set[str]] = { + "os": set(), + "user": set(), + "domain": set(), + "name": set(), + "SPN": set(), + } + if negotiate_fields: + for k, v in negotiate_fields.items(): + if v and k in ntlm_fields: + ntlm_fields[k].add(v) + for key, value in ( + ("os", os_str), + ("user", user_name), + ("domain", domain_name), + ("name", host_name_str), + ("SPN", spn), + ): + if value: + ntlm_fields[key].add(value) + + parts = [ + f"{k}:{','.join(sorted(ntlm_fields[k]))}" + for k in ("os", "user", "domain", "name", "SPN") + if ntlm_fields.get(k) + ] + return " | ".join(parts) if parts else None + + +def _build_ntlm_host_info( + os_str: str, host_name_str: str, domain_name: str +) -> str | None: + """Build HOST_INFO string from NTLMSSP AUTHENTICATE fields.""" + host_parts: list[str] = [] + if os_str: + host_parts.append(os_str) + if host_name_str: + host_parts.append(f"(name: {host_name_str})") + if domain_name: + host_parts.append(f"(domain: {domain_name})") + return " ".join(host_parts) if host_parts else None + + +def _store_ntlm_captures( + all_hashes: list[tuple[str, str]], + *, + user_name: str, + domain_name: str, + os_str: str, + host_name_str: str, + client: tuple[str, int], + session: "SessionConfig", + logger: "ProtocolLogger | None", + extras: dict[str, Any] | None, +) -> None: + """Write captured NTLM hashes to the session capture database. + + Separated from hash extraction so the auth-parsing and storage + concerns live in different functions. + + :param all_hashes: Non-empty list of ``(version_label, hashcat_line)`` pairs + :param user_name: Decoded username from the AUTHENTICATE_MESSAGE + :param domain_name: Decoded domain name from the AUTHENTICATE_MESSAGE + :param os_str: OS version string for HOST_INFO annotation + :param host_name_str: Hostname string for HOST_INFO annotation + :param client: Client (host, port) tuple + :param session: Active session with capture database + :param logger: Protocol logger passed through to ``add_auth`` + :param extras: Extra metadata dict; mutated to add HOST_INFO + """ + extras = extras or {} + extras[HOST_INFO] = _build_ntlm_host_info(os_str, host_name_str, domain_name) + for version_label, hashcat_line in all_hashes: + session.db.add_auth( + client=client, + credtype=version_label, + username=user_name, + domain=domain_name, + password=hashcat_line, + logger=logger, + extras=extras, + ) + + +def ntlm_handle_authenticate_message( auth_token: ntlm.NTLMAuthChallengeResponse, *, challenge: bytes, @@ -1003,9 +1149,9 @@ def NTLM_handle_authenticate_message( :param transport: NTLM transport identifier (NTLM_TRANSPORT_*); used for logging only :type transport: str :param negotiate_fields: Fields extracted from the NEGOTIATE_MESSAGE by - :func:`NTLM_handle_negotiate_message`. Merged (Type 3 wins) into the display line + :func:`ntlm_handle_negotiate_message`. Merged (Type 3 wins) into the display line so the deduped output reflects both messages. This is ntlm.py's own - output passed back in — no protocol-layer state. + output passed back in - no protocol-layer state. :type negotiate_fields: dict[str, str] | None """ # Use the protocol logger for session-linked messages; fall back to the @@ -1013,7 +1159,7 @@ def NTLM_handle_authenticate_message( log = logger or dm_logger if _is_anonymous_authenticate(auth_token): - log.debug("Anonymous NTLM login attempt; skipping hash extraction") + log.display("Anonymous NTLM login attempt; skipping hash extraction") return False # -- AUTHENTICATE_MESSAGE parsed fields (single debug line) ------------ @@ -1023,11 +1169,11 @@ def NTLM_handle_authenticate_message( negotiate_flags: int = auth_token["flags"] try: host_name_str = ( - NTLM_decode_string(auth_token["host_name"], negotiate_flags) or "" + ntlm_decode_string(auth_token["host_name"], negotiate_flags) or "" ) - except Exception: + except (KeyError, UnicodeDecodeError): dm_logger.debug("Failed to parse host_name from AUTHENTICATE_MESSAGE") - mic_str: str = "(absent)" # no VERSION flag → MIC field doesn't exist + mic_str: str = "(absent)" # no VERSION flag -> MIC field doesn't exist try: if negotiate_flags & ntlm.NTLMSSP_NEGOTIATE_VERSION: mic_val: bytes = auth_token["MIC"] @@ -1036,12 +1182,13 @@ def NTLM_handle_authenticate_message( if mic_val and len(mic_val) == 16 and mic_val != b"\x00" * 16 else "(empty)" ) - except Exception: # noqa: S110 - pass + except (KeyError, TypeError): + log.debug("Failed to parse MIC field", exc_info=True) + mic_str = "(parse error)" auth_parts = [f"flags=0x{negotiate_flags:08x}"] auth_parts.append(f"os={os_str!r}" if os_str else "os=(empty)") - user_name: str = NTLM_decode_string(auth_token["user_name"], negotiate_flags) - domain_name: str = NTLM_decode_string(auth_token["domain_name"], negotiate_flags) + user_name: str = ntlm_decode_string(auth_token["user_name"], negotiate_flags) + domain_name: str = ntlm_decode_string(auth_token["domain_name"], negotiate_flags) auth_parts.append(f"user={user_name!r}" if user_name else "user=(empty)") auth_parts.append(f"domain={domain_name!r}" if domain_name else "domain=(empty)") auth_parts.append(f"name={host_name_str!r}" if host_name_str else "name=(empty)") @@ -1051,18 +1198,18 @@ def NTLM_handle_authenticate_message( auth_parts.append(f"LM_len={lm_len}") auth_parts.append(f"MIC={mic_str}") log.debug("NTLMSSP AUTHENTICATE: %s", " ".join(auth_parts), is_client=True) - except Exception: + except (KeyError, UnicodeDecodeError, TypeError): log.debug("Failed to parse AUTHENTICATE_MESSAGE fields", exc_info=True) try: negotiate_flags = auth_token["flags"] - except Exception: + except KeyError: negotiate_flags = 0 user_name = "" domain_name = "" # -- Hash extraction --------------------------------------------------- try: - all_hashes = NTLM_to_hashcat( + all_hashes = ntlm_to_hashcat( server_challenge=challenge, user_name=auth_token["user_name"], domain_name=auth_token["domain_name"], @@ -1084,47 +1231,11 @@ def NTLM_handle_authenticate_message( spn = _log_ntlmv2_blob(auth_token, log) # -- Consolidated display line (Type 1 + Type 3 deduped) ----------- - # Collect all identity fields into sets so values from both - # NEGOTIATE (Type 1) and AUTHENTICATE (Type 3) are shown, - # with duplicates removed. Empty strings are filtered. - ntlm_fields: dict[str, set[str]] = { - "os": set(), - "user": set(), - "domain": set(), - "name": set(), - "SPN": set(), - } - # Add Type 1 (NEGOTIATE) fields - if negotiate_fields: - for k, v in negotiate_fields.items(): - if v and k in ntlm_fields: - ntlm_fields[k].add(v) - # Add Type 3 (AUTHENTICATE) fields - if os_str: - ntlm_fields["os"].add(os_str) - if user_name: - ntlm_fields["user"].add(user_name) - if domain_name: - ntlm_fields["domain"].add(domain_name) - if host_name_str: - ntlm_fields["name"].add(host_name_str) - if spn: - ntlm_fields["SPN"].add(spn) - - display_keys = [ - ("os", "os"), - ("user", "user"), - ("domain", "domain"), - ("name", "name"), - ("SPN", "SPN"), - ] - parts = [ - f"{label}:{','.join(sorted(ntlm_fields[k]))}" - for k, label in display_keys - if ntlm_fields.get(k) - ] - if parts: - log.info("NTLM: %s", " | ".join(parts)) + display_line = _build_ntlm_display_line( + negotiate_fields, os_str, user_name, domain_name, host_name_str, spn + ) + if display_line: + log.debug("NTLM: %s", display_line) log.debug( "Writing %d hash(es) to capture database for user=%r domain=%r", @@ -1132,28 +1243,17 @@ def NTLM_handle_authenticate_message( user_name, domain_name, ) - # Build host_info for model.py from extracted fields. - host_parts: list[str] = [] - if os_str: - host_parts.append(os_str) - if host_name_str: - host_parts.append(f"(name: {host_name_str})") - if domain_name: - host_parts.append(f"(domain: {domain_name})") - host_info = " ".join(host_parts) if host_parts else None - extras = extras or {} - extras[_HOST_INFO] = host_info - - for version_label, hashcat_line in all_hashes: - session.db.add_auth( - client=client, - credtype=version_label, - username=user_name, - domain=domain_name, - password=hashcat_line, - logger=logger, - extras=extras, - ) + _store_ntlm_captures( + all_hashes, + user_name=user_name, + domain_name=domain_name, + os_str=os_str, + host_name_str=host_name_str, + client=client, + session=session, + logger=logger, + extras=extras, + ) return bool(all_hashes) @@ -1162,12 +1262,61 @@ def NTLM_handle_authenticate_message( "Invalid data in AUTHENTICATE_MESSAGE (bad challenge length or " "malformed response fields); skipping capture" ) - except Exception: + except (AttributeError, KeyError, IndexError, TypeError, OverflowError): log.exception("Failed to extract NTLM hashes from AUTHENTICATE_MESSAGE") return False +# -- Reporting helpers (within NTLMSSP Transaction) -------------------------- + + +def ntlm_report_auth( + auth_token: "ntlm.NTLMAuthChallengeResponse", + *, + challenge: bytes, + client: tuple[str, int], + logger: "ProtocolLogger | None" = None, + session: "SessionConfig", + negotiate_fields: "dict[str, str] | None" = None, +) -> bool: + """Report captured NTLM credentials. + + Convenience wrapper around :func:`ntlm_handle_authenticate_message` for + callers (e.g. LDAP) that pass ``auth_token`` positionally and want + keyword-only arguments for the remaining parameters. + + :param auth_token: Parsed AUTHENTICATE_MESSAGE + :param challenge: 8-byte ServerChallenge from the CHALLENGE_MESSAGE + :param client: Client (host, port) tuple + :param logger: Protocol logger + :param session: Active session config with capture database + :param negotiate_fields: Fields from the NEGOTIATE_MESSAGE to merge + :return: ``True`` if credentials were captured + """ + return ntlm_handle_authenticate_message( + auth_token, + challenge=challenge, + client=client, + session=session, + logger=logger, + negotiate_fields=negotiate_fields, + ) + + +def ntlm_split_fqdn(fqdn: str) -> tuple[str, str]: + """Split an FQDN into ``(hostname, domain)`` components. + + :param fqdn: Fully-qualified domain name, e.g. ``"dc01.corp.example.com"`` + :return: ``(hostname, domain)`` - ``domain`` is empty if no dot is present + :rtype: tuple[str, str] + """ + if "." in fqdn: + hostname, domain = fqdn.split(".", 1) + return hostname, domain + return fqdn, "WORKGROUP" + + # --- Hash Formatting --------------------------------------------------------- @@ -1286,7 +1435,25 @@ def _compute_dummy_lm_responses(server_challenge: bytes) -> set[bytes]: # User/Domain MUST be decoded plain-text strings, NOT raw hex bytes. -def NTLM_to_hashcat( +def _decode_ntlm_identity_string( + value: bytes | str, negotiate_flags: int, field: str +) -> str: + """Decode a UserName or DomainName field from an AUTHENTICATE_MESSAGE. + + Handles both pre-decoded strings and raw wire bytes. Returns an empty + string and logs a debug message on any decoding failure. + """ + try: + if isinstance(value, (bytes, bytearray, memoryview)): + return ntlm_decode_string(bytes(value), negotiate_flags) + except (UnicodeDecodeError, TypeError): + dm_logger.debug("Failed to decode %s; using empty string", field, exc_info=True) + return "" + else: + return value or "" + + +def ntlm_to_hashcat( server_challenge: bytes, user_name: bytes | str, domain_name: bytes | str, @@ -1345,29 +1512,12 @@ def NTLM_to_hashcat( # -- Decode identity strings --------------------------------------------- # Both hashcat modes require decoded plain-text strings, not raw wire # bytes. Hashcat does its own toupper + UTF-16LE expansion internally. - try: - user: str = ( - NTLM_decode_string(bytes(user_name), negotiate_flags) - if isinstance(user_name, (bytes, bytearray, memoryview)) - else (user_name or "") - ) - except Exception: - dm_logger.debug("Failed to decode UserName; using empty string", exc_info=True) - user = "" - - try: - domain: str = ( - NTLM_decode_string(bytes(domain_name), negotiate_flags) - if isinstance(domain_name, (bytes, bytearray, memoryview)) - else (domain_name or "") - ) - except Exception: - dm_logger.debug("Failed to decode DomainName; using empty string", exc_info=True) - domain = "" + user: str = _decode_ntlm_identity_string(user_name, negotiate_flags, "UserName") + domain: str = _decode_ntlm_identity_string(domain_name, negotiate_flags, "DomainName") try: hash_type: str = _classify_hash_type(nt_response, lm_response, negotiate_flags) - except Exception: + except (TypeError, ValueError, AttributeError): dm_logger.debug( "_classify_hash_type raised unexpectedly; defaulting to %s", NTLM_V1, @@ -1402,7 +1552,7 @@ def NTLM_to_hashcat( ) ) dm_logger.debug("Appended %s hash (nt_len=%d)", NTLM_V2, len(nt_response)) - except Exception: + except (TypeError, AttributeError): dm_logger.debug("Failed to format %s hash; skipping", NTLM_V2, exc_info=True) return captures @@ -1439,7 +1589,7 @@ def NTLM_to_hashcat( len(lm_response), NTLM_V2_LM, ) - except Exception: + except (TypeError, AttributeError): dm_logger.debug( "Failed to format %s hash; skipping", NTLM_V2_LM, exc_info=True ) @@ -1465,7 +1615,7 @@ def NTLM_to_hashcat( ) ) dm_logger.debug("Appended %s hash", NTLM_V1_ESS) - except Exception: + except (TypeError, AttributeError): dm_logger.debug( "Failed to format %s hash; skipping", NTLM_V1_ESS, exc_info=True ) @@ -1507,7 +1657,7 @@ def NTLM_to_hashcat( ) ) dm_logger.debug("Appended %s hash (lm_slot_empty=%s)", NTLM_V1, lm_slot_hex == "") - except Exception: + except (TypeError, AttributeError): dm_logger.debug("Failed to format %s hash; skipping", NTLM_V1, exc_info=True) return captures @@ -1516,7 +1666,21 @@ def NTLM_to_hashcat( # --- Legacy SMB1 Basic Auth (non-NTLMSSP) ----------------------------------- -def NTLM_handle_legacy_raw_auth( +def _build_smb1_host_info(extras: dict[str, Any], domain: str, fallback: str) -> str: + """Build a HOST_INFO string from SMB1 basic-security auth fields. + + Consumes the ``os`` key from *extras* (if present) and appends the domain. + Returns *fallback* when no parts are found. + """ + host_parts: list[str] = [] + if extras.get("os"): + host_parts.append(extras.pop("os")) + if domain: + host_parts.append(f"(domain: {domain})") + return " ".join(host_parts) if host_parts else fallback + + +def ntlm_handle_legacy_raw_auth( *, user_name: bytes | str, domain_name: bytes | str, @@ -1539,7 +1703,7 @@ def NTLM_handle_legacy_raw_auth( For NTLM_TRANSPORT_RAW: classifies LM/NT response bytes and formats hashcat lines using the existing pipeline. No NTLMSSP wrapper exists - on this path — do NOT create a fake NTLMAuthChallengeResponse. + on this path - do NOT create a fake NTLMAuthChallengeResponse. For NTLM_TRANSPORT_CLEARTEXT: stores the raw password directly. @@ -1547,9 +1711,9 @@ def NTLM_handle_legacy_raw_auth( :type user_name: bytes | str :param domain_name: PrimaryDomain from SESSION_SETUP_ANDX :type domain_name: bytes | str - :param lm_response: OEMPassword (LM response) — None for cleartext + :param lm_response: OEMPassword (LM response) - None for cleartext :type lm_response: bytes | None - :param nt_response: UnicodePassword (NT response) — None for cleartext + :param nt_response: UnicodePassword (NT response) - None for cleartext :type nt_response: bytes | None :param challenge: 8-byte server challenge from negotiate :type challenge: bytes @@ -1575,7 +1739,7 @@ def NTLM_handle_legacy_raw_auth( else (user_name or "") ) # Protocol handlers should decode strings before calling this function. - # The bytes fallback assumes UTF-16LE for safety — only reachable if + # The bytes fallback assumes UTF-16LE for safety - only reachable if # a caller passes raw bytes directly. domain: str = ( domain_name.decode("utf-16-le", errors="replace") @@ -1592,15 +1756,10 @@ def NTLM_handle_legacy_raw_auth( f"Cleartext password captured: {user}\\{domain}", ) extras = extras or {} - host_parts: list[str] = [] - if extras.get("os"): - host_parts.append(extras.pop("os")) - if domain: - host_parts.append(f"(domain: {domain})") - extras[_HOST_INFO] = " ".join(host_parts) if host_parts else "SMB1 cleartext" + extras[HOST_INFO] = _build_smb1_host_info(extras, domain, "SMB1 cleartext") session.db.add_auth( client=client, - credtype="Cleartext", + credtype=CLEARTEXT, username=user, domain=domain, password=cleartext_password, @@ -1609,11 +1768,11 @@ def NTLM_handle_legacy_raw_auth( ) return - # RAW transport — classify and format hashes + # RAW transport - classify and format hashes lm_response = lm_response or b"" nt_response = nt_response or b"" - # Anonymous check — empty user + empty NT + empty/null LM + # Anonymous check - empty user + empty NT + empty/null LM if not user and not nt_response and (not lm_response or lm_response == b"\x00"): log.debug("Anonymous SMB1 basic-security login; skipping hash extraction") return @@ -1624,7 +1783,7 @@ def NTLM_handle_legacy_raw_auth( try: # negotiate_flags=0: no NTLMSSP flags exist on this path - all_hashes = NTLM_to_hashcat( + all_hashes = ntlm_to_hashcat( server_challenge=challenge, user_name=user, domain_name=domain, @@ -1649,12 +1808,7 @@ def NTLM_handle_legacy_raw_auth( # Build host_info from available SMB1 basic-security fields. # SMB1 basic-security has NativeOS and PrimaryDomain but no # workstation name (unlike NTLMSSP AUTHENTICATE). - host_parts: list[str] = [] - if extras.get("os"): - host_parts.append(extras.pop("os")) - if domain: - host_parts.append(f"(domain: {domain})") - extras[_HOST_INFO] = " ".join(host_parts) if host_parts else "SMB1 raw" + extras[HOST_INFO] = _build_smb1_host_info(extras, domain, "SMB1 raw") for version_label, hashcat_line in all_hashes: session.db.add_auth( client=client, @@ -1668,20 +1822,20 @@ def NTLM_handle_legacy_raw_auth( except ValueError: log.exception("Invalid data in SMB1 basic-security auth; skipping capture") - except Exception: + except (AttributeError, KeyError, IndexError, TypeError, OverflowError): log.exception("Failed to extract hashes from SMB1 basic-security auth") # --- Utilities --------------------------------------------------------------- -def NTLM_timestamp() -> int: +def ntlm_timestamp() -> int: """Return the current UTC time as a Windows FILETIME (100ns ticks since 1601-01-01). :return: Current UTC time in 100-nanosecond intervals since Windows epoch (1601-01-01) :rtype: int """ - # calendar.timegm() → UTC seconds since 1970; scaled to 100ns ticks since 1601. + # calendar.timegm() -> UTC seconds since 1970; scaled to 100ns ticks since 1601. return ( NTLM_FILETIME_EPOCH_OFFSET + calendar.timegm(time.gmtime()) * NTLM_FILETIME_TICKS_PER_SECOND diff --git a/dementor/protocols/pop3.py b/dementor/protocols/pop3.py index 7eb4940..4df3dc5 100644 --- a/dementor/protocols/pop3.py +++ b/dementor/protocols/pop3.py @@ -25,9 +25,11 @@ # - https://www.rfc-editor.org/rfc/rfc1734 # - https://datatracker.ietf.org/doc/html/rfc4616 # - https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-pop3/ +from dementor.config.util import HostValue, HostFallbackValue import base64 import binascii import typing +from typing import Literal, overload from typing_extensions import override from impacket import ntlm @@ -35,9 +37,9 @@ from dementor.loader import BaseProtocolModule, DEFAULT_ATTR from dementor.config.session import SessionConfig from dementor.protocols.ntlm import ( - NTLM_build_challenge_message, - NTLM_handle_authenticate_message, - NTLM_handle_negotiate_message, + ntlm_build_challenge_message, + ntlm_handle_authenticate_message, + ntlm_handle_negotiate_message, ) from dementor.servers import ( ServerThread, @@ -47,14 +49,13 @@ BaseServerThread, ) from dementor.log.logger import ProtocolLogger -from dementor.db import _CLEARTEXT +from dementor.db import CLEARTEXT from dementor.config.toml import ( TomlConfig, Attribute as A, ) from dementor.config.attr import ATTR_TLS, ATTR_CERT, ATTR_KEY - __proto__ = ["POP3"] POP3_AUTH_MECHANISMS = [ @@ -70,7 +71,13 @@ class POP3ServerConfig(TomlConfig): _section_ = "POP3" _fields_ = [ A("pop3_port", "Port"), - A("pop3_fqdn", "FQDN", "Dementor", section_local=False), + A( + "pop3_fqdn", + "Host", + None, + section_local=False, + factory=HostFallbackValue(HostValue.HOST, "DEMENTOR"), + ), A("pop3_downgrade", "Downgrade", True), A("pop3_banner", "Banner", "POP3 Server ready"), A("pop3_auth_mechs", "AuthMechanisms", POP3_AUTH_MECHANISMS), @@ -100,7 +107,7 @@ class POP3(BaseProtocolModule[POP3ServerConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: POP3ServerConfig - ) -> BaseServerThread: + ) -> BaseServerThread[POP3ServerConfig]: return ServerThread( session, server_config, @@ -115,8 +122,10 @@ class CloseConnection(Exception): class POP3Handler(BaseProtoHandler): - def __init__(self, config, server_config, request, client_address, server) -> None: - self.server_config = server_config + def __init__( + self, config, server_config: POP3ServerConfig, request, client_address, server + ) -> None: + self.server_config: POP3ServerConfig = server_config super().__init__(config, request, client_address, server) def proto_logger(self) -> ProtocolLogger: @@ -142,6 +151,22 @@ def line(self, msg: str, prefix: str | None = None) -> None: self.logger.debug(f"S: {line!r}") self.send(f"{line}\r\n".encode("utf-8", "strict")) + @overload + def challenge_auth( + self, + token: bytes | None = ..., + decode: Literal[False] = ..., + prefix: str | None = ..., + ) -> bytes: ... + + @overload + def challenge_auth( + self, + token: bytes | None = ..., + decode: Literal[True] = ..., + prefix: str | None = ..., + ) -> str: ... + def challenge_auth( self, token: bytes | None = None, @@ -169,7 +194,7 @@ def challenge_auth( response = response.decode("utf-8", errors="replace") return response - def handle_data(self, data, transport): + def handle_data(self, data: bytes, transport) -> None: # ty:ignore[invalid-method-override] self.request.settimeout(2) self.rfile = transport.makefile("rb") @@ -200,13 +225,13 @@ def handle_data(self, data, transport): # Implementation # [rfc1939] 4. The AUTHORIZATION State # QUIT - def do_QUIT(self, args): + def do_QUIT(self, args: list[str]) -> None: self.ok("Goodbye") raise CloseConnection # [rfc1939] 7. Optional POP3 Commands # USER - def do_USER(self, args): + def do_USER(self, args: list[str]) -> None: if len(args) != 1: self.err("Invalid number of arguments") return @@ -216,10 +241,9 @@ def do_USER(self, args): # [rfc1939] 7. Optional POP3 Commands # PASS - def do_PASS(self, args): + def do_PASS(self, args: list[str]) -> None: if len(args) < 1: return self.err("Invalid number of arguments") - return None if not hasattr(self, "username"): return self.err("Username not set") @@ -234,7 +258,7 @@ def do_PASS(self, args): username=self.username, password=self.password, logger=self.logger, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, ) del self.username del self.password @@ -242,7 +266,7 @@ def do_PASS(self, args): # [rfc2449] 5. The CAPA Command # CAPA - def do_CAPA(self, args): + def do_CAPA(self, args: list[str]) -> None: self.ok("Capability list follows") # The USER capability indicates that the USER and PASS commands # are supported, although they may not be available to all users @@ -257,7 +281,7 @@ def do_CAPA(self, args): # [rfc1734] 2. The AUTH command # AUTH - def do_AUTH(self, args): + def do_AUTH(self, args: list[str]) -> None: if len(args) != 1: self.err("Invalid number of arguments") return None @@ -282,7 +306,7 @@ def do_AUTH(self, args): # [rfc4616] 2. PLAIN SASL Mechanism # PLAIN - def auth_PLAIN(self, initial_response=None): + def auth_PLAIN(self, initial_response: str | None = None) -> None: if not initial_response: initial_response = self.challenge_auth(decode=True) @@ -302,13 +326,13 @@ def auth_PLAIN(self, initial_response=None): username=login, password=password, logger=self.logger, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, ) self.err("Invalid username or password") # https://datatracker.ietf.org/doc/html/draft-murchison-sasl-login-00 # LOGIN - def auth_LOGIN(self, username: bytes | None = None): + def auth_LOGIN(self, username: str | None = None) -> None: if not username: # The server issues the string "User Name" in challenge, and receives a # client response. This response is recorded as the authorization @@ -332,7 +356,7 @@ def auth_LOGIN(self, username: bytes | None = None): username=username, password=password, logger=self.logger, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, ) self.err("Invalid username or password") @@ -353,14 +377,21 @@ def auth_NTLM(self, initial_response=None) -> None: # 3. The server sends a POP3_AUTH_NTLM_Blob_Response message containing # a base64-encoded NTLM CHALLENGE_MESSAGE. - negotiate_fields = NTLM_handle_negotiate_message(negotiate, self.logger) - challenge = NTLM_build_challenge_message( + negotiate_fields = ntlm_handle_negotiate_message(negotiate, self.logger) + host = HostValue(self.server_config.pop3_fqdn) + challenge = ntlm_build_challenge_message( negotiate, challenge=self.config.ntlm_challenge, - nb_computer=self.config.ntlm_nb_computer, - nb_domain=self.config.ntlm_nb_domain, + nb_computer=host.get_value(HostValue.NETBIOS_COMPUTER), + nb_domain=host.get_value(HostValue.NETBIOS_DOMAIN), disable_ess=self.config.ntlm_disable_ess, disable_ntlmv2=self.config.ntlm_disable_ntlmv2, + target_type=self.config.ntlm_target_type, + version=self.config.ntlm_version, + dns_computer=host.get_value(HostValue.DNS_COMPUTER), + dns_domain=host.get_value(HostValue.DNS_DOMAIN), + # REVISIT: capture DNSTree too + # dns_tree=self.config.ntlm_dns_tree, log=self.logger, ) token = self.challenge_auth(challenge.getData()) @@ -370,7 +401,7 @@ def auth_NTLM(self, initial_response=None) -> None: auth_message = ntlm.NTLMAuthChallengeResponse() auth_message.fromString(token) - NTLM_handle_authenticate_message( + ntlm_handle_authenticate_message( auth_message, challenge=self.config.ntlm_challenge, client=self.client_address, diff --git a/dementor/protocols/quic.py b/dementor/protocols/quic.py index d763a1a..6ad14c8 100644 --- a/dementor/protocols/quic.py +++ b/dementor/protocols/quic.py @@ -37,7 +37,7 @@ from dementor.servers import AsyncServerThread, BaseServerThread from dementor.config.toml import TomlConfig, Attribute as A from dementor.config.session import SessionConfig -from dementor.config.util import generate_self_signed_cert +from dementor.config.tls import generate_self_signed_cert from dementor.log.logger import ProtocolLogger, dm_logger if typing.TYPE_CHECKING: @@ -109,12 +109,12 @@ def __init__( self.config: SessionConfig = config # stream_id -> (w, r) self.conn_data: dict[int, tuple[asyncio.StreamWriter, asyncio.StreamReader]] = {} - self.logger: ProtocolLogger = QuicHandler.proto_logger( + self.logger: ProtocolLogger = QuicHandler._make_proto_logger( self.config.quic_config.quic_port ) @staticmethod - def proto_logger(port: int) -> ProtocolLogger: + def _make_proto_logger(port: int) -> ProtocolLogger: return ProtocolLogger( extra={ "protocol": "QUIC", @@ -123,6 +123,11 @@ def proto_logger(port: int) -> ProtocolLogger: } ) + @property + def proto_logger(self) -> ProtocolLogger: + """Return the protocol logger for this QUIC connection.""" + return QuicHandler._make_proto_logger(self.config.quic_config.quic_port) + @property def target_smb_host(self): return self.config.quic_config.quic_smb_host or self.host @@ -145,7 +150,9 @@ def quic_event_received(self, event: events.QuicEvent) -> None: case _: pass # ignore other events for now - async def handle_data(self, stream_id: int, data: bytes): + # NOTE: This is intentionally out-of-band from BaseServerThread.handle_data(data, transport). + # QUIC multiplexes streams, so this method takes a stream_id instead of a transport socket. + async def handle_data(self, stream_id: int, data: bytes) -> None: if stream_id not in self.conn_data: # create new connection network_path = self._quic._network_paths[0] @@ -218,7 +225,7 @@ def is_running(self) -> bool: def generate_self_signed_cert(self) -> None: """Generate a self-signed certificate and private key for QUIC server.""" - logger = QuicHandler.proto_logger(self.server_config.quic_port) + logger = QuicHandler._make_proto_logger(self.server_config.quic_port) logger.display("Generating self-signed certificate for QUIC server") cert_path, key_path, temp_dir = generate_self_signed_cert( @@ -236,7 +243,8 @@ def generate_self_signed_cert(self) -> None: self._temp_dir = temp_dir self._generated_temp_cert = True - def get_service_name(self) -> str: + @property + def service_name(self) -> str: return "QUIC" def create_handler(self, *args: typing.Any, **kwargs: typing.Any): diff --git a/dementor/protocols/smb.py b/dementor/protocols/smb.py index 3d543e9..cf0f3b9 100644 --- a/dementor/protocols/smb.py +++ b/dementor/protocols/smb.py @@ -18,6 +18,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # pyright: basic +import struct as _stdlib_struct import uuid import secrets import typing @@ -47,17 +48,17 @@ from dementor.config.toml import TomlConfig, Attribute as A from dementor.config.session import SessionConfig -from dementor.config.util import is_true +from dementor.config.util import is_true, HostFallbackValue from dementor.loader import BaseProtocolModule, DEFAULT_ATTR from dementor.log.logger import ProtocolLogger, dm_logger from dementor.protocols.ntlm import ( - NTLM_build_challenge_message, + ntlm_build_challenge_message, NTLM_TRANSPORT_CLEARTEXT, NTLM_TRANSPORT_RAW, - NTLM_handle_negotiate_message, - NTLM_timestamp, - NTLM_handle_authenticate_message, - NTLM_handle_legacy_raw_auth, + ntlm_handle_negotiate_message, + ntlm_timestamp, + ntlm_handle_authenticate_message, + ntlm_handle_legacy_raw_auth, ) from dementor.protocols.spnego import ( NEG_STATE_ACCEPT_COMPLETED, @@ -101,10 +102,10 @@ def _split_smb_strings(data: bytes, is_unicode: bool) -> list[str]: return [] if not is_unicode: - # [MS-CIFS] §2.2.1.1: OEM_STRING — single \x00 terminator + # [MS-CIFS] §2.2.1.1: OEM_STRING - single \x00 terminator return [s.decode("ascii", errors="replace") for s in data.split(b"\x00") if s] - # [MS-CIFS] §2.2.1.1: UNICODE_STRING — \x00\x00 at 2-byte aligned offsets + # [MS-CIFS] §2.2.1.1: UNICODE_STRING - \x00\x00 at 2-byte aligned offsets segments: list[str] = [] start = 0 i = 0 @@ -152,13 +153,13 @@ def _split_smb_strings(data: bytes, is_unicode: bool) -> list[str]: # Realistic SMB2 server values per [MS-SMB2] §2.2.4 (Windows Server defaults) # Per-dialect max sizes matching real Windows pcap behaviour: -# 2.0.2: 65536 (64K) — matches Vista/Srv2008 -# 2.1: 1048576 (1M) or 8388608 (8M) — varies; use 8M for Server 2012+ -# 3.0+: 8388608 (8M) — matches Windows Server 2012+ +# 2.0.2: 65536 (64K) - matches Vista/Srv2008 +# 2.1: 1048576 (1M) or 8388608 (8M) - varies; use 8M for Server 2012+ +# 3.0+: 8388608 (8M) - matches Windows Server 2012+ SMB2_MAX_SIZE_SMALL: int = 65_536 # SMB 2.0.2 SMB2_MAX_SIZE_LARGE: int = 8_388_608 # SMB 2.1+ -# Realistic SMB2 capabilities — [MS-SMB2] §2.2.4 +# Realistic SMB2 capabilities - [MS-SMB2] §2.2.4 # DFS(0x01) | Leasing(0x02) | LargeMTU(0x04) | MultiChannel(0x08) # | DirectoryLeasing(0x20) = 0x2f # We do NOT set Encryption(0x40) since we don't implement it. @@ -175,12 +176,15 @@ def _split_smb_strings(data: bytes, is_unicode: bool) -> list[str]: SMB1_MAX_MPX_COUNT: int = 50 SMB1_MAX_BUFFER_SIZE: int = 16644 -# STATUS_ACCOUNT_DISABLED — used for multi-credential SSPI retry +# STATUS_ACCOUNT_DISABLED - used for multi-credential SSPI retry STATUS_ACCOUNT_DISABLED: int = 0xC0000072 # [MS-SMB2] §2.2.3.1.7: SMB2_SIGNING_CAPABILITIES negotiate context type SMB2_SIGNING_CAPABILITIES_ID: int = 0x0008 +# Reusable exception tuple for parsing struct/packet fields. +_PARSE_ERRORS = (_stdlib_struct.error, KeyError, IndexError, TypeError, ValueError) + # (missing in impacket struct definitions) # [MS-SMB2] §2.2.3.1.7 SMB2_SIGNING_CAPABILITIES @@ -228,13 +232,25 @@ class SMBServerConfig(TomlConfig): A("smb2_min_dialect", "SMB2MinDialect", "2.002", factory=parse_dialect), A("smb2_max_dialect", "SMB2MaxDialect", "3.1.1", factory=parse_dialect), # --- SMB Identity --- - A("smb_nb_computer", "NetBIOSComputer", "DEMENTOR"), - A("smb_nb_domain", "NetBIOSDomain", "WORKGROUP"), + A( + "smb_nb_computer", + "NTLM.NetBIOSComputer", + None, + section_local=False, + factory=HostFallbackValue("NetBIOSComputer", "DEMENTOR"), + ), + A( + "smb_nb_domain", + "NTLM.NetBIOSDomain", + None, + section_local=False, + factory=HostFallbackValue("NetBIOSDomain", "WORKGROUP"), + ), A("smb_server_os", "ServerOS", "Windows"), A("smb_native_lanman", "NativeLanMan", "Windows"), # --- Post-Auth --- A("smb_captures_per_connection", "CapturesPerConnection", 0, factory=int), - A("smb_error_code", "ErrorCode", nt_errors.STATUS_SMB_BAD_UID), + A("smb_error_code", "ErrorCode", default_val=None), ] if typing.TYPE_CHECKING: @@ -249,9 +265,9 @@ class SMBServerConfig(TomlConfig): smb_server_os: str smb_native_lanman: str smb_captures_per_connection: int - smb_error_code: int + smb_error_code: int | None - def set_smb_error_code(self, value: str | int) -> None: + def set_smb_error_code(self, value: str | int | None) -> None: """Set the SMB error code from an integer or nt_errors attribute name. Falls back to STATUS_SMB_BAD_UID if the string does not match any @@ -261,16 +277,19 @@ def set_smb_error_code(self, value: str | int) -> None: (e.g. "STATUS_ACCESS_DENIED") :type value: str | int """ - if isinstance(value, int): - self.smb_error_code = value - else: - try: - self.smb_error_code = getattr(nt_errors, str(value)) - except AttributeError: - dm_logger.error( - f"Invalid SMB error code: {value} - using default: STATUS_SMB_BAD_UID" - ) - self.smb_error_code = nt_errors.STATUS_SMB_BAD_UID + match value: + case int(): + self.smb_error_code = value + case str(): + try: + self.smb_error_code = getattr(nt_errors, str(value)) + except AttributeError: + dm_logger.error( + f"Invalid SMB error code: {value} - using default: STATUS_SMB_BAD_UID" + ) + self.smb_error_code = nt_errors.STATUS_SMB_BAD_UID + case _: + self.smb_error_code = None class SMB(BaseProtocolModule[SMBServerConfig]): @@ -312,7 +331,7 @@ def get_server_time() -> int: :return: Current UTC time encoded as a 64-bit Windows FILETIME value :rtype: int """ - return NTLM_timestamp() + return ntlm_timestamp() def get_command_name(command: int, smb_version: int) -> str: @@ -339,8 +358,6 @@ def get_command_name(command: int, smb_version: int) -> str: for key, value in vars(smb2).items(): if key.startswith("SMB2_") and value == command: return key - case _: - pass return "Unknown" @@ -427,10 +444,10 @@ def __init__( self.client_info: dict[str, str] = {} # Filenames from CREATE/NT_CREATE_ANDX, deduped across the connection. self.client_files: set[str] = set() - # NTLM NEGOTIATE fields returned by NTLM_handle_negotiate_message(). - # Passed back to NTLM_handle_authenticate_message() so the display line is the + # NTLM NEGOTIATE fields returned by ntlm_handle_negotiate_message(). + # Passed back to ntlm_handle_authenticate_message() so the display line is the # deduped union of Type 1 + Type 3. This is ntlm.py's own output - # passed through — smb.py never reads or modifies it. + # passed through - smb.py never reads or modifies it. self.ntlm_negotiate_fields: dict[str, str] = {} # Sequential file ID counters for fake file handles. # SMB1 FIDs are 16-bit; SMB2 FileIDs are 64-bit volatile IDs. @@ -566,7 +583,7 @@ def send_smb1_command( # [MS-CIFS] §2.2.3.1: SMB_FLAGS_REPLY (0x80) on server responses resp["Flags1"] = smb.SMB.FLAGS1_REPLY - # Flags2 depends on security mode — [MS-SMB] §2.2.3.1 + # Flags2 depends on security mode - [MS-SMB] §2.2.3.1 flags2 = ( smb.SMB.FLAGS2_NT_STATUS | smb.SMB.FLAGS2_LONG_NAMES @@ -579,7 +596,7 @@ def send_smb1_command( resp["Pid"] = packet["Pid"] resp["Tid"] = packet["Tid"] resp["Mid"] = packet["Mid"] - # Server-assigned session UID — [MS-SMB] §3.3.5.3 + # Server-assigned session UID - [MS-SMB] §3.3.5.3 if self.smb1_uid: resp["Uid"] = self.smb1_uid if error_code: @@ -600,6 +617,7 @@ def send_smb2_command( packet: typing.Any | None = None, command: int | None = None, status: int | None = None, + tree_id: int | None = None, ) -> None: """Build and send an SMB2 response wrapped in a NetBIOS session packet. @@ -621,6 +639,11 @@ def send_smb2_command( :param status: NTSTATUS code for the response, defaults to None (STATUS_SUCCESS) :type status: int | None, optional + :param tree_id: Override TreeID in the response header; when None the + TreeID is echoed from *packet*. Use for TREE_CONNECT responses + where the server assigns a new TreeID rather than echoing the + client's value. [MS-SMB2] §3.3.5.7 + :type tree_id: int | None, optional """ resp = smb2.SMB2Packet() # [MS-SMB2] §2.2.1: SMB2_FLAGS_SERVER_TO_REDIR (0x01) on responses @@ -640,10 +663,10 @@ def send_smb2_command( resp["Command"] = packet["Command"] resp["CreditCharge"] = packet["CreditCharge"] resp["Reserved"] = packet["Reserved"] - # Server-assigned SessionID — [MS-SMB2] §3.3.5.5.1 + # Server-assigned SessionID - [MS-SMB2] §3.3.5.5.1 resp["SessionID"] = self.smb2_session_id resp["MessageID"] = packet["MessageID"] - resp["TreeID"] = packet["TreeID"] + resp["TreeID"] = tree_id if tree_id is not None else packet["TreeID"] # Real Windows grants 32-256 credits; 1 causes smbclient to exhaust # credits during compound requests (get, dir listing). resp["CreditRequestResponse"] = 32 @@ -714,7 +737,7 @@ def handle_data(self, data: bytes | None, transport: typing.Any) -> None: if cdn: self.client_info["smb_called_name"] = cdn except ValueError: - pass + self.logger.debug("Failed to parse NetBIOS session request names") self.send_data(b"\x00", nmb.NETBIOS_SESSION_POSITIVE_RESPONSE) continue @@ -773,7 +796,7 @@ def handle_smb_packet(self, packet: typing.Any, smbv1: bool = False) -> None: except Exception: self.logger.exception(f"Error in {title}") elif not smbv1: - # Unhandled SMB2 command — respond with STATUS_NOT_SUPPORTED + # Unhandled SMB2 command - respond with STATUS_NOT_SUPPORTED # instead of dropping the connection. This keeps the session # alive so the client can proceed to TREE_CONNECT with the # real share path after IPC$ queries (CREATE, IOCTL, CLOSE). @@ -788,7 +811,7 @@ def handle_smb_packet(self, packet: typing.Any, smbv1: bool = False) -> None: status=nt_errors.STATUS_NOT_SUPPORTED, ) else: - # Unhandled SMB1 command — respond with STATUS_NOT_IMPLEMENTED + # Unhandled SMB1 command - respond with STATUS_NOT_IMPLEMENTED # instead of dropping the connection. Keeps the session alive # so the client can proceed with file operations. # [MS-CIFS] §3.3.5: error response for unsupported commands @@ -887,7 +910,7 @@ def _smb3_get_target_capabilities( offset += context["DataLength"] + 8 offset += (8 - (offset % 8)) % 8 - except Exception as e: + except _PARSE_ERRORS as e: self.logger.debug(f"Warning: invalid negotiate context list: {e}") return target_cipher, target_sign @@ -916,12 +939,12 @@ def _build_smb2_negotiate_response( command["SecurityMode"] = 0x01 # [MS-SMB2] §3.3.5.4: set to the common dialect command["DialectRevision"] = target_revision - # Stable ServerGuid per server instance — [MS-SMB2] §2.2.4 + # Stable ServerGuid per server instance - [MS-SMB2] §2.2.4 command["ServerGuid"] = self.server.server_guid # type: ignore[union-attr] - # Realistic capabilities — [MS-SMB2] §2.2.4 + # Realistic capabilities - [MS-SMB2] §2.2.4 command["Capabilities"] = SMB2_SERVER_CAPABILITIES # Per-dialect max sizes matching real Windows pcap behaviour: - # 2.0.2 → 64K, 2.1+ → 8M (direct TCP, port 445) + # 2.0.2 -> 64K, 2.1+ -> 8M (direct TCP, port 445) max_size = ( SMB2_MAX_SIZE_SMALL if target_revision == smb2.SMB2_DIALECT_002 @@ -1014,14 +1037,14 @@ def handle_smb2_negotiate(self, packet: smb2.SMB2Packet) -> None: req_dialects: list[int] = req_raw_dialects[:dialect_count] if len(req_dialects) == 0: - # [MS-SMB2] §3.3.5.4: DialectCount == 0 → STATUS_INVALID_PARAMETER + # [MS-SMB2] §3.3.5.4: DialectCount == 0 -> STATUS_INVALID_PARAMETER self.logger.debug("SMB2_NEGOTIATE: no dialects offered", is_client=True) self.logger.fail("SMB Negotiation: Client failed to provide any dialects.") raise BaseProtoHandler.TerminateConnection str_req_dialects = ", ".join([SMB2_DIALECTS.get(d, hex(d)) for d in req_dialects]) - # Build ONE consolidated debug line — [MS-SMB2] §2.2.3 + # Build ONE consolidated debug line - [MS-SMB2] §2.2.3 try: guid = uuid.UUID(bytes_le=req["ClientGuid"]) sec_mode: int = req["SecurityMode"] @@ -1033,7 +1056,7 @@ def handle_smb2_negotiate(self, packet: smb2.SMB2Packet) -> None: f"Capabilities=0x{client_caps:08x}" ) - # Add NegotiateContexts only for 3.1.1 — [MS-SMB2] §2.2.3.1 + # Add NegotiateContexts only for 3.1.1 - [MS-SMB2] §2.2.3.1 ctx_data: bytes = req["NegotiateContextList"] or b"" if ctx_data: ctx_types = { @@ -1061,7 +1084,7 @@ def handle_smb2_negotiate(self, packet: smb2.SMB2Packet) -> None: ) # Select the highest common dialect within the configured range. - # No adaptive downgrade — negotiate at the client's native dialect. + # No adaptive downgrade - negotiate at the client's native dialect. # # At 3.1.1, hash capture works but the client disconnects after # SESSION_SETUP without sending TREE_CONNECT because the spec @@ -1096,7 +1119,7 @@ def handle_smb2_negotiate(self, packet: smb2.SMB2Packet) -> None: # [MS-SMB2] §2.2.3: SecurityMode bit 0x0002 = SIGNING_REQUIRED try: self.smb2_client_signing_required = bool(req["SecurityMode"] & 0x0002) - except Exception: + except _PARSE_ERRORS: self.smb2_client_signing_required = False # Client's highest offered dialect (uncapped by our MaxDialect) client_negotiable = [d for d in req_dialects if d in SMB2_NEGOTIABLE_DIALECTS] @@ -1113,7 +1136,7 @@ def handle_smb2_negotiate(self, packet: smb2.SMB2Packet) -> None: # from the user's password hash, which a capture server does # not have. Hash capture still works (the AUTHENTICATE_MESSAGE # arrives before the signed response is validated), but the - # client will disconnect after auth — no TREE_CONNECT, CREATE, + # client will disconnect after auth - no TREE_CONNECT, CREATE, # or READ follows, so share path and filename capture is not # possible. self.logger.debug( @@ -1126,6 +1149,24 @@ def handle_smb2_negotiate(self, packet: smb2.SMB2Packet) -> None: # -- SMB1 Negotiate -- + def _fill_smb1_negotiate_params( + self, params: typing.Any, nt_lm_index: int, server_time: int + ) -> None: + """Set negotiate parameter fields shared by both extended and non-extended paths.""" + params["DialectIndex"] = nt_lm_index + params["SecurityMode"] = ( + smb.SMB.SECURITY_AUTH_ENCRYPTED | smb.SMB.SECURITY_SHARE_USER + ) + params["MaxMpxCount"] = SMB1_MAX_MPX_COUNT + params["MaxNumberVcs"] = 1 + params["MaxBufferSize"] = SMB1_MAX_BUFFER_SIZE + params["MaxRawSize"] = 65536 + params["SessionKey"] = 0 + # [MS-CIFS] §2.2.4.52.2: SystemTime as FILETIME split into 32-bit words + params["LowDateTime"] = server_time & 0xFFFFFFFF + params["HighDateTime"] = (server_time >> 32) & 0xFFFFFFFF + params["ServerTimeZone"] = 0 + def handle_smb1_negotiate(self, packet: smb.NewSMBPacket) -> None: """Handle SMB1 NEGOTIATE -- [MS-SMB] §3.3.5.2. @@ -1168,7 +1209,7 @@ def handle_smb1_negotiate(self, packet: smb.NewSMBPacket) -> None: cfg = self.smb_config # Check for SMB2 dialect strings for protocol transition - # [MS-SMB2] §3.3.5.3.1 — only when AllowSMB1Upgrade and EnableSMB2 + # [MS-SMB2] §3.3.5.3.1 - only when AllowSMB1Upgrade and EnableSMB2 smb2_upgrade_target: str | None = None if cfg.smb_allow_smb1_upgrade and cfg.smb_enable_smb2: smb2_entries: dict[str, int] = { @@ -1199,7 +1240,7 @@ def handle_smb1_negotiate(self, packet: smb.NewSMBPacket) -> None: self.send_smb2_command(command.getData(), command=smb2.SMB2_NEGOTIATE) return - # Find NT LM 0.12 dialect — [MS-SMB] extensions only apply to it + # Find NT LM 0.12 dialect - [MS-SMB] extensions only apply to it nt_lm_index: int | None = None for i, d in enumerate(dialects): if d == "NT LM 0.12": @@ -1212,7 +1253,7 @@ def handle_smb1_negotiate(self, packet: smb.NewSMBPacket) -> None: ) raise BaseProtoHandler.TerminateConnection - # Shared negotiate parameters — [MS-CIFS] §2.2.4.52.2 + # Shared negotiate parameters - [MS-CIFS] §2.2.4.52.2 server_time = get_server_time() # Respond based on the client's capabilities: if the client sets @@ -1237,7 +1278,7 @@ def handle_smb1_negotiate(self, packet: smb.NewSMBPacket) -> None: ) _dialects_data = smb.SMBExtended_Security_Data() - # Stable ServerGuid per server instance — [MS-SMB2] §2.2.4 + # Stable ServerGuid per server instance - [MS-SMB2] §2.2.4 _dialects_data["ServerGUID"] = self.server.server_guid # type: ignore[union-attr] blob = build_neg_token_init([SPNEGO_NTLMSSP_MECH]) _dialects_data["SecurityBlob"] = blob.getData() @@ -1248,6 +1289,7 @@ def handle_smb1_negotiate(self, packet: smb.NewSMBPacket) -> None: smb.SMB.CAP_EXTENDED_SECURITY | SMB1_CAPABILITIES_BASE ) _dialects_parameters["ChallengeLength"] = 0 + dialect_label = "NT LM 0.12" else: # --- Non-extended security path (raw challenge/response) --- # [MS-SMB] §2.2.4.5.2.2 @@ -1265,70 +1307,30 @@ def handle_smb1_negotiate(self, packet: smb.NewSMBPacket) -> None: _dialects_parameters = smb.SMBNTLMDialect_Parameters() _dialects_data = smb.SMBNTLMDialect_Data() - # SecurityMode — [MS-CIFS] §2.2.4.52.2 - _dialects_parameters["SecurityMode"] = ( - smb.SMB.SECURITY_AUTH_ENCRYPTED | smb.SMB.SECURITY_SHARE_USER - ) _dialects_parameters["ChallengeLength"] = 8 _dialects_data["Challenge"] = self.config.ntlm_challenge - # Realistic capabilities matching Windows pcap — NO CAP_EXTENDED_SECURITY + # Realistic capabilities matching Windows pcap - NO CAP_EXTENDED_SECURITY _dialects_parameters["Capabilities"] = SMB1_CAPABILITIES_BASE - # DomainName and ServerName — [MS-CIFS] §2.2.4.52.2 + # DomainName and ServerName - [MS-CIFS] §2.2.4.52.2 # Payload is the raw concatenation of DomainName + ServerName; # the virtual DomainName/ServerName fields are parse-time only. _dialects_data["Payload"] = smbserver.encodeSMBString( resp["Flags2"], cfg.smb_nb_domain ) + smbserver.encodeSMBString(resp["Flags2"], cfg.smb_nb_computer) + dialect_label = "NT LM 0.12 (non-extended)" - _dialects_parameters["DialectIndex"] = nt_lm_index - _dialects_parameters["MaxMpxCount"] = SMB1_MAX_MPX_COUNT - _dialects_parameters["MaxNumberVcs"] = 1 - _dialects_parameters["MaxBufferSize"] = SMB1_MAX_BUFFER_SIZE - _dialects_parameters["MaxRawSize"] = 65536 - _dialects_parameters["SessionKey"] = 0 - # [MS-CIFS] §2.2.4.52.2: SystemTime as FILETIME split into 32-bit words - _dialects_parameters["LowDateTime"] = server_time & 0xFFFFFFFF - _dialects_parameters["HighDateTime"] = (server_time >> 32) & 0xFFFFFFFF - _dialects_parameters["ServerTimeZone"] = 0 - - command = smb.SMBCommand(smb.SMB.SMB_COM_NEGOTIATE) - command["Data"] = _dialects_data - command["Parameters"] = _dialects_parameters - - self.logger.debug( - "SMB_COM_NEGOTIATE: selected dialect NT LM 0.12 (non-extended)", - is_server=True, - ) - self.client_info["smb_dialect"] = "NT LM 0.12 (non-extended)" - resp.addCommand(command) - self.send_data(resp.getData()) - return - - # Extended security common path - _dialects_parameters["DialectIndex"] = nt_lm_index - _dialects_parameters["SecurityMode"] = ( - smb.SMB.SECURITY_AUTH_ENCRYPTED | smb.SMB.SECURITY_SHARE_USER - ) - _dialects_parameters["MaxMpxCount"] = SMB1_MAX_MPX_COUNT - _dialects_parameters["MaxNumberVcs"] = 1 - _dialects_parameters["MaxBufferSize"] = SMB1_MAX_BUFFER_SIZE - _dialects_parameters["MaxRawSize"] = 65536 - _dialects_parameters["SessionKey"] = 0 - # SystemTime must be current FILETIME — [MS-CIFS] §2.2.4.52.2 - _dialects_parameters["LowDateTime"] = server_time & 0xFFFFFFFF - _dialects_parameters["HighDateTime"] = (server_time >> 32) & 0xFFFFFFFF - _dialects_parameters["ServerTimeZone"] = 0 + self._fill_smb1_negotiate_params(_dialects_parameters, nt_lm_index, server_time) command = smb.SMBCommand(smb.SMB.SMB_COM_NEGOTIATE) command["Data"] = _dialects_data command["Parameters"] = _dialects_parameters self.logger.debug( - "SMB_COM_NEGOTIATE: selected dialect NT LM 0.12", is_server=True + "SMB_COM_NEGOTIATE: selected dialect %s", dialect_label, is_server=True ) - self.client_info["smb_dialect"] = "NT LM 0.12" + self.client_info["smb_dialect"] = dialect_label resp.addCommand(command) self.send_data(resp.getData()) @@ -1385,7 +1387,7 @@ def handle_ntlmssp( self.logger.debug(f"<{command_name}> GSSAPI negTokenInit", is_client=True) try: neg_token = spnego.SPNEGO_NegTokenInit(data=token) - except Exception as e: + except _PARSE_ERRORS as e: self.logger.debug(f"Invalid GSSAPI token: {e}") raise BaseProtoHandler.TerminateConnection from None @@ -1410,7 +1412,7 @@ def handle_ntlmssp( self.logger.debug(f"<{command_name}> GSSAPI negTokenArg", is_client=True) try: neg_token = spnego.SPNEGO_NegTokenResp(data=token) - except Exception as e: + except _PARSE_ERRORS as e: self.logger.debug(f"Invalid GSSAPI token: {e}") raise BaseProtoHandler.TerminateConnection from None token = neg_token["ResponseToken"] @@ -1419,9 +1421,6 @@ def handle_ntlmssp( self.logger.fail(f"<{command_name}> Invalid NTLM token length: {len(token)}") raise BaseProtoHandler.TerminateConnection - cfg = self.smb_config - error_code = cfg.smb_error_code - match token[8]: case 0x01: # [MS-NLMP] §2.2.1.1: NEGOTIATE_MESSAGE negotiate = ntlm.NTLMAuthNegotiate() @@ -1432,14 +1431,14 @@ def handle_ntlmssp( ) # NTLM-layer NEGOTIATE parsing and logging stays in ntlm.py. - # Store the returned dict to pass through to NTLM_handle_authenticate_message + # Store the returned dict to pass through to ntlm_handle_authenticate_message # for the deduped display line. Do NOT merge into client_info - # — the SMB display line uses only SMB-layer fields. - self.ntlm_negotiate_fields = NTLM_handle_negotiate_message( + # - the SMB display line uses only SMB-layer fields. + self.ntlm_negotiate_fields = ntlm_handle_negotiate_message( negotiate, self.logger ) - challenge = NTLM_build_challenge_message( + challenge = ntlm_build_challenge_message( negotiate, challenge=self.config.ntlm_challenge, nb_computer=self.config.ntlm_nb_computer, @@ -1468,7 +1467,7 @@ def handle_ntlmssp( # [MS-SMB2] §3.3.5.5.3: auth still in progress error_code = nt_errors.STATUS_MORE_PROCESSING_REQUIRED - case 0x02: # [MS-NLMP] §2.2.1.2: CHALLENGE_MESSAGE — unexpected + case 0x02: # [MS-NLMP] §2.2.1.2: CHALLENGE_MESSAGE - unexpected if not is_gssapi: self.logger.debug( f"<{command_name}> NTLMSSP_CHALLENGE_MESSAGE", is_client=True @@ -1487,7 +1486,7 @@ def handle_ntlmssp( # NTLM-layer AUTHENTICATE parsing and logging in ntlm.py. # Returns True if real credentials were captured, False # for anonymous or parse failures. - captured = NTLM_handle_authenticate_message( + captured = ntlm_handle_authenticate_message( authenticate, challenge=self.config.ntlm_challenge, client=self.client_address, @@ -1497,13 +1496,13 @@ def handle_ntlmssp( ) if not captured: - # Anonymous probe or parse failure — reject so the + # Anonymous probe or parse failure - reject so the # client retries with real credentials (XP sends # anonymous first, then the real auth). error_code = nt_errors.STATUS_ACCESS_DENIED resp = build_neg_token_resp(NEG_STATE_REJECT) else: - # Real credentials captured — resolve error code. + # Real credentials captured - resolve error code. # Returns STATUS_ACCOUNT_DISABLED for multi-cred # intermediate attempts, STATUS_SUCCESS for final # (to let client proceed to TREE_CONNECT for path). @@ -1535,6 +1534,10 @@ def _resolve_auth_error_code(self) -> int: STATUS_ACCOUNT_DISABLED to trigger retries, and the Nth capture returns STATUS_SUCCESS for the tree connect path capture. + When ``ErrorCode`` has been confiured to a valid value, it will + be returned regardless of whether ``CapturesPerConnection`` is + set. + :return: NTSTATUS code -- STATUS_ACCOUNT_DISABLED for intermediate attempts, or STATUS_SUCCESS for the final attempt (to allow tree connect path capture) @@ -1542,6 +1545,7 @@ def _resolve_auth_error_code(self) -> int: """ self.auth_attempt_count += 1 max_captures = self.smb_config.smb_captures_per_connection + error_code: int = nt_errors.STATUS_SUCCESS if max_captures > 0 and self.auth_attempt_count < max_captures: self.logger.debug( @@ -1551,17 +1555,22 @@ def _resolve_auth_error_code(self) -> int: max_captures, is_server=True, ) - return STATUS_ACCOUNT_DISABLED + error_code = STATUS_ACCOUNT_DISABLED - # Return SUCCESS to let the client proceed to TREE_CONNECT, - # where we capture the share path before returning the real - # error code. See handle_smb2_tree_connect / handle_smb1_tree_connect. - self.logger.debug( - "ErrorCode=0x%08x (STATUS_SUCCESS, awaiting tree connect)", - 0, - is_server=True, - ) - return nt_errors.STATUS_SUCCESS + if self.smb_config.smb_error_code is not None: + # Allow custom override via config + error_code = self.smb_config.smb_error_code + + if error_code == nt_errors.STATUS_SUCCESS: + # Return SUCCESS to let the client proceed to TREE_CONNECT, + # where we capture the share path before returning the real + # error code. See handle_smb2_tree_connect / handle_smb1_tree_connect. + self.logger.debug( + "ErrorCode=0x%08x (STATUS_SUCCESS, awaiting tree connect)", + 0, + is_server=True, + ) + return error_code # -- SMB2 Session -- @@ -1581,7 +1590,7 @@ def handle_smb2_session_setup(self, packet: smb2.SMB2Packet) -> None: """ req = smb2.SMB2SessionSetup(data=packet["Data"]) - # Log PreviousSessionId — [MS-SMB2] §2.2.5 + # Log PreviousSessionId - [MS-SMB2] §2.2.5 try: prev_session: int = req["PreviousSessionId"] prev_str = f"0x{prev_session:016x}" if prev_session else "(empty)" @@ -1589,7 +1598,7 @@ def handle_smb2_session_setup(self, packet: smb2.SMB2Packet) -> None: f"SMB2_SESSION_SETUP: PreviousSessionId={prev_str}", is_client=True, ) - except Exception: + except _PARSE_ERRORS: self.logger.debug("Failed to extract PreviousSessionId", exc_info=True) command = smb2.SMB2SessionSetup_Response() @@ -1608,17 +1617,17 @@ def handle_smb2_session_setup(self, packet: smb2.SMB2Packet) -> None: # # Three-tier decision using SIGNING_REQUIRED + client max dialect: # - # 1. SIGNING_REQUIRED set → never IS_GUEST + # 1. SIGNING_REQUIRED set -> never IS_GUEST # §3.2.5.3.1: IS_GUEST + SigningRequired = client MUST fail. # Future-proofing for Win11 24H2+ / Server 2025. # - # 2. Client max dialect ≤ 3.0.2 → IS_GUEST + # 2. Client max dialect ≤ 3.0.2 -> IS_GUEST # These clients (Win8.1, Srv2012R2, Srv2016) have - # AllowInsecureGuestAccess=TRUE → IS_GUEST accepted → ✓ + # AllowInsecureGuestAccess=TRUE -> IS_GUEST accepted -> ✓ # - # 3. Client max dialect ≥ 3.1.1 → no IS_GUEST + # 3. Client max dialect ≥ 3.1.1 -> no IS_GUEST # These clients (Win10, Win11, Srv2019, Srv2022) have - # AllowInsecureGuestAccess=FALSE → IS_GUEST rejected → H. + # AllowInsecureGuestAccess=FALSE -> IS_GUEST rejected -> H. # Without IS_GUEST at 2.x they get P (path from # TREE_CONNECT before VALIDATE_NEGOTIATE RST). if error_code == nt_errors.STATUS_SUCCESS: @@ -1705,7 +1714,7 @@ def handle_smb1_session_setup(self, packet: smb.NewSMBPacket) -> None: # [MS-CIFS] §2.2.4.53.1: Unicode strings are 2-byte aligned # from the start of the SMB header. Fixed overhead for # WordCount=12: 32(hdr)+1(WC)+24(params)+2(BC) = 59 (odd). - # Padding needed when (59 + blob_len) is odd → blob_len even. + # Padding needed when (59 + blob_len) is odd -> blob_len even. # Cannot check byte value: NT 4.0 uses non-zero pad bytes. needs_pad = is_unicode and blob_len % 2 == 0 if needs_pad and len(raw_after_blob) > 0: @@ -1752,7 +1761,7 @@ def handle_smb1_session_setup(self, packet: smb.NewSMBPacket) -> None: error_code=error_code, ) elif command["WordCount"] == 13: - # Non-extended security — [MS-CIFS] §2.2.4.53.1 + # Non-extended security - [MS-CIFS] §2.2.4.53.1 self.handle_smb1_session_setup_basic(packet, command) else: self.logger.warning( @@ -1797,7 +1806,7 @@ def handle_smb1_session_setup_basic( oem_len: int = setup_params["AnsiPwdLength"] uni_len: int = setup_params["UnicodePwdLength"] is_unicode = bool(packet["Flags2"] & smb.SMB.FLAGS2_UNICODE) - # [MS-CIFS] §2.2.4.53.1 — manually parse the data section. + # [MS-CIFS] §2.2.4.53.1 - manually parse the data section. # impacket's AsciiStructure truncates at \x00 (wrong for Unicode) # and UnicodeStructure decodes as UTF-16BE (impacket bug). raw_data: bytes = command["Data"] @@ -1805,9 +1814,9 @@ def handle_smb1_session_setup_basic( oem_pwd: bytes = raw_data[:oem_len] if oem_len else b"" uni_pwd: bytes = raw_data[oem_len : oem_len + uni_len] if uni_len else b"" - # Determine transport type FIRST — needed for string parsing. + # Determine transport type FIRST - needed for string parsing. if oem_len == 0 and uni_len == 0: - # Anonymous — no credentials at all + # Anonymous - no credentials at all transport: str | None = None elif uni_len == 0 and oem_len <= 1 and oem_pwd in (b"", b"\x00"): # NT 4.0 null session: OemPwdLen=1 with value \x00. @@ -1819,7 +1828,7 @@ def handle_smb1_session_setup_basic( # Only Unicode populated with non-standard length or (oem_len == 0 and uni_len not in (0, 24) and uni_len <= 512) ): - # Unexpected plaintext despite challenge — [MS-CIFS] §3.2.4.2.4 + # Unexpected plaintext despite challenge - [MS-CIFS] §3.2.4.2.4 self.logger.debug( "SMB_COM_SESSION_SETUP_ANDX: plaintext password detected " "despite challenge (unusual client behavior)", @@ -1903,7 +1912,7 @@ def handle_smb1_session_setup_basic( ct_extras["os"] = client_os if client_lanman: ct_extras["lanman"] = client_lanman - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name=account, domain_name=domain, lm_response=None, @@ -1922,7 +1931,7 @@ def handle_smb1_session_setup_basic( extras["os"] = client_os if client_lanman: extras["lanman"] = client_lanman - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name=account, domain_name=domain, lm_response=oem_pwd, @@ -1935,7 +1944,7 @@ def handle_smb1_session_setup_basic( extras=extras or None, ) else: - # Anonymous — reject to force the client to retry with real + # Anonymous - reject to force the client to retry with real # credentials. Without this, XP's redirector uses the anonymous # session for the share and never sends real hashes. self.logger.debug( @@ -1962,11 +1971,11 @@ def handle_smb1_session_setup_basic( ) return - # Allocate Uid for this session — [MS-SMB] §3.3.5.3 + # Allocate Uid for this session - [MS-SMB] §3.3.5.3 if self.smb1_uid == 0: self.smb1_uid = secrets.randbelow(0xFFFE) + 1 - # Build response — [MS-CIFS] §2.2.4.53.2 (WordCount=3) + # Build response - [MS-CIFS] §2.2.4.53.2 (WordCount=3) resp_params = smb.SMBSessionSetupAndXResponse_Parameters() resp_data = smb.SMBSessionSetupAndXResponse_Data(flags=packet["Flags2"]) resp_params["Action"] = 0 @@ -1980,7 +1989,7 @@ def handle_smb1_session_setup_basic( packet["Flags2"], cfg.smb_nb_domain ) - # Determine error code — multi-cred or final + # Determine error code - multi-cred or final error_code = self._resolve_auth_error_code() self.send_smb1_command( @@ -2040,35 +2049,6 @@ def _extract_smb2_tree_path(self, packet: smb2.SMB2Packet) -> str: ) return "" - def _send_smb2_tree_connect_response( - self, packet: smb2.SMB2Packet, resp: typing.Any, tree_id: int - ) -> None: - """Send an SMB2 TREE_CONNECT response with a server-assigned TreeID. - - Uses manual SMB2Packet construction rather than :meth:`send_smb2_command` - because the TreeID in the response must be the server-assigned value, - not the echoed value from the request. - - :param packet: The original TREE_CONNECT request - :type packet: smb2.SMB2Packet - :param resp: The populated SMB2TreeConnect_Response structure - :type resp: typing.Any - :param tree_id: Server-assigned TreeID for this tree connect - :type tree_id: int - """ - smb2_resp = smb2.SMB2Packet() - smb2_resp["Flags"] = smb2.SMB2_FLAGS_SERVER_TO_REDIR - smb2_resp["Status"] = nt_errors.STATUS_SUCCESS - smb2_resp["Command"] = packet["Command"] - smb2_resp["CreditCharge"] = packet["CreditCharge"] - smb2_resp["Reserved"] = packet["Reserved"] - smb2_resp["SessionID"] = self.smb2_session_id - smb2_resp["MessageID"] = packet["MessageID"] - smb2_resp["TreeID"] = tree_id - smb2_resp["CreditRequestResponse"] = 32 - smb2_resp["Data"] = resp.getData() - self.send_data(smb2_resp.getData()) - def handle_smb2_tree_connect(self, packet: smb2.SMB2Packet) -> None: r"""SMB2 TREE_CONNECT handler -- [MS-SMB2] §3.3.5.7. @@ -2099,7 +2079,7 @@ def handle_smb2_tree_connect(self, packet: smb2.SMB2Packet) -> None: exc_info=True, ) - # Extract share name from UNC path (\\server\share → share) + # Extract share name from UNC path (\\server\share -> share) share_name = path.rsplit("\\", 1)[-1].upper() if path else "" self.smb2_tree_id_counter += 1 @@ -2116,7 +2096,7 @@ def handle_smb2_tree_connect(self, packet: smb2.SMB2Packet) -> None: is_server=True, ) else: - # Non-IPC$ disk share — capture the path for intelligence + # Non-IPC$ disk share - capture the path for intelligence if path: self.client_info["smb_path"] = path resp["ShareType"] = 0x01 # SMB2_SHARE_TYPE_DISK @@ -2129,7 +2109,7 @@ def handle_smb2_tree_connect(self, packet: smb2.SMB2Packet) -> None: is_server=True, ) - self._send_smb2_tree_connect_response(packet, resp, self.smb2_tree_id_counter) + self.send_smb2_command(resp.getData(), packet, tree_id=self.smb2_tree_id_counter) def handle_smb2_tree_disconnect(self, packet: smb2.SMB2Packet) -> None: """SMB2 TREE_DISCONNECT handler -- [MS-SMB2] §3.3.5.8. @@ -2172,7 +2152,7 @@ def handle_smb2_ioctl(self, packet: smb2.SMB2Packet) -> None: req = smb2.SMB2Ioctl(packet["Data"]) ctl_code = req["CtlCode"] self.logger.debug("SMB2_IOCTL CtlCode=0x%08x", ctl_code, is_client=True) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB2_IOCTL (malformed)", is_client=True) self._smb2_error_response(packet, nt_errors.STATUS_INVALID_PARAMETER) return @@ -2181,7 +2161,7 @@ def handle_smb2_ioctl(self, packet: smb2.SMB2Packet) -> None: # [MS-SMB2] §3.3.5.15.12: echo back server negotiate params self._handle_validate_negotiate(packet, req) else: - # [MS-SMB2] §3.3.5.15.2: non-DFS → STATUS_FS_DRIVER_REQUIRED + # [MS-SMB2] §3.3.5.15.2: non-DFS -> STATUS_FS_DRIVER_REQUIRED self._smb2_error_response(packet, nt_errors.STATUS_FS_DRIVER_REQUIRED) def _handle_validate_negotiate( @@ -2234,7 +2214,7 @@ def _handle_validate_negotiate( ) self.send_smb2_command(resp.getData(), packet) - except Exception: + except _PARSE_ERRORS: self.logger.debug("FSCTL_VALIDATE_NEGOTIATE_INFO failed", exc_info=True) self._smb2_error_response(packet, nt_errors.STATUS_ACCESS_DENIED) @@ -2253,12 +2233,13 @@ def handle_smb1_tree_connect(self, packet: smb.NewSMBPacket) -> None: :param packet: Parsed SMB1 packet from the client :type packet: smb.NewSMBPacket """ + path = "" try: # [MS-CIFS] §2.2.4.55.1: SMB_COM_TREE_CONNECT_ANDX Request # Use impacket for Parameters parsing (PasswordLength), but # extract the Path manually because impacket's # SMBTreeConnectAndX_Data has no alignment-pad field between - # Password and Path — when PasswordLength causes an odd SMB + # Password and Path - when PasswordLength causes an odd SMB # offset, the client inserts a pad byte that impacket's 'u' # format parser includes in the Path, producing garbled # UTF-16LE. This only happens with even PasswordLength @@ -2320,40 +2301,30 @@ def handle_smb1_tree_connect(self, packet: smb.NewSMBPacket) -> None: resp_data = smb.SMBTreeConnectAndXResponse_Data(flags=packet["Flags2"]) if share_name == "IPC$": - # Accept IPC$ so the client can proceed to the real share self.logger.debug("SMB1 TREE_CONNECT IPC$ accepted", is_server=True) resp_data["Service"] = b"IPC\x00" - resp_data["NativeFileSystem"] = smbserver.encodeSMBString( - packet["Flags2"], "" - ) - self.send_smb1_command( - smb.SMB.SMB_COM_TREE_CONNECT_ANDX, - resp_data, - resp_params, - packet, - ) else: - # Non-IPC$ disk share — accept so client proceeds to + # Non-IPC$ disk share - accept so client proceeds to # NT_CREATE / READ, allowing filename capture. + if path: + self.client_info["smb_path"] = path + resp_data["Service"] = b"A:\x00" self.logger.debug( "SMB1 TREE_CONNECT share accepted (path=%s)", path, is_server=True ) - resp_data["Service"] = b"A:\x00" - resp_data["NativeFileSystem"] = smbserver.encodeSMBString( - packet["Flags2"], "" - ) - self.send_smb1_command( - smb.SMB.SMB_COM_TREE_CONNECT_ANDX, - resp_data, - resp_params, - packet, - ) + resp_data["NativeFileSystem"] = smbserver.encodeSMBString(packet["Flags2"], "") + self.send_smb1_command( + smb.SMB.SMB_COM_TREE_CONNECT_ANDX, + resp_data, + resp_params, + packet, + ) def handle_smb1_tree_disconnect(self, packet: smb.NewSMBPacket) -> None: """SMB1 TREE_DISCONNECT handler -- [MS-CIFS] §3.3.5.29. Acknowledges tree disconnect requests. ``SMB_COM_TREE_DISCONNECT`` - is NOT an AndX command — the response has zero parameter words + is NOT an AndX command - the response has zero parameter words and zero data bytes. :param packet: Parsed SMB1 packet from the client @@ -2397,7 +2368,7 @@ def handle_smb2_create(self, packet: smb2.SMB2Packet) -> None: self.logger.debug("SMB2_CREATE Name=%s", name or "(empty)", is_client=True) if name: self.client_files.add(name) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB2_CREATE (malformed)", is_client=True) # Allocate a sequential volatile FileId @@ -2458,7 +2429,7 @@ def handle_smb2_query_directory(self, packet: smb2.SMB2Packet) -> None: end = min(name_offset + name_length, len(raw)) pattern = raw[name_offset:end].decode("utf-16-le", errors="replace") self.logger.debug("SMB2_QUERY_DIRECTORY Pattern=%s", pattern, is_client=True) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB2_QUERY_DIRECTORY (malformed)", is_client=True) self._smb2_error_response(packet, nt_errors.STATUS_NO_MORE_FILES) @@ -2491,7 +2462,7 @@ def handle_smb2_query_info(self, packet: smb2.SMB2Packet) -> None: file_info_class, is_client=True, ) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB2_QUERY_INFO (malformed)", is_client=True) self._smb2_error_response(packet, nt_errors.STATUS_INVALID_PARAMETER) return @@ -2644,7 +2615,7 @@ def handle_smb2_read(self, packet: smb2.SMB2Packet) -> None: req["Length"], is_client=True, ) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB2_READ (malformed)", is_client=True) self._smb2_error_response(packet, nt_errors.STATUS_END_OF_FILE) @@ -2658,14 +2629,14 @@ def handle_smb2_close(self, packet: smb2.SMB2Packet) -> None: :type packet: smb2.SMB2Packet """ self.logger.debug("SMB2_CLOSE", is_client=True) - # SMB2Close_Response has all zeros for timestamps/sizes — spec-compliant + # SMB2Close_Response has all zeros for timestamps/sizes - spec-compliant resp = smb2.SMB2Close_Response() self.send_smb2_command(resp.getData(), packet) def handle_smb2_write(self, packet: smb2.SMB2Packet) -> None: """SMB2 WRITE handler -- [MS-SMB2] §3.3.5.13. - Acknowledges write requests. No data is actually written — the + Acknowledges write requests. No data is actually written - the fake files are read-only scaffolding. Returns the requested byte count as ``Count`` so the client believes the write succeeded. @@ -2679,7 +2650,7 @@ def handle_smb2_write(self, packet: smb2.SMB2Packet) -> None: self.logger.debug( "SMB2_WRITE Length=%d Offset=%d", count, req["Offset"], is_client=True ) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB2_WRITE (malformed)", is_client=True) resp = smb2.SMB2Write_Response() resp["Count"] = count @@ -2688,7 +2659,7 @@ def handle_smb2_write(self, packet: smb2.SMB2Packet) -> None: def handle_smb2_flush(self, packet: smb2.SMB2Packet) -> None: """SMB2 FLUSH handler -- [MS-SMB2] §3.3.5.11. - Acknowledges flush requests. No data is actually flushed — the + Acknowledges flush requests. No data is actually flushed - the fake files have no backing store. Observed from Win8.1 and Srv2012R2 (SMB 3.0.2 IS_GUEST clients) after WRITE operations. @@ -2716,7 +2687,7 @@ def handle_smb2_set_info(self, packet: smb2.SMB2Packet) -> None: """SMB2 SET_INFO handler -- [MS-SMB2] §3.3.5.21. Acknowledges set-info requests. No attributes are actually - changed — the fake files are immutable scaffolding. Response + changed - the fake files are immutable scaffolding. Response is 2 bytes (StructureSize only). :param packet: Parsed SMB2 packet from the client @@ -2730,7 +2701,7 @@ def handle_smb2_set_info(self, packet: smb2.SMB2Packet) -> None: req["FileInfoClass"], is_client=True, ) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB2_SET_INFO (malformed)", is_client=True) resp = smb2.SMB2SetInfo_Response() self.send_smb2_command(resp.getData(), packet) @@ -2775,7 +2746,7 @@ def handle_smb1_nt_create(self, packet: smb.NewSMBPacket) -> None: ) if name: self.client_files.add(name) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB_COM_NT_CREATE_ANDX (malformed)", is_client=True) # Allocate a sequential FID @@ -2833,7 +2804,7 @@ def _send_smb1_trans2_response( """ # Absolute offsets from SMB header start # 32(hdr) + 1(WC) + 20(Words) + 2(BC) = 55 - pad1 = b"\x00" # align to even offset (55 → 56) + pad1 = b"\x00" # align to even offset (55 -> 56) param_offset = 56 if trans_parameters else 0 param_len = len(trans_parameters) @@ -2892,37 +2863,37 @@ def _build_trans2_file_info(self, info_level: int) -> bytes | None: # noqa: PLR # Normalise pass-through levels to native NT info class. # Raw FileInformationClass values (< 0x03E8) pass through unchanged, - # which is correct — XP SP3 sends them without the 0x03E8 base. + # which is correct - XP SP3 sends them without the 0x03E8 base. native = info_level if info_level >= PASS_THROUGH_BASE: native = info_level - PASS_THROUGH_BASE # ── CIFS-native levels ([MS-CIFS] §2.2.8.3) ────────────────────── - # SMB_INFO_STANDARD (0x0001/0x0100) — NT 4.0 + # SMB_INFO_STANDARD (0x0001/0x0100) - NT 4.0 # [MS-CIFS] §2.2.8.3.1: 3×(Date+Time) + DataSize + AllocationSize + Attributes if info_level in {0x0001, 0x0100}: return b"\x00" * 22 - # SMB_INFO_QUERY_EA_SIZE (0x0002/0x0200) — NT 4.0 EA query + # SMB_INFO_QUERY_EA_SIZE (0x0002/0x0200) - NT 4.0 EA query if info_level in {0x0002, 0x0200}: return b"\x00" * 26 # 22 (standard) + 4 (EaSize) - # SMB_INFO_QUERY_EAS_FROM_LIST (0x0003) — Srv2003 + # SMB_INFO_QUERY_EAS_FROM_LIST (0x0003) - Srv2003 # [MS-CIFS] §2.2.8.3.3: return empty EA list (4-byte size = 0) if info_level == 0x0003: return b"\x00" * 4 # 0x0006: dual meaning depending on TRANS2 subcommand: - # QUERY_PATH_INFORMATION: SMB_INFO_IS_NAME_VALID — empty SUCCESS - # QUERY_FILE_INFORMATION: FileInternalInformation (class 6) — 8-byte file ID + # QUERY_PATH_INFORMATION: SMB_INFO_IS_NAME_VALID - empty SUCCESS + # QUERY_FILE_INFORMATION: FileInternalInformation (class 6) - 8-byte file ID # Since both return SUCCESS and the 8-byte response is a superset of # the empty response, always return 8 bytes. [MS-FSCC] §2.4.20. if info_level == 0x0006: return b"\x00" * 8 # SMB_QUERY_FILE_EA_INFO (0x0103) / FileEaInformation (class 7) - # [MS-FSCC] §2.4.12: EaSize(4) — no EAs on fake files + # [MS-FSCC] §2.4.12: EaSize(4) - no EAs on fake files if native == 7 or info_level == 0x0103: return b"\x00" * 4 @@ -2975,41 +2946,9 @@ def _build_trans2_file_info(self, info_level: int) -> bytes | None: # noqa: PLR file_info["Directory"] = 0 return file_info.getData() - # FileInternalInformation (class 6) — XP SP3/SP0/Srv2003 - # [MS-FSCC] §2.4.20: IndexNumber(8) — unique file ID - if native == 6: - return b"\x00" * 8 - - # FilePositionInformation (class 11/0x0b) — XP SP3 - # [MS-FSCC] §2.4.32: CurrentByteOffset(8) - if native == 11: - return b"\x00" * 8 - - # FileNamesInformation (class 12/0x0c) — XP SP3/Srv2003 - # [MS-FSCC] §2.4.28: NextEntryOffset(4) + FileIndex(4) + - # FileNameLength(4) + FileName(variable) — return empty entry - if native == 12: - return b"\x00" * 12 - - # FileModeInformation (class 13/0x0d) — XP SP3/SP0 - # [MS-FSCC] §2.4.26: Mode(4) - if native == 13: - return b"\x00" * 4 - - # FileAlignmentInformation (class 14/0x0e) — XP SP3/SP0 - # [MS-FSCC] §2.4.3: AlignmentRequirement(4) — 0 = byte-aligned - if native == 14: - return b"\x00" * 4 - - # FileAllocationInformation (class 16/0x10) — XP SP3/SP0 - # This is a SET class per spec, but XP queries it. - # [MS-FSCC] §2.4.4: AllocationSize(8) - if native == 16: - return b"\x00" * 8 - # FileNetworkOpenInformation (class 34/0x22 or raw 0x0026=38) # [MS-FSCC] §2.4.29: 4×FILETIME + sizes + attributes (56 bytes) - # Note: 0x0026 = 38 decimal — observed from XP SP3 as raw class. + # Note: 0x0026 = 38 decimal - observed from XP SP3 as raw class. if native in {34, 38}: info = smb.SMBFileNetworkOpenInfo() info["CreationTime"] = now @@ -3021,30 +2960,34 @@ def _build_trans2_file_info(self, info_level: int) -> bytes | None: # noqa: PLR info["FileAttributes"] = smb2.FILE_ATTRIBUTE_ARCHIVE return info.getData() - # FilePipeInformation (class 23/0x17) — XP SP3 on IPC$ - # [MS-FSCC] §2.4.31: ReadMode(4) + CompletionMode(4) - if native == 23: + # Zero-filled responses grouped by size: + # + # 8 bytes: FileInternalInformation (6) [MS-FSCC] §2.4.20: IndexNumber + # FilePositionInformation (11) [MS-FSCC] §2.4.32: CurrentByteOffset + # FileAllocationInformation (16) [MS-FSCC] §2.4.4: AllocationSize + # FilePipeInformation (23) [MS-FSCC] §2.4.31: ReadMode+CompletionMode + if native in {6, 11, 16, 23}: return b"\x00" * 8 - # FilePipeLocalInformation (class 24/0x18) — XP SP3 on IPC$ - # [MS-FSCC] §2.4.30: 9 × ULONG (36 bytes) - if native == 24: - return b"\x00" * 36 - - # FilePipeRemoteInformation (class 25/0x19) — XP SP3 on IPC$ - # [MS-FSCC] §2.4.31: CollectDataTime(8) + MaximumCollectionCount(4) - if native == 25: + # 12 bytes: FileNamesInformation (12) [MS-FSCC] §2.4.28: NextOffset+Index+NameLen + # FilePipeRemoteInformation (25) [MS-FSCC] §2.4.31: CollectTime+MaxCount + if native in {12, 25}: return b"\x00" * 12 - # FileMailslotQueryInformation (class 26/0x1a) — XP SP3 - # Not defined in [MS-FSCC]; return minimal 4-byte response - if native == 26: + # 4 bytes: FileModeInformation (13) [MS-FSCC] §2.4.26: Mode + # FileAlignmentInformation (14) [MS-FSCC] §2.4.3: AlignmentRequirement + # FileMailslotQueryInformation (26) - minimal response (XP SP3) + if native in {13, 14, 26}: return b"\x00" * 4 + # 36 bytes: FilePipeLocalInformation (24) [MS-FSCC] §2.4.30: 9 × ULONG + if native == 24: + return b"\x00" * 36 + # ── Samba Unix extensions ───────────────────────────────────────── # SMB_QUERY_FILE_UNIX_BASIC (0x0120) - # Samba extension — smbclient sends this before READ on NT1. + # Samba extension - smbclient sends this before READ on NT1. if info_level == 0x0120: return b"\x00" * 100 @@ -3078,7 +3021,7 @@ def handle_smb1_trans2(self, packet: smb.NewSMBPacket) -> None: self.logger.debug( "SMB_COM_TRANSACTION2 Subcommand=0x%04x", subcommand, is_client=True ) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB_COM_TRANSACTION2 (malformed)", is_client=True) if subcommand == smb.SMB.TRANS2_QUERY_PATH_INFORMATION: @@ -3148,7 +3091,7 @@ def handle_smb1_trans2(self, packet: smb.NewSMBPacket) -> None: ) elif subcommand == smb.SMB.TRANS2_QUERY_FS_INFORMATION: # [MS-CIFS] §2.2.6.4: NT 4.0 queries filesystem info after - # tree connect. Return empty success — the info level doesn't + # tree connect. Return empty success - the info level doesn't # matter for a capture server; the client proceeds regardless. self.logger.debug("TRANS2_QUERY_FS_INFORMATION", is_client=True) self._send_smb1_trans2_response( @@ -3188,7 +3131,7 @@ def handle_smb1_read(self, packet: smb.NewSMBPacket) -> None: params["Offset"], is_client=True, ) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB_COM_READ_ANDX (malformed)", is_client=True) self.send_smb1_command( smb.SMB.SMB_COM_READ_ANDX, @@ -3202,7 +3145,7 @@ def handle_smb1_close(self, packet: smb.NewSMBPacket) -> None: """SMB1 CLOSE handler -- [MS-CIFS] §3.3.5.27. Acknowledges close requests. ``SMB_COM_CLOSE`` is NOT an AndX - command — the response has zero parameter words and zero data + command - the response has zero parameter words and zero data bytes. :param packet: Parsed SMB1 packet from the client @@ -3212,7 +3155,7 @@ def handle_smb1_close(self, packet: smb.NewSMBPacket) -> None: cmd = smb.SMBCommand(packet["Data"][0]) params = smb.SMBClose_Parameters(cmd["Parameters"]) self.logger.debug("SMB_COM_CLOSE Fid=%d", params["FID"], is_client=True) - except Exception: + except _PARSE_ERRORS: self.logger.debug("SMB_COM_CLOSE (malformed)", is_client=True) self.send_smb1_command( smb.SMB.SMB_COM_CLOSE, @@ -3258,7 +3201,7 @@ def __init__( :type RequestHandlerClass: type | None, optional """ self.server_config = server_config - # Stable ServerGuid per server instance — [MS-SMB2] §2.2.4 + # Stable ServerGuid per server instance - [MS-SMB2] §2.2.4 self.server_guid: bytes = secrets.token_bytes(16) super().__init__(config, server_address, RequestHandlerClass) diff --git a/dementor/protocols/smtp.py b/dementor/protocols/smtp.py index faea35b..ae2fe34 100644 --- a/dementor/protocols/smtp.py +++ b/dementor/protocols/smtp.py @@ -48,14 +48,15 @@ from dementor.config.toml import TomlConfig, Attribute as A from dementor.config.session import SessionConfig +from dementor.config.util import HostValue, HostFallbackValue from dementor.log.logger import ProtocolLogger, dm_logger from dementor.protocols.ntlm import ( - NTLM_build_challenge_message, - NTLM_handle_authenticate_message, - NTLM_handle_negotiate_message, + ntlm_build_challenge_message, + ntlm_handle_authenticate_message, + ntlm_handle_negotiate_message, ) -from dementor.db import _CLEARTEXT -from dementor.servers import AsyncServerThread +from dementor.db import CLEARTEXT +from dementor.servers import AsyncServerThread, BaseServerThread from dementor.loader import BaseProtocolModule, DEFAULT_ATTR __proto__ = ["SMTP"] @@ -82,7 +83,13 @@ class SMTPServerConfig(TomlConfig): _fields_ = [ A("smtp_port", "Port"), A("smtp_tls", "TLS", False), - A("smtp_fqdn", "FQDN", "DEMENTOR", section_local=False), + A( + "smtp_fqdn", + "Host", + None, + section_local=False, + factory=HostFallbackValue(HostValue.HOST, "DEMENTOR"), + ), A("smtp_ident", "Ident", "Dementor 1.0dev0"), A("smtp_downgrade", "Downgrade", False), A("smtp_auth_mechanisms", "AuthMechanisms", list), @@ -108,12 +115,14 @@ class SMTPServerConfig(TomlConfig): class SMTP(BaseProtocolModule[SMTPServerConfig]): name = "SMTP" config_ty = SMTPServerConfig - config_attr = "smtp_servers" + config_attr = DEFAULT_ATTR config_enabled_attr = DEFAULT_ATTR config_list = True @override - def create_server_thread(self, session, server_config): + def create_server_thread( + self, session: SessionConfig, server_config: SMTPServerConfig + ) -> BaseServerThread[SMTPServerConfig]: return SMTPServerThread(session, server_config) @@ -144,15 +153,8 @@ def __call__( ) -> AuthResult: match auth_data: case NTLMAuth(): - # successful NTLM authentication - # self.config.db.add_auth( - # client=session.peer, - # credtype=auth_data.hash_version, - # password=auth_data.hash_string, - # logger=self.logger, - # username=auth_data.user_name, - # domain=auth_data.domain_name, - # ) + # Credentials are captured upstream in chapture_ntlm_auth via + # ntlm_handle_authenticate_message; nothing to do here. pass case LoginPassword(): @@ -161,7 +163,7 @@ def __call__( password = auth_data.password.decode(errors="replace") self.config.db.add_auth( client=session.peer, - credtype=_CLEARTEXT, + credtype=CLEARTEXT, password=password, logger=self.logger, username=username, @@ -190,11 +192,6 @@ async def auth_plain( ) -> SMTP_AUTH_Result: return await server.auth_PLAIN(server, args) - async def auth_ntlm( - self, server: SMTPServerBase, args: list[str] - ) -> SMTP_AUTH_Result: - return await self.auth_NTLM(server, args) - async def auth_NTLM( self, server: SMTPServerBase, args: list[bytes] ) -> SMTP_AUTH_Result: @@ -239,16 +236,23 @@ async def chapture_ntlm_auth(self, server: SMTPServerBase, blob=None) -> Any: negotiate_message = NTLMAuthNegotiate() negotiate_message.fromString(blob) - negotiate_fields = NTLM_handle_negotiate_message(negotiate_message, self.logger) + negotiate_fields = ntlm_handle_negotiate_message(negotiate_message, self.logger) # now we can build the challenge using the answer flags - ntlm_challenge = NTLM_build_challenge_message( + host = HostValue(self.server_config.smtp_fqdn) + ntlm_challenge = ntlm_build_challenge_message( negotiate_message, challenge=self.config.ntlm_challenge, - nb_computer=self.config.ntlm_nb_computer, - nb_domain=self.config.ntlm_nb_domain, + nb_computer=host.get_value(HostValue.NETBIOS_COMPUTER), + nb_domain=host.get_value(HostValue.NETBIOS_DOMAIN), disable_ess=self.config.ntlm_disable_ess, disable_ntlmv2=self.config.ntlm_disable_ntlmv2, + target_type=self.config.ntlm_target_type, + version=self.config.ntlm_version, + dns_computer=host.get_value(HostValue.DNS_COMPUTER), + dns_domain=host.get_value(HostValue.DNS_DOMAIN), + # REVISIT: capture DNSTree too + # dns_tree=self.config.ntlm_dns_tree, log=self.logger, ) @@ -260,7 +264,7 @@ async def chapture_ntlm_auth(self, server: SMTPServerBase, blob=None) -> Any: # NTLM AUTHENTICATE_MESSAGE. auth_message = NTLMAuthChallengeResponse() auth_message.fromString(blob) - NTLM_handle_authenticate_message( + ntlm_handle_authenticate_message( auth_message, challenge=self.config.ntlm_challenge, client=server.session.peer, @@ -284,19 +288,21 @@ async def chapture_ntlm_auth(self, server: SMTPServerBase, blob=None) -> Any: class SMTPServerThread(AsyncServerThread[SMTPServerConfig]): - def __init__(self, config: SessionConfig, server_config: SMTPServerConfig): + def __init__(self, config: SessionConfig, server_config: SMTPServerConfig) -> None: super().__init__(config, server_config) self.controller: Controller | None = None self._running = False @override - def is_running(self): + def is_running(self) -> bool: return self._running - def get_service_name(self) -> str: + @property + def service_name(self) -> str: return "SMTPS" if self.server_config.smtp_tls else "SMTP" - def get_port(self): + @property + def get_port(self) -> int: return self.server_config.smtp_port def create_logger(self) -> ProtocolLogger: @@ -307,9 +313,9 @@ def create_logger(self) -> ProtocolLogger: } ) - async def start_server( + def start_server( self, controller: Controller, config: SessionConfig, smtp_config - ): + ) -> None: controller.port = smtp_config.smtp_port # NOTE: hostname on the controller points to the local address that will be @@ -358,7 +364,7 @@ async def arun(self) -> None: tls_context=tls_context, require_starttls=server.smtp_require_starttls, ) - await self.start_server( + self.start_server( self.controller, self.config, server, diff --git a/dementor/protocols/spnego.py b/dementor/protocols/spnego.py index 7484b67..c42519b 100644 --- a/dementor/protocols/spnego.py +++ b/dementor/protocols/spnego.py @@ -147,7 +147,7 @@ def _handle_neg_token_init(self, token: SPNEGO_NegTokenInit) -> tuple[bytes, boo ) # Build response - response = negTokenInit_step( + response = build_neg_token_resp( 0x00 if complete else 0x01, response_token, chosen_mech ) return response.getData(), complete @@ -182,9 +182,10 @@ def _handle_neg_token_resp(self, token: SPNEGO_NegTokenResp) -> tuple[bytes, boo ) # Build final response - response = negTokenInit_step(0x00 if complete else 0x01, final_token) + response = build_neg_token_resp(0x00 if complete else 0x01, final_token) return response.getData(), complete + # [RFC4178] §4.2.2 / [MS-SPNG]: negState enumeration values for NegTokenResp. # These indicate the outcome of each round of the SPNEGO exchange. NEG_STATE_ACCEPT_COMPLETED: int = 0 # Authentication succeeded, context established @@ -208,7 +209,7 @@ def build_neg_token_resp( Spec: [RFC4178] §4.2.2, [MS-SPNG] §3.2.5.2 - :param neg_state: Negotiation state — one of ``NEG_STATE_ACCEPT_COMPLETED``, + :param neg_state: Negotiation state - one of ``NEG_STATE_ACCEPT_COMPLETED``, ``NEG_STATE_ACCEPT_INCOMPLETE``, or ``NEG_STATE_REJECT`` :type neg_state: int :param resp_token: The mechanism-specific response token (e.g., serialized @@ -255,13 +256,6 @@ def build_neg_token_init(mech_types: list[str]) -> SPNEGO_NegTokenInit: :return: Populated NegTokenInit ready for serialization via ``.getData()`` :rtype: SPNEGO_NegTokenInit """ -def negTokenInit(mech_types: list[str]) -> SPNEGO_NegTokenInit: - """Create a NegTokenInit with specified mechanism types. - - :param list mech_types: List of mechanism names - :return: NegTokenInit structure - :rtype: SPNEGO_NegTokenInit - """ token_init = SPNEGO_NegTokenInit() token_init["MechTypes"] = [TypesMech[x] for x in mech_types] return token_init diff --git a/dementor/protocols/ssdp.py b/dementor/protocols/ssdp.py index b6bba68..f869bbc 100644 --- a/dementor/protocols/ssdp.py +++ b/dementor/protocols/ssdp.py @@ -87,7 +87,7 @@ class SSDP(BaseProtocolModule[SSDPConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: SSDPConfig - ) -> BaseServerThread: + ) -> BaseServerThread[SSDPConfig]: return ServerThread( session, server_config, @@ -165,7 +165,7 @@ def __init__(self, config, request, client_address, server) -> None: super().__init__(config, request, client_address, server) @property - def ssdp_config(self): + def ssdp_config(self) -> "SSDPConfig": return self.config.ssdp_config @property @@ -220,7 +220,7 @@ def handle_data(self, data, transport) -> None: return self.handle_search(transport) # [1.2 Advertisement] - def handle_advertisement(self): + def handle_advertisement(self) -> None: # All NOTIFY messages MUST store the NTS (notification subtype) match self.message["NTS"]: # [1.2.2 Device available - NOTIFY with ssdp:alive] @@ -241,7 +241,7 @@ def handle_advertisement(self): # REVISIT: shoul we report ssdp:update? # [1.3.2 Search request with M-SEARCH] - def handle_search(self, transport): + def handle_search(self, transport) -> None: target = self.message["ST"] or "uuid:invalid" target_text = "" search_tyoe = "" @@ -347,7 +347,7 @@ def handle_search(self, transport): f"Sent poisoned response to {self.client_host} for {target_text or search_tyoe}" ) - def describe_server(self, server: str) -> tuple: + def describe_server(self, server: str) -> tuple[str, str, str]: os_name = product_name = "" if "UPnP/" in server: os_name, _, product_name = server.partition("UPnP/") diff --git a/dementor/protocols/upnp.py b/dementor/protocols/upnp.py index beb0d2c..adb6b81 100644 --- a/dementor/protocols/upnp.py +++ b/dementor/protocols/upnp.py @@ -74,7 +74,7 @@ class UPNPConfig(TomlConfig): upnp_scpd_path: str upnp_present_path: str - def set_upnp_templates_path(self, path_list): + def set_upnp_templates_path(self, path_list) -> None: dirs = set() for templates_dir in path_list: path = pathlib.Path(templates_dir) @@ -89,7 +89,7 @@ def set_upnp_templates_path(self, path_list): dirs.add(HTTP_TEMPLATES_PATH) self.upnp_templates_path = list(dirs) - def set_upnp_template(self, template): + def set_upnp_template(self, template) -> None: upnp_template = None for templates_dir in self.upnp_templates_path: path = pathlib.Path(templates_dir) / template @@ -112,7 +112,7 @@ class UPnP(BaseProtocolModule[UPNPConfig]): @override def create_server_thread( self, session: SessionConfig, server_config: UPNPConfig - ) -> BaseServerThread: + ) -> BaseServerThread[UPNPConfig]: return ServerThread( session, server_config, @@ -150,7 +150,7 @@ def version_string(self) -> str: def log_message(self, format: str, *args) -> None: pass - def send_page(self, template, content_type): + def send_page(self, template, content_type) -> None: self.logger.debug(escape(f"{self.command} {self.path} 200"), is_server=True) path = pathlib.Path(template) script = self.server.render(path.name, uuid=self.target_uuid) @@ -162,7 +162,7 @@ def send_page(self, template, content_type): self.end_headers() self.wfile.write(body) - def do_GET(self): + def do_GET(self) -> None: self.logger.debug(f"Request for {self.path}", is_client=True) user_agent = escape(self.headers.get("User-Agent", "")) if len(user_agent) > 50: @@ -192,7 +192,7 @@ def do_GET(self): path = posixpath.normpath(path) try: - mime_type, _ = mimetypes.guess_file_type(path) + mime_type, _ = mimetypes.guess_file_type(path) # ty:ignore[unresolved-attribute] except AttributeError: mime_type, _ = mimetypes.guess_type(path) @@ -235,7 +235,7 @@ def finish_request(self, request, client_address) -> None: with contextlib.suppress(ConnectionError): self.RequestHandlerClass(self.config, request, client_address, self) - def render(self, template, **kwargs): + def render(self, template, **kwargs) -> str: return self.env.get_template(template).render( **kwargs, session=self.config, diff --git a/dementor/protocols/x11.py b/dementor/protocols/x11.py index 122df49..ae71750 100644 --- a/dementor/protocols/x11.py +++ b/dementor/protocols/x11.py @@ -36,7 +36,7 @@ ServerThread, BaseServerThread, ) -from dementor.db import _NO_USER +from dementor.db import NO_USER __proto__ = ["X11"] @@ -52,7 +52,7 @@ class X11Config(TomlConfig): x11_ports: range x11_error_reason: str - def set_x11_ports(self, port_range: str | range | dict[str, int]): + def set_x11_ports(self, port_range: str | range | dict[str, int]) -> None: x11_ports = None match port_range: case range(): @@ -129,7 +129,7 @@ def _wrap(context): return _wrap -def xConnClient_set_length(context): +def x_conn_client_set_length(context): self = context._obj self.nbytesAuthProto = len(self.authProto) self.nbytesAuthString = len(self.authString) @@ -143,7 +143,7 @@ class xConnClientPrefixLE: majorVersion: CARD16 minorVersion: CARD16 - _aset_length: py.Action(pack=xConnClient_set_length) + _aset_length: py.Action(pack=x_conn_client_set_length) nbytesAuthProto: CARD16 nbytesAuthString: CARD16 _pad2: py.padding[2] @@ -164,7 +164,7 @@ class xConnClientPrefixBE: _pad: py.padding[1] majorVersion: CARD16 minorVersion: CARD16 - _aset_length: py.Action(pack=xConnClient_set_length) + _aset_length: py.Action(pack=x_conn_client_set_length) nbytesAuthProto: CARD16 nbytesAuthString: CARD16 _pad2: py.padding[2] @@ -188,7 +188,7 @@ class xConnClientPrefixBE: # 2 (n+p)/4 length in 4-byte units of "additional data" # n STRING8 reason # p unused, p=pad(n) -def xConnSetup_set_length(context): +def x_conn_setup_set_length(context): self = context._obj self.lengthReason = len(self.reason) pad = (4 - (self.lengthReason % 4)) % 4 @@ -197,7 +197,7 @@ def xConnSetup_set_length(context): @py.struct(order=py.LittleEndian) class xConnSetupPrefixLE: - _aset_length: py.Action(pack=xConnSetup_set_length) + _aset_length: py.Action(pack=x_conn_setup_set_length) # Taken from Xproto.h#L286: # The protocol also defines a case of success == Authenticate, but @@ -213,7 +213,7 @@ class xConnSetupPrefixLE: @py.struct(order=py.BigEndian) class xConnSetupPrefixBE: - _aset_length: py.Action(pack=xConnSetup_set_length) + _aset_length: py.Action(pack=x_conn_setup_set_length) success: CARD8 = X_CONN_SUCCESS lengthReason: BYTE = 0 majorVersion: CARD16 @@ -265,7 +265,7 @@ def handle_data(self, data, transport) -> None: self.config.db.add_auth( client=self.client_address, credtype=request.authProto.decode(errors="replace").strip(), - username=_NO_USER, + username=NO_USER, password=request.authString.hex(), logger=self.logger, custom=True, diff --git a/dementor/servers.py b/dementor/servers.py index 69fa9a6..2f6f0d8 100644 --- a/dementor/servers.py +++ b/dementor/servers.py @@ -20,8 +20,6 @@ # pyright: reportAny=false, reportExplicitAny=false import contextlib import asyncio -from dementor.config.toml import TomlConfig -from asyncio import Task import traceback import pathlib import socket @@ -33,11 +31,13 @@ import sys from io import StringIO +from asyncio import Task from typing import Any, ClassVar, Generic from socketserver import BaseRequestHandler from typing_extensions import override, TypeVar from dementor import db +from dementor.config.toml import TomlConfig from dementor.log import hexdump from dementor.log.logger import ProtocolLogger, dm_logger from dementor.log.stream import log_host @@ -56,24 +56,15 @@ def __init__(self, config: SessionConfig, server_config: _ConfigTy) -> None: self.address: str | None = None super().__init__(daemon=False) - def get_service_name(self) -> str: - """Get the service name for logging purposes. - - This method should be overridden by subclasses to provide a specific service name. - :return: Service name string - :rtype: str - """ - raise NotImplementedError("get_service_name must be implemented by subclasses") - @property def service_name(self) -> str: - """Get the service name from server class or use class name as fallback. + """Return the display name for this service (used in logging). - :return: Service name. - :rtype: str + Subclasses must implement this property. """ - return self.get_service_name() + raise NotImplementedError("service_name must be implemented by subclasses") + @property def get_port(self) -> int: """Return the listening port of the server. @@ -88,6 +79,7 @@ def get_port(self) -> int: raise ValueError("Port not set - the server may not have been started yet.") return self.port + @property def get_address(self) -> str: """Return the bound address of the server. @@ -137,11 +129,8 @@ def task(self) -> Task[None]: return self._task async def arun(self) -> None: - """Asynchronous run method to start the server. - - This method should be overridden to implement the actual async server logic. - """ - # To be implemented with async server logic in the future + """Subclasses must override this to implement async server logic.""" + raise NotImplementedError("arun must be implemented by subclasses") def run(self) -> None: """Start the asynchronous server.""" @@ -462,19 +451,15 @@ def __init__( super().__init__(config, request, client_address, server) -class ThreadingUDPServer(socketserver.ThreadingMixIn, socketserver.UDPServer): - """Threaded UDP server with IPv6 support and cross-platform binding. - - :var default_port: Default port to listen on - :var default_handler_class: Handler class for processing requests - :var ipv4_only: Whether to only use IPv4 (skip IPv6) - """ +class _ThreadingServerInitMixin: + """Shared ``__init__`` logic for :class:`ThreadingUDPServer` and :class:`ThreadingTCPServer`.""" default_port: ClassVar[int] default_handler_class: ClassVar[type] + config: SessionConfig ipv4_only: bool - - allow_reuse_address: bool = True + stop_flag: threading.Event + address_family: int def __init__( self, @@ -482,26 +467,34 @@ def __init__( server_address: tuple[str, int] | None = None, RequestHandlerClass: type | None = None, ) -> None: - """Initialize the UDP server. - - :param config: Session configuration - :type config: SessionConfig - :param server_address: (host, port) tuple or None to use defaults - :type server_address: tuple[str, int] | None - :param RequestHandlerClass: Handler class or None to use default - :type RequestHandlerClass: type | None - """ - self.config: SessionConfig = config + """Initialize the server with session config, optional address and handler overrides.""" + self.config = config self.ipv4_only = getattr(config, "ipv4_only", False) self.stop_flag = threading.Event() if config.ipv6 and not self.ipv4_only: self.address_family = socket.AF_INET6 - - super().__init__( + super().__init__( # type: ignore[call-arg] server_address or (self.config.bind_address, self.default_port), RequestHandlerClass or self.default_handler_class, ) + +class ThreadingUDPServer( + _ThreadingServerInitMixin, socketserver.ThreadingMixIn, socketserver.UDPServer +): + """Threaded UDP server with IPv6 support and cross-platform binding. + + :var default_port: Default port to listen on + :var default_handler_class: Handler class for processing requests + :var ipv4_only: Whether to only use IPv4 (skip IPv6) + """ + + default_port: ClassVar[int] + default_handler_class: ClassVar[type] + ipv4_only: bool + + allow_reuse_address: bool = True + @override def server_bind(self) -> None: """Bind the server socket with interface and IPv6 settings.""" @@ -513,7 +506,7 @@ def finish_request( # pyright: ignore[reportIncompatibleMethodOverride] self, request: bytes, client_address: tuple[str, int], - ) -> None: + ) -> None: # ty:ignore[invalid-method-override] """Finish a single request by instantiating the handler. :param request: The request data @@ -555,7 +548,9 @@ def bind_server( dm_logger.warning(f"Failed to set IPV6_V6ONLY: {e}") -class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): +class ThreadingTCPServer( + _ThreadingServerInitMixin, socketserver.ThreadingMixIn, socketserver.TCPServer +): """Threaded TCP server with IPv6 support and cross-platform binding. :var default_port: Default port to listen on @@ -568,31 +563,6 @@ class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): ipv4_only: bool allow_reuse_address: bool = True - def __init__( - self, - config: SessionConfig, - server_address: tuple[str, int] | None = None, - RequestHandlerClass: type | None = None, - ) -> None: - """Initialize the TCP server. - - :param config: Session configuration - :type config: SessionConfig - :param server_address: (host, port) tuple or None to use defaults - :type server_address: tuple[str, int] | None - :param RequestHandlerClass: Handler class or None to use default - :type RequestHandlerClass: type | None - """ - self.config: SessionConfig = config - self.ipv4_only = getattr(config, "ipv4_only", False) - self.stop_flag = threading.Event() - if config.ipv6 and not self.ipv4_only: - self.address_family = socket.AF_INET6 - super().__init__( - server_address or (self.config.bind_address, self.default_port), - RequestHandlerClass or self.default_handler_class, - ) - @override def server_bind(self) -> None: """Bind the server socket with interface and IPv6 settings.""" @@ -604,7 +574,7 @@ def finish_request( # pyright: ignore[reportIncompatibleMethodOverride] self, request: socket.socket, client_address: tuple[str, int], - ) -> None: + ) -> None: # ty:ignore[invalid-method-override] """Finish a single request by instantiating the handler. :param request: Connected socket diff --git a/dementor/standalone.py b/dementor/standalone.py index 7c9a10b..fd36870 100644 --- a/dementor/standalone.py +++ b/dementor/standalone.py @@ -60,7 +60,7 @@ def serve( analyze_only: bool = False, config_path: str | None = None, session: SessionConfig | None = None, - supress_output: bool = False, + suppress_output: bool = False, loop: asyncio.AbstractEventLoop | None = None, run_forever: bool = True, run_repl: bool = False, @@ -122,10 +122,6 @@ def serve( # Load protocols loader = ProtocolLoader() session.manager = ProtocolManager(session, loader) - # REVISIT: ? - if not supress_output: - pass - if not getattr(session, "loop", None): session.loop = loop or asyncio.new_event_loop() @@ -176,7 +172,7 @@ def stop_session(session: SessionConfig) -> None: _SkippedOption = typer.Option(parser=lambda _: _, hidden=True, expose_value=False) -def parse_options(options: list[str]) -> dict: +def parse_options(options: list[str]) -> dict[str, Any]: result = {} for option in options: key, raw_value = option.split("=", 1) @@ -259,8 +255,8 @@ def main_format_config(name: str, value: str) -> str: return f"{line}[/white] {value}" -# TODO: refactor this -def main_print_options(session: SessionConfig, interface: str, config_path: str): +# TODO: extract main_print_options into a dedicated display module (too many responsibilities here) +def main_print_options(session: SessionConfig, interface: str, config_path: str) -> None: console = Console() console.rule(style="white", title="Dementor Configuration") analyze_only = r"[bold grey]\[Analyze Only][/bold grey]" @@ -268,7 +264,7 @@ def main_print_options(session: SessionConfig, interface: str, config_path: str) off = r"[bold red]\[OFF][/bold red]" poisoners_lines = ["", "[bold]Poisoners:[/bold]"] - # REVISIT: creation of poisoners list + # REVISIT: poisoners list is hardcoded here; extract to a registry in loader.py to avoid duplication poisoners = ("LLMNR", "MDNS", "NBTNS", "SSRP", "SSDP") for name in poisoners: attr_name = f"{name.lower()}_enabled" @@ -370,6 +366,17 @@ def main( help="Add an extra option to the global configuration file.", ), ] = None, + host: Annotated[ + str | None, + typer.Option( + "--host", + "-H", + metavar="HOST", + show_default=False, + help="Host FQDN for all protocol servers (e.g. DC01.contoso.lab). " + "Shortcut for -O Globals.Host=FQDN.", + ), + ] = None, ignore_prompt: Annotated[ bool, typer.Option( @@ -476,6 +483,10 @@ def main( for key, value in section_opts.items(): config.dm_config[section][key] = value + if host: + cfg = config.get_global_config() + cfg.setdefault("Globals", {})["Host"] = host + if ignored: ignore_targets = config.dm_config["Globals"].setdefault("Ignore", []) for target_format in ignored: diff --git a/dementor/tui/commands/database.py b/dementor/tui/commands/database.py index b978294..a4630c5 100644 --- a/dementor/tui/commands/database.py +++ b/dementor/tui/commands/database.py @@ -31,7 +31,7 @@ from prompt_toolkit.document import Document from typing_extensions import override -from dementor.db import _CLEARTEXT +from dementor.db import CLEARTEXT from dementor.tui.action import command, ReplAction from dementor.db.model import Credential, HostInfo, HostExtra @@ -274,9 +274,7 @@ def export(self, argv: argparse.Namespace) -> None: password = cred.password credtype = (cred.credtype or "").lower() - line = ( - f"{username}:{password}" if credtype == _CLEARTEXT.lower() else password - ) + line = f"{username}:{password}" if credtype == CLEARTEXT.lower() else password lines.append(line) if not lines: diff --git a/dementor/tui/commands/proto.py b/dementor/tui/commands/proto.py index 3ec8341..b2b0d4c 100644 --- a/dementor/tui/commands/proto.py +++ b/dementor/tui/commands/proto.py @@ -232,7 +232,7 @@ def service_status( if not active: label = " [white]".ljust(50, ".") else: - addr, port = thread.get_address(), thread.get_port() + addr, port = thread.get_address, thread.get_port label = f"{addr}:{port} [white]".ljust(50, ".") _ = tree.add(f"{label} {ON if active else OFF}") diff --git a/dementor/tui/completer.py b/dementor/tui/completer.py index cc6a232..c2b805c 100644 --- a/dementor/tui/completer.py +++ b/dementor/tui/completer.py @@ -75,7 +75,9 @@ def _get_parser_for_command(self, command_name: str): # Completer interface ------------------------------------------------------ @override - def get_completions(self, document: Document, complete_event: CompleteEvent): + def get_completions( + self, document: Document, complete_event: CompleteEvent + ) -> Iterable[Completion]: """Yield :class:`prompt_toolkit.completion.Completion` objects. The logic mirrors the description in the class docstring. It works on diff --git a/dementor/tui/repl.py b/dementor/tui/repl.py index 4268fd8..03de1a8 100644 --- a/dementor/tui/repl.py +++ b/dementor/tui/repl.py @@ -76,7 +76,7 @@ def __init__( ) self.console: Console = Console() - def get_prompt(self): + def get_prompt(self) -> list[tuple[str, str]]: """Build the prompt parts for the REPL. :return: A list of style/segment tuples understood by ``prompt_toolkit``. @@ -127,8 +127,8 @@ async def arun(self) -> None: while True: try: line = await self.prompt_session.prompt_async( - self.get_prompt(), - placeholder=self.get_placeholder(), + self.get_prompt(), # ty:ignore[invalid-argument-type] + placeholder=self.get_placeholder(), # ty:ignore[invalid-argument-type] ) line = line.strip() if not line: diff --git a/docs/Makefile b/docs/Makefile index d0c3cbf..faf5089 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -4,7 +4,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build +SPHINXBUILD ?= uv run sphinx-build SOURCEDIR = source BUILDDIR = build diff --git a/docs/make.bat b/docs/make.bat index dc1312a..1bf7193 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -5,7 +5,7 @@ pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build + set SPHINXBUILD="uv run sphinx-build" ) set SOURCEDIR=source set BUILDDIR=build diff --git a/docs/source/cli.rst b/docs/source/cli.rst index e7a0745..f95b22a 100644 --- a/docs/source/cli.rst +++ b/docs/source/cli.rst @@ -46,18 +46,18 @@ Command-Line Options - ``VALUE``: Will be automatically parsed into the appropriate type based on the following rules: - - **Boolean**: ``on``, ``yes``, ``true`` → `True`; ``off``, ``no``, ``false`` → `False` + - **Boolean**: ``on``, ``yes``, ``true`` -> `True`; ``off``, ``no``, ``false`` -> `False` - **Lists**: Values enclosed in brackets (e.g., ``["foo", "bar"]``) will be parsed as JSON strings. - **Numbers**: Numeric values will be automatically converted to integers or floats where applicable. - **Strings**: All other values will be treated as plain strings. Examples: - - ``-OLLMNR.AnswerName=pki-srv`` → Maps to :attr:`LLMNR.AnswerName`, value parsed as string. - - ``--option mDNS.TTL=340`` → Maps to :attr:`mDNS.TTL`, value parsed as integer. - - ``--option SMB.SMB2Support=off`` → Maps to :attr:`SMB.Server.SMB2Support`, value parsed as boolean. - - ``--option Log.DebugLoggers='["asyncio", "quic"]'`` → Maps to :attr:`Log.DebugLoggers`, value parsed as list. - - ``-O Globals.Ignore+="foobar"`` → Appends the parsed string value to :attr:`Globals.Ignore` + - ``-OLLMNR.AnswerName=pki-srv`` -> Maps to :attr:`LLMNR.AnswerName`, value parsed as string. + - ``--option mDNS.TTL=340`` -> Maps to :attr:`mDNS.TTL`, value parsed as integer. + - ``--option SMB.SMB2Support=off`` -> Maps to :attr:`SMB.Server.SMB2Support`, value parsed as boolean. + - ``--option Log.DebugLoggers='["asyncio", "quic"]'`` -> Maps to :attr:`Log.DebugLoggers`, value parsed as list. + - ``-O Globals.Ignore+="foobar"`` -> Appends the parsed string value to :attr:`Globals.Ignore` .. note:: Overrides made via the ``--option`` flag will **always take precedence** over the values @@ -67,6 +67,13 @@ Command-Line Options Options now support an "append" action using the ``+=`` operator for settings storing multiple values. +-H, --host HOST + + .. versionadded:: 1.0.0.dev22 + + Specify the host FQDN (fully qualified domain name) for all protocol servers + (e.g., ``DC01.contoso.lab``). This is a shortcut for ``-O Globals.Host=HOST``. + --verbose Enables verbose output for protocol-specific loggers, including debug-level messages. diff --git a/docs/source/config/dcerpc.rst b/docs/source/config/dcerpc.rst index cd9a2fe..1231801 100644 --- a/docs/source/config/dcerpc.rst +++ b/docs/source/config/dcerpc.rst @@ -92,14 +92,18 @@ Section ``[RPC]`` are configured globally in the :ref:`config_ntlm` section and apply to all protocols including DCE/RPC. -.. py:attribute:: Server.FQDN +.. py:attribute:: Server.Host :type: str :value: "DEMENTOR" *Maps to* :attr:`rpc.RPCConfig.rpc_fqdn`. *Can also be set in* ``[Globals]`` - Specifies the Fully Qualified Domain Name (FQDN) used by the server. The hostname part is - included in NTLM responses. The domain part is optional. + Specifies the host identity for this server. Accepts a full FQDN (e.g. ``DC01.contoso.lab``) or a bare + hostname. All protocol-level identity values (FQDN, NetBIOS names, DNS names) are derived from this + entry. Inherits from ``Globals.Host`` when not set here. + + .. versionchanged:: 1.0.0.dev22 + Renamed from ``FQDN`` to ``Host`` Section ``[EPM]`` ----------------- diff --git a/docs/source/config/globals.rst b/docs/source/config/globals.rst index d9046e8..c49ec91 100644 --- a/docs/source/config/globals.rst +++ b/docs/source/config/globals.rst @@ -134,4 +134,37 @@ TLS Options .. py:attribute:: Key :type: str - Specifies the private key file corresponding to the certificate used for TLS. \ No newline at end of file + Specifies the private key file corresponding to the certificate used for TLS. + + +.. py:attribute:: Host + :type: str + :value: "DEMENTOR" + + Specifies the host identity for all protocol servers. Accepts a full FQDN (e.g. ``DC01.contoso.lab``) or a bare + hostname. All protocol-level identity values (FQDN, NetBIOS names, DNS names) are derived from this entry when + not configured explicitly. + + When a Host value is set, the following identities are automatically derived: + + - **NetBIOS Domain Name** (``NetBIOSDomainName``): Domain portion of the Host (after first dot), uppercase + - **DNS Host Name** (``DNSHostName``): Hostname portion (before first dot) + - **NetBIOS Name** (``NetBIOSName``): First 15 characters of hostname, uppercase + - **DNS Domain Name** (``DNSDomainName``): Domain portion of the Host, lowercase + + Example: + + .. code-block:: toml + + [Globals] + Host = "DC01.contoso.lab" + + Results in: + + - ``NetBIOS Computer Name``: ``DC01`` + - ``NetBIOS Domain Name``: ``CONTOSO`` + - ``DNS Host Name``: ``dc01.contoso.lab`` + - ``DNS Domain Name``: ``contoso.lab`` + + .. versionadded:: 1.0.0.dev22 + Global Host configuration provides centralized identity management across all protocols. \ No newline at end of file diff --git a/docs/source/config/imap.rst b/docs/source/config/imap.rst index b6c9cd2..04fe78f 100644 --- a/docs/source/config/imap.rst +++ b/docs/source/config/imap.rst @@ -40,14 +40,18 @@ Section ``[IMAP]`` Defines the server capabilities to advertise to the client. According to the IMAP specification, the revision (such as `IMAP4rev1`) **must** be returned. - .. py:attribute:: Server.FQDN + .. py:attribute:: Server.Host :type: str - :value: "Dementor" + :value: "DEMENTOR" *Linked to* :attr:`imap.IMAPServerConfig.imap_fqdn`. *Can also be set in* ``[IMAP]`` *or* ``[Globals]``. - Specifies the Fully Qualified Domain Name (FQDN) hostname used by the IMAP server. - The hostname portion appears in server responses; the domain part is optional. + Specifies the host identity for this server. Accepts a full FQDN (e.g. ``DC01.contoso.lab``) or a bare + hostname. All protocol-level identity values (FQDN, NetBIOS names, DNS names) are derived from this + entry. Inherits from ``Globals.Host`` when not set here. + + .. versionchanged:: 1.0.0.dev22 + Renamed from ``FQDN`` to ``Host`` .. py:attribute:: Server.Banner :type: str diff --git a/docs/source/config/ldap.rst b/docs/source/config/ldap.rst index bcc6793..866d89e 100644 --- a/docs/source/config/ldap.rst +++ b/docs/source/config/ldap.rst @@ -77,14 +77,17 @@ Section ``[LDAP]`` for operations. - .. py:attribute:: Server.FQDN + .. py:attribute:: Server.Host :type: str *Maps to* :attr:`ldap.LDAPServerConfig.ldap_fqdn`. *Can also be set in* ``[LDAP]`` - Specifies the server's hostname or fully qualified domain name (FQDN). The domain portion is optional. - Example: ``"HOSTNAME.domain.local"``. + Specifies the host identity for this server. Accepts a full FQDN (e.g. ``DC01.contoso.lab``) or a bare + hostname. All protocol-level identity values (FQDN, NetBIOS names, DNS names) are derived from this + entry. Inherits from ``Globals.Host`` when not set here. + .. versionchanged:: 1.0.0.dev22 + Renamed from ``FQDN`` to ``Host`` .. py:attribute:: Server.ErrorCode :type: str | int diff --git a/docs/source/config/mssql.rst b/docs/source/config/mssql.rst index 56c807b..6439db3 100644 --- a/docs/source/config/mssql.rst +++ b/docs/source/config/mssql.rst @@ -57,14 +57,19 @@ Section ``[MSSQL]`` are configured globally in the :ref:`config_ntlm` section and apply to all protocols including MSSQL. -.. py:attribute:: FQDN +.. py:attribute:: Host :type: str :value: "DEMENTOR" *Maps to* :attr:`mssql.MSSQLServerConfig.mssql_fqdn`. *May also be set in* ``[Globals]`` - Sets the Fully Qualified Domain Name (FQDN) returned by the server. The hostname portion - is used in NTLM responses; the domain portion is optional. + Specifies the host identity for this server. Accepts a full FQDN (e.g. ``DC01.contoso.lab``) or a bare + hostname. All protocol-level identity values (FQDN, NetBIOS names, DNS names) are derived from this + entry. Inherits from ``Globals.Host`` when not set here. + + .. versionchanged:: 1.0.0.dev22 + Renamed from ``FQDN`` to ``Host`` + Error Configuration ^^^^^^^^^^^^^^^^^^^ @@ -123,16 +128,19 @@ Section ``[SSRP]`` would be valid. + .. versionchanged:: 1.0.0.dev22 + Renamed from ``FQDN`` to ``Host`` + Inherited from ``[MSSQL]`` ^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. py:attribute:: FQDN +.. py:attribute:: Host :type: str - :value: MSSQL.FQDN + :value: MSSQL.Host *Maps to* :attr:`mssql.SSRPConfig.ssrp_server_name`. *May also be set in* ``[Globals]`` - Defines the server name as described in :attr:`MSSQL.FQDN`. + Defines the server name as described in :attr:`MSSQL.Host`. .. py:attribute:: Version :type: str diff --git a/docs/source/config/ntlm.rst b/docs/source/config/ntlm.rst index 176b746..c38f269 100644 --- a/docs/source/config/ntlm.rst +++ b/docs/source/config/ntlm.rst @@ -10,7 +10,7 @@ Section ``[NTLM]`` Dementor's NTLM module (``ntlm.py``) implements the server side of the three-message NTLM handshake per `[MS-NLMP] `__. -It is a **capture module** — it builds a valid ``CHALLENGE_MESSAGE`` to keep +It is a **capture module** - it builds a valid ``CHALLENGE_MESSAGE`` to keep the handshake alive, extracts crackable hashes from the ``AUTHENTICATE_MESSAGE``, formats them for hashcat, and writes them to the database. It does not verify responses, compute session keys, or participate @@ -19,7 +19,7 @@ in post-authentication signing, sealing, or encryption. The ``[NTLM]`` config section provides **global settings** shared by every protocol that uses NTLM (SMB, HTTP, LDAP, MSSQL, etc.). All NTLM settings are configured exclusively in the ``[NTLM]`` section and apply identically to -every protocol — there are no per-protocol overrides. +every protocol - there are no per-protocol overrides. .. |rarr| unicode:: U+2192 @@ -54,7 +54,7 @@ Capture Behaviour `crack.sh `__) can crack NTLMv1 hashes offline without GPU resources. - **For NTLMv2 cracking:** the challenge value does not matter — + **For NTLMv2 cracking:** the challenge value does not matter - NTLMv2 incorporates the challenge into an HMAC-MD5 construction that is not amenable to rainbow tables. Use hashcat ``-m 5600`` with a wordlist or rules. @@ -104,19 +104,19 @@ Capture Behaviour **Effect on captured hashes:** - - ``false`` (default) — the server echoes ESS back to clients that + - ``false`` (default) - the server echoes ESS back to clients that request it. NTLMv1 clients (LmCompatibilityLevel 0-2) produce **NetNTLMv1-ESS** hashes (hashcat ``-m 5500``). The effective challenge becomes ``MD5(ServerChallenge || ClientChallenge)[0:8]``; hashcat derives this internally from the emitted ``ClientChallenge`` field. - - ``true`` — the server strips ESS from the response regardless of + - ``true`` - the server strips ESS from the response regardless of what the client requested. NTLMv1 clients produce plain **NetNTLMv1** hashes instead. A fixed :attr:`Challenge` combined with rainbow tables can crack these without GPU resources. - NTLMv2 clients (level 3+, all modern Windows) are **unaffected** — + NTLMv2 clients (level 3+, all modern Windows) are **unaffected** - they always produce NetNTLMv2 regardless of ESS. .. note:: @@ -145,11 +145,11 @@ Capture Behaviour **Effect on captured hashes:** - - ``false`` (default) — ``TargetInfoFields`` is populated. Clients can + - ``false`` (default) - ``TargetInfoFields`` is populated. Clients can construct an NTLMv2 response and produce **NetNTLMv2** (and sometimes **NetLMv2**) hashes (hashcat ``-m 5600``). - - ``true`` — ``TargetInfoFields`` is empty. Without it, clients cannot + - ``true`` - ``TargetInfoFields`` is empty. Without it, clients cannot build the NTLMv2 ``NTLMv2_CLIENT_CHALLENGE`` blob per [MS-NLMP] §3.3.2. LmCompatibilityLevel 0-2 clients fall back to NTLMv1. **Level 3+ clients** (all modern Windows defaults) **fail @@ -172,7 +172,7 @@ Server Identity These options control the identity values embedded in the NTLM ``CHALLENGE_MESSAGE``. They determine what appears on the wire, in captured hash lines, and in NTLMv2 ``AV_PAIR`` structures. **No client -changes authentication behavior** based on any of these values — they are +changes authentication behavior** based on any of these values - they are cosmetic from the client's perspective but operationally important for blending in and for hash formatting. @@ -188,9 +188,9 @@ protocols. Sets the ``NTLMSSP_TARGET_TYPE`` flag in the ``CHALLENGE_MESSAGE`` and determines the ``TargetName`` field value: - - ``"server"`` — sets ``NTLMSSP_TARGET_TYPE_SERVER`` (bit 17); + - ``"server"`` - sets ``NTLMSSP_TARGET_TYPE_SERVER`` (bit 17); ``TargetName`` is the NetBIOS computer name. - - ``"domain"`` — sets ``NTLMSSP_TARGET_TYPE_DOMAIN`` (bit 16); + - ``"domain"`` - sets ``NTLMSSP_TARGET_TYPE_DOMAIN`` (bit 16); ``TargetName`` is the NetBIOS domain name. .. py:attribute:: Version @@ -293,7 +293,7 @@ CHALLENGE_MESSAGE Construction ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The ``CHALLENGE_MESSAGE`` is built by ``NTLM_build_challenge_message()`` -per [MS-NLMP] §3.2.5.1.1. It is the **only message Dementor authors** — +per [MS-NLMP] §3.2.5.1.1. It is the **only message Dementor authors** - the other two are client-originated. **Flag mirroring:** @@ -342,15 +342,15 @@ The following client-requested flags are echoed back when present: When the client requests both ``NEGOTIATE_EXTENDED_SESSIONSECURITY`` (P) and ``NEGOTIATE_LM_KEY`` (G), only ESS is returned. Per [MS-NLMP] §2.2.2.5, -these flags are mutually exclusive — ESS takes priority. +these flags are mutually exclusive - ESS takes priority. **Server-set flags:** -- ``NTLMSSP_NEGOTIATE_NTLM`` — always set (NTLM authentication) -- ``NTLMSSP_REQUEST_TARGET`` — always set (TargetName present) -- ``NTLMSSP_TARGET_TYPE_SERVER`` or ``_DOMAIN`` — per :attr:`TargetType` -- ``NTLMSSP_NEGOTIATE_TARGET_INFO`` — set unless :attr:`DisableNTLMv2` -- ``NTLMSSP_NEGOTIATE_VERSION`` — echoed when the client requests it +- ``NTLMSSP_NEGOTIATE_NTLM`` - always set (NTLM authentication) +- ``NTLMSSP_REQUEST_TARGET`` - always set (TargetName present) +- ``NTLMSSP_TARGET_TYPE_SERVER`` or ``_DOMAIN`` - per :attr:`TargetType` +- ``NTLMSSP_NEGOTIATE_TARGET_INFO`` - set unless :attr:`DisableNTLMv2` +- ``NTLMSSP_NEGOTIATE_VERSION`` - echoed when the client requests it AV_PAIRs (``TargetInfoFields``) @@ -375,19 +375,19 @@ independent default configured in the ``[NTLM]`` section: - :attr:`NTLM.NetBIOSDomain` (default ``"WORKGROUP"``) * - ``0x0003`` - ``MsvAvDnsComputerName`` - - :attr:`NTLM.DnsComputer` (default ``""`` — omitted from AV_PAIRs + - :attr:`NTLM.DnsComputer` (default ``""`` - omitted from AV_PAIRs when empty) * - ``0x0004`` - ``MsvAvDnsDomainName`` - - :attr:`NTLM.DnsDomain` (default ``""`` — omitted from AV_PAIRs + - :attr:`NTLM.DnsDomain` (default ``""`` - omitted from AV_PAIRs when empty) * - ``0x0005`` - ``MsvAvDnsTreeName`` - - :attr:`NTLM.DnsTree` (default ``""`` — omitted from AV_PAIRs + - :attr:`NTLM.DnsTree` (default ``""`` - omitted from AV_PAIRs when empty) * - ``0x0007`` - ``MsvAvTimestamp`` - - **Intentionally omitted** — see below + - **Intentionally omitted** - see below * - ``0x0000`` - ``MsvAvEOL`` - Always appended (list terminator) @@ -472,15 +472,15 @@ LM Response Filtering For **NetNTLMv1** captures, the LM slot in the hashcat line is omitted when any of the following conditions hold: -- **Identical response** — ``LmChallengeResponse == NtChallengeResponse``. +- **Identical response** - ``LmChallengeResponse == NtChallengeResponse``. This occurs at LmCompatibilityLevel 2, where the client copies the NT response into both slots. Using the LM copy with the NT one-way function during cracking would yield incorrect results. -- **Long-password placeholder** — ``LmChallengeResponse == DESL(Z(16))``. +- **Long-password placeholder** - ``LmChallengeResponse == DESL(Z(16))``. Clients send this deterministic value when the password exceeds 14 characters or the ``NoLMHash`` registry policy is enforced. It carries no crackable material. -- **Empty-password placeholder** — ``LmChallengeResponse == DESL(LMOWFv1(""))``. +- **Empty-password placeholder** - ``LmChallengeResponse == DESL(LMOWFv1(""))``. The LM derivative of an empty password; equally uncrackable. @@ -499,12 +499,12 @@ the primary NetNTLMv2 response when all of the following hold: Clients set ``LmChallengeResponse`` to ``Z(24)`` when: -- **MsvAvTimestamp present in CHALLENGE_MESSAGE** — Per [MS-NLMP] §3.3.2 +- **MsvAvTimestamp present in CHALLENGE_MESSAGE** - Per [MS-NLMP] §3.3.2 rule 7, when the server includes ``MsvAvTimestamp`` (``0x0007``) in the AV_PAIR list, the client MUST suppress ``LmChallengeResponse``. Dementor intentionally omits ``MsvAvTimestamp`` to avoid this. -- **Win 7+ / Server 2008 R2+ defaults** — These versions suppress LMv2 +- **Win 7+ / Server 2008 R2+ defaults** - These versions suppress LMv2 regardless of LmCompatibilityLevel. Only Vista and Server 2008 send real LMv2 responses. @@ -541,7 +541,7 @@ avoid writing a malformed capture. Anonymous tokens are silently discarded. XP SP3 and XP SP0 send an anonymous ``AUTHENTICATE_MESSAGE`` probe before the real credential auth on each connection. This is normal - SSPI behavior — the anonymous probe is discarded and the real auth + SSPI behavior - the anonymous probe is discarded and the real auth that follows is captured. @@ -573,7 +573,7 @@ The three messages provide increasingly detailed information: - Yes * - Username - No - - — + - - - Yes * - NegotiateFlags - Yes @@ -581,27 +581,27 @@ The three messages provide increasingly detailed information: - Yes * - NTLMv2 blob AV_PAIRs - No - - — + - - - Yes (NTLMv2 only) * - SPN (``MsvAvTargetName``, 0x0009) - No - - — + - - - Yes (NTLMv2 only) [3]_ * - Client timestamp (``MsvAvTimestamp``, 0x0007) - No - - — + - - - Yes (NTLMv2 only) * - MIC (Message Integrity Code) - No - - — + - - - Yes (if VERSION flag set) * - Channel Bindings (``MsvAvChannelBindings``, 0x000A) - No - - — + - - - Yes (NTLMv2 only) [4]_ * - ``MsvAvFlags`` (0x0006) - No - - — + - - - Yes (NTLMv2 only) .. [1] Only when ``NTLMSSP_NEGOTIATE_VERSION`` is set. XP SP0 does not @@ -663,7 +663,7 @@ LmCompatibilityLevel Reference The Windows ``LmCompatibilityLevel`` registry value (``HKLM\SYSTEM\CurrentControlSet\Control\Lsa``) controls which NTLM response types a client sends. This is the **single most important client -setting** for hash capture — it determines what Dementor can extract. +setting** for hash capture - it determines what Dementor can extract. .. list-table:: :header-rows: 1 @@ -768,23 +768,23 @@ Default Configuration # This section applies to all NTLM-enabled protocols # (SMB, HTTP, SMTP, IMAP, POP3, LDAP, MSSQL, RPC). # 8-byte ServerChallenge nonce. Accepted formats: - # "hex:1122334455667788" — explicit hex (recommended) - # "ascii:1337LEET" — explicit ASCII (recommended) - # "1122334455667788" — 16 hex chars, auto-detected - # "1337LEET" — 8 ASCII chars, auto-detected + # "hex:1122334455667788" - explicit hex (recommended) + # "ascii:1337LEET" - explicit ASCII (recommended) + # "1122334455667788" - 16 hex chars, auto-detected + # "1337LEET" - 8 ASCII chars, auto-detected # Omit entirely for a cryptographically random value per run. # Challenge = "1337LEET" # Strip NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY. - # false (default): ESS negotiated → NetNTLMv1-ESS (hashcat -m 5500). - # true: ESS suppressed → plain NetNTLMv1; crackable with + # false (default): ESS negotiated -> NetNTLMv1-ESS (hashcat -m 5500). + # true: ESS suppressed -> plain NetNTLMv1; crackable with # rainbow tables when combined with a fixed Challenge. DisableExtendedSessionSecurity = false # Omit TargetInfoFields (AV_PAIRS) from CHALLENGE_MESSAGE. # false (default): NetNTLMv2 + NetLMv2 captured from modern clients. # true: Level 0-2 clients fall back to NTLMv1; level 3+ - # (all modern Windows) will FAIL — NO captures. + # (all modern Windows) will FAIL - NO captures. DisableNTLMv2 = false # Server identity in the CHALLENGE_MESSAGE: diff --git a/docs/source/config/pop3.rst b/docs/source/config/pop3.rst index f38ba1e..37252b3 100644 --- a/docs/source/config/pop3.rst +++ b/docs/source/config/pop3.rst @@ -35,10 +35,8 @@ Section ``[POP3]`` :type: str :value: "Dementor" - *Linked to* :attr:`pop3.POP3ServerConfig.pop3_fqdn`. *Can also be set in* ``[POP3]`` *or* ``[Globals]`` - - Specifies the Fully Qualified Domain Name (FQDN) hostname used by the POP3 server. - The hostname portion of the FQDN will be included in server responses. The domain part is optional. + .. versionremoved:: 1.0.0.dev22 + Unused attribute was removed .. py:attribute:: Server.Banner :type: str diff --git a/docs/source/config/smb.rst b/docs/source/config/smb.rst index a4db8d1..6be6b82 100644 --- a/docs/source/config/smb.rst +++ b/docs/source/config/smb.rst @@ -297,8 +297,8 @@ Post-Auth Behaviour effect over CPC = 0. .. py:attribute:: ErrorCode - :type: str | int - :value: "STATUS_SMB_BAD_UID" + :type: str | int | None + :value: None *Maps to* :attr:`smb.SMBServerConfig.smb_error_code` @@ -311,13 +311,13 @@ Post-Auth Behaviour * - Value - Effect - * - ``"STATUS_SMB_BAD_UID"`` (default) + * - ``"STATUS_SMB_BAD_UID"`` - Client disconnects cleanly. * - ``"STATUS_ACCESS_DENIED"`` - Client may retry, then disconnects. * - ``"STATUS_LOGON_FAILURE"`` - Client disconnects cleanly. - * - ``"STATUS_SUCCESS"`` + * - ``"STATUS_SUCCESS"`` (default behavior if no value is set) - Client proceeds to tree connect. Useful for extending the session to capture tree-connect paths. @@ -358,6 +358,19 @@ Server Instances - **139** -- NetBIOS session service (used by XP/Server 2003 in addition to port 445; leaks NetBIOS CallingName) + .. py:attribute:: Server.Host + :type: str + :value: "DEMENTOR" + + *Maps to* :attr:`smb.SMBServerConfig.smb_fqdn`. *Can also be set in* ``[SMB]`` *or* ``[Globals]`` + + Specifies the host identity for this server. Accepts a full FQDN (e.g. ``DC01.contoso.lab``) or a bare + hostname. All protocol-level identity values (FQDN, NetBIOS names, DNS names) are derived from this + entry. Inherits from ``Globals.Host`` when not set here. + + .. versionchanged:: 1.0.0.dev22 + Renamed from ``FQDN`` to ``Host`` + SMB 3.1.1 Negotiate Contexts ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -436,7 +449,7 @@ Default Configuration # ServerOS = "Windows" # NativeLanMan = "Windows" CapturesPerConnection = 0 - ErrorCode = "STATUS_SMB_BAD_UID" + # ErrorCode = "STATUS_SMB_BAD_UID" # NTLM settings are in the [NTLM] section. diff --git a/docs/source/config/smtp.rst b/docs/source/config/smtp.rst index f7f906b..2df9acc 100644 --- a/docs/source/config/smtp.rst +++ b/docs/source/config/smtp.rst @@ -31,14 +31,18 @@ Section: ``[SMTP]`` This value must be specified within a ``[[SMTP.Server]]`` section. - .. py:attribute:: Server.FQDN + .. py:attribute:: Server.Host :type: str :value: "DEMENTOR" *Linked to* :attr:`smtp.SMTPServerConfig.smtp_fqdn`. *Can also be set in* ``[SMTP]`` - Specifies the Fully Qualified Domain Name (FQDN) hostname used by the SMTP server. - The hostname portion of the FQDN will be included in server responses. The domain part is optional. + Specifies the host identity for this server. Accepts a full FQDN (e.g. ``DC01.contoso.lab``) or a bare + hostname. All protocol-level identity values (FQDN, NetBIOS names, DNS names) are derived from this + entry. Inherits from ``Globals.Host`` when not set here. + + .. versionchanged:: 1.0.0.dev22 + Renamed from ``FQDN`` to ``Host`` .. py:attribute:: Server.Ident :type: str diff --git a/tests/config/test_host_config.py b/tests/config/test_host_config.py new file mode 100644 index 0000000..a0fdc32 --- /dev/null +++ b/tests/config/test_host_config.py @@ -0,0 +1,544 @@ +# Copyright (c) 2025-Present MatrixEditor +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Unit tests for the Host identity configuration feature. + +Covers: +- ``dementor.config.util.HostValue``: FQDN parsing and value derivation +- ``dementor.config.util.HostDerivedValue``: per-attribute factory (pure, no global fetches) +- ``dementor.config.attr.ATTR_GLOBALS_HOST``: single Attribute for Host +- NTLM session identity (ntlm_nb_computer, ntlm_nb_domain, …) via apply_config +- Protocol FQDN fallback (SMTP, LDAP, POP3, IMAP, MSSQL, RPC, HTTP, SMB) +- CLI -H / --host option (parse_options integration) +""" + +import pytest + +from unittest.mock import MagicMock + +from dementor.config import _set_global_config, get_global_config +from dementor.config.attr import ATTR_GLOBALS_HOST +from dementor.config.toml import Attribute +from dementor.config.util import HostValue, HostFallbackValue + +from dementor.protocols import ntlm +from dementor.protocols.smtp import SMTPServerConfig +from dementor.protocols.ldap import LDAPServerConfig +from dementor.protocols.imap import IMAPServerConfig +from dementor.protocols.mssql import MSSQLConfig +from dementor.protocols.msrpc.rpc import RPCConfig +from dementor.protocols.mssql import SSRPConfig +from dementor.protocols.smb import SMBServerConfig + +from dementor.standalone import parse_options + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _with_globals(**kw): + """Return a minimal global config dict that has a Globals section.""" + return {"Globals": kw} + + +# --------------------------------------------------------------------------- +# HostValue - construction +# --------------------------------------------------------------------------- + + +class TestHostValueConstruction: + def test_full_fqdn_splits_correctly(self): + hv = HostValue("DC01.contoso.lab") + assert hv.hostname == "DC01" + assert hv.domain == "contoso.lab" + + def test_no_dot_hostname_only(self): + hv = HostValue("DEMENTOR") + assert hv.hostname == "DEMENTOR" + assert hv.domain == "" + + def test_multi_label_domain(self): + hv = HostValue("srv.sub.corp.example.com") + assert hv.hostname == "srv" + assert hv.domain == "sub.corp.example.com" + + def test_str_returns_raw(self): + hv = HostValue("DC01.contoso.lab") + assert str(hv) == "DC01.contoso.lab" + + def test_str_hostname_only(self): + hv = HostValue("MYHOST") + assert str(hv) == "MYHOST" + + def test_call_returns_new_instance(self): + factory = HostValue("DEMENTOR") + result = factory("DC01.corp.local") + assert isinstance(result, HostValue) + assert result.hostname == "DC01" + assert result.domain == "corp.local" + + def test_whitespace_stripped(self): + hv = HostValue(" DC01.corp.com ") + assert str(hv) == "DC01.corp.com" + + +# --------------------------------------------------------------------------- +# HostValue - get_value field derivation +# --------------------------------------------------------------------------- + + +class TestHostValueGetValue: + @pytest.fixture + def hv(self): + return HostValue("DC01.contoso.lab") + + @pytest.fixture + def hv_no_domain(self): + return HostValue("DEMENTOR") + + # --- with domain --- + + def test_host_field(self, hv): + assert hv.get_value("Host") == "DC01.contoso.lab" + + def test_fqdn_field(self, hv): + assert hv.get_value("FQDN") == "DC01.contoso.lab" + + def test_dns_computer(self, hv): + assert hv.get_value("DnsComputer") == "DC01.contoso.lab" + + def test_dns_hostname(self, hv): + assert hv.get_value("DNSHostName") == "DC01" + + def test_netbios_computer(self, hv): + assert hv.get_value("NetBIOSComputer") == "DC01" + + def test_netbios_name(self, hv): + assert hv.get_value("NetBIOSName") == "DC01" + + def test_netbios_domain(self, hv): + assert hv.get_value("NetBIOSDomain") == "CONTOSO.LAB" + + def test_netbios_domain_name(self, hv): + assert hv.get_value("NetBIOSDomainName") == "CONTOSO.LAB" + + def test_dns_domain(self, hv): + assert hv.get_value("DnsDomain") == "contoso.lab" + + def test_dns_domain_name(self, hv): + assert hv.get_value("DNSDomainName") == "contoso.lab" + + def test_dns_tree(self, hv): + assert hv.get_value("DnsTree") == "contoso.lab" + + def test_unknown_field_returns_hostname(self, hv): + assert hv.get_value("UnknownField") == "DC01" + + # --- without domain --- + + def test_no_domain_fqdn_is_hostname(self, hv_no_domain): + assert hv_no_domain.get_value("FQDN") == "DEMENTOR" + + def test_no_domain_dns_computer_is_empty(self, hv_no_domain): + # DnsComputer must be empty when there is no domain (omit AV_PAIR) + assert hv_no_domain.get_value("DnsComputer") == "" + + def test_no_domain_netbios_domain_fallback(self, hv_no_domain): + assert hv_no_domain.get_value("NetBIOSDomain") == "WORKGROUP" + + def test_no_domain_dns_domain_empty(self, hv_no_domain): + assert hv_no_domain.get_value("DnsDomain") == "WORKGROUP" + + def test_no_domain_dns_tree_empty(self, hv_no_domain): + assert hv_no_domain.get_value("DnsTree") == "WORKGROUP" + + # --- NetBIOS 15-char truncation --- + + def test_netbios_computer_truncated_to_15(self): + hv = HostValue("AVERYLONGHOSTNAME123.corp.local") + nb = hv.get_value("NetBIOSComputer") + assert len(nb) <= 15 + assert nb == "AVERYLONGHOSTNA" + + def test_netbios_computer_uppercased(self): + hv = HostValue("dc01.contoso.lab") + assert hv.get_value("NetBIOSComputer") == "DC01" + + def test_netbios_domain_uppercased(self): + hv = HostValue("dc01.contoso.lab") + assert hv.get_value("NetBIOSDomain") == "CONTOSO.LAB" + + def test_dns_domain_lowercased(self): + hv = HostValue("DC01.CONTOSO.LAB") + assert hv.get_value("DnsDomain") == "contoso.lab" + + +# --------------------------------------------------------------------------- +# ATTR_GLOBALS_HOST - Attribute metadata +# --------------------------------------------------------------------------- + + +class TestAttrGlobalsHost: + def test_is_attribute_instance(self): + assert isinstance(ATTR_GLOBALS_HOST, Attribute) + + def test_attr_name(self): + assert ATTR_GLOBALS_HOST.attr_name == "host" + + def test_qname(self): + assert ATTR_GLOBALS_HOST.qname == "Host" + + def test_not_section_local(self): + assert ATTR_GLOBALS_HOST.section_local is False + + def test_factory_is_host_value(self): + assert ATTR_GLOBALS_HOST.factory is HostValue + + def test_factory_produces_host_value_instance(self): + result = ATTR_GLOBALS_HOST.factory("DC01.corp.local") + assert isinstance(result, HostValue) + assert result.hostname == "DC01" + + +# --------------------------------------------------------------------------- +# HostFallbackValue - explicit-first, Host-derived fallback factory +# --------------------------------------------------------------------------- + + +class TestHostFallbackValue: + """Verify that HostFallbackValue uses explicit values directly and + reads Globals.Host only as a last resort. + """ # noqa: D205 + + def test_explicit_value_returned_as_is(self): + factory = HostFallbackValue("FQDN", "DEMENTOR") + assert factory("explicit.smtp.com") == "explicit.smtp.com" + + def test_explicit_netbios_returned_as_is(self): + factory = HostFallbackValue("NetBIOSComputer", "DEMENTOR") + assert factory("MYSERVER") == "MYSERVER" + + def test_fallback_when_none_and_no_host(self): + original = get_global_config() + try: + _set_global_config({}) + factory = HostFallbackValue("FQDN", "DEMENTOR") + assert factory(None) == "DEMENTOR" + finally: + _set_global_config(original) + + def test_derives_fqdn_from_globals_host_when_none(self): + original = get_global_config() + try: + _set_global_config({"Globals": {"Host": "DC01.contoso.lab"}}) + factory = HostFallbackValue("FQDN", "DEMENTOR") + assert factory(None) == "DC01.contoso.lab" + finally: + _set_global_config(original) + + def test_derives_netbios_computer_from_globals_host(self): + original = get_global_config() + try: + _set_global_config({"Globals": {"Host": "DC01.contoso.lab"}}) + factory = HostFallbackValue("NetBIOSComputer", "DEMENTOR") + assert factory(None) == "DC01" + finally: + _set_global_config(original) + + def test_derives_netbios_domain_from_globals_host(self): + original = get_global_config() + try: + _set_global_config({"Globals": {"Host": "DC01.contoso.lab"}}) + factory = HostFallbackValue("NetBIOSDomain", "WORKGROUP") + assert factory(None) == "CONTOSO.LAB" + finally: + _set_global_config(original) + + def test_explicit_beats_globals_host(self): + """An explicit value must win over [Globals].Host derivation.""" + original = get_global_config() + try: + _set_global_config({"Globals": {"Host": "DC01.contoso.lab"}}) + factory = HostFallbackValue("FQDN", "DEMENTOR") + assert factory("override.corp.com") == "override.corp.com" + finally: + _set_global_config(original) + + def test_post_factory_applied_to_explicit(self): + factory = HostFallbackValue("FQDN", "DEMENTOR", post_factory=str.upper) + assert factory("smtp.corp.com") == "SMTP.CORP.COM" + + def test_post_factory_applied_to_derived(self): + original = get_global_config() + try: + _set_global_config({"Globals": {"Host": "dc01.corp.com"}}) + factory = HostFallbackValue("FQDN", "DEMENTOR", post_factory=str.upper) + assert factory(None) == "DC01.CORP.COM" + finally: + _set_global_config(original) + + def test_hostname_only_netbios_domain_workgroup_fallback(self): + original = get_global_config() + try: + _set_global_config({"Globals": {"Host": "DEMENTOR"}}) + factory = HostFallbackValue("NetBIOSDomain", "WORKGROUP") + assert factory(None) == "WORKGROUP" + finally: + _set_global_config(original) + + +# --------------------------------------------------------------------------- +# NTLM session identity - apply_config picks up Globals.Host +# --------------------------------------------------------------------------- + + +class TestNTLMApplyConfigWithHost: + """Verify that ntlm.apply_config() derives identity from Globals.Host.""" + + def _apply(self, globals_dict: dict, extra_sections: dict | None = None): + original = get_global_config() + try: + cfg = {"Globals": globals_dict} + if extra_sections: + cfg.update(extra_sections) + _set_global_config(cfg) + session = MagicMock() + ntlm.apply_config(session) + return session + finally: + _set_global_config(original) + + def test_default_identity_when_no_host(self): + session = self._apply({}) + assert session.ntlm_nb_computer == "DEMENTOR" + assert session.ntlm_nb_domain == "WORKGROUP" + assert session.ntlm_dns_computer == "" + assert session.ntlm_dns_domain == "" + + def test_identity_derived_from_host(self): + session = self._apply({"Host": "DC01.contoso.lab"}) + assert session.ntlm_nb_computer == "DC01" + assert session.ntlm_nb_domain == "CONTOSO.LAB" + assert session.ntlm_dns_computer == "DC01.contoso.lab" + assert session.ntlm_dns_domain == "contoso.lab" + + def test_ntlm_section_overrides_host(self): + """Explicit [NTLM].NetBIOSComputer beats [Globals].Host derivation.""" + session = self._apply( + {"Host": "DC01.contoso.lab"}, + {"NTLM": {"NetBIOSComputer": "OVERRIDE"}}, + ) + assert session.ntlm_nb_computer == "OVERRIDE" + # Domain still from Host since no [NTLM].NetBIOSDomain + assert session.ntlm_nb_domain == "CONTOSO.LAB" + + def test_globals_explicit_field_beats_host(self): + """An explicit [Globals].NetBIOSComputer beats [Globals].Host derivation.""" + session = self._apply({"Host": "DC01.contoso.lab", "NetBIOSComputer": "EXPLICIT"}) + assert session.ntlm_nb_computer == "EXPLICIT" + # Other fields still from Host + assert session.ntlm_nb_domain == "CONTOSO.LAB" + + +# --------------------------------------------------------------------------- +# Protocol FQDN fallback - SMTP, LDAP, POP3, IMAP, MSSQL, HTTP, RPC +# --------------------------------------------------------------------------- + + +class TestProtocolFQDNFallback: + """Verify that protocol FQDN attrs derive from Globals.Host.""" + + def _global_cfg(self, host: str, **extra_globals): + return {"Globals": {"Host": host, **extra_globals}} + + def test_smtp_fqdn_from_host(self): + original = get_global_config() + try: + _set_global_config(self._global_cfg("MAIL01.corp.com")) + + cfg = SMTPServerConfig({"Port": 25}) + assert cfg.smtp_fqdn == "MAIL01.corp.com" + finally: + _set_global_config(original) + + def test_smtp_fqdn_explicit_in_protocol_overrides_host(self): + original = get_global_config() + try: + _set_global_config( + { + "Globals": {"Host": "MAIL01.corp.com"}, + "SMTP": { + "Host": "explicit.smtp.com" + }, # explicit Host in [SMTP] beats [Globals].Host + } + ) + + cfg = SMTPServerConfig({"Port": 25}) + assert cfg.smtp_fqdn == "explicit.smtp.com" + finally: + _set_global_config(original) + + def test_ldap_fqdn_from_host(self): + original = get_global_config() + try: + _set_global_config(self._global_cfg("DC01.corp.local")) + + cfg = LDAPServerConfig({"Port": 389, "Connectionless": False}) + assert cfg.ldap_fqdn == "DC01.corp.local" + finally: + _set_global_config(original) + + def test_imap_fqdn_from_host(self): + original = get_global_config() + try: + _set_global_config(self._global_cfg("MAIL01.corp.com")) + + cfg = IMAPServerConfig({"Port": 143}) + assert cfg.imap_fqdn == "MAIL01.corp.com" + finally: + _set_global_config(original) + + def test_mssql_fqdn_from_host(self): + original = get_global_config() + try: + _set_global_config(self._global_cfg("SQL01.corp.com")) + + cfg = MSSQLConfig({"Port": 1433}) + assert cfg.mssql_fqdn == "SQL01.corp.com" + finally: + _set_global_config(original) + + def test_rpc_fqdn_from_host(self): + original = get_global_config() + try: + _set_global_config(self._global_cfg("DC01.corp.local")) + + cfg = RPCConfig({}) + assert cfg.rpc_fqdn == "DC01.corp.local" + finally: + _set_global_config(original) + + def test_ssrp_server_name_from_host(self): + """SSRP derives its server name from Globals.Host (via MSSQL.FQDN chain).""" + original = get_global_config() + try: + _set_global_config(self._global_cfg("SQL01.corp.com")) + + cfg = SSRPConfig({}) + assert cfg.ssrp_server_name == "SQL01.corp.com" + finally: + _set_global_config(original) + + def test_smb_nb_computer_from_host(self): + """SMB identity derives NetBIOSComputer from Globals.Host.""" + original = get_global_config() + try: + _set_global_config(self._global_cfg("DC01.corp.local")) + + cfg = SMBServerConfig({"Port": 445}) + assert cfg.smb_nb_computer == "DC01" + assert cfg.smb_nb_domain == "CORP.LOCAL" + finally: + _set_global_config(original) + + def test_smb_nb_computer_explicit_overrides_host(self): + """Explicit [SMB].NetBIOSComputer beats Globals.Host derivation.""" + original = get_global_config() + try: + _set_global_config( + { + "Globals": {"Host": "DC01.corp.local"}, + "SMB": {"NetBIOSComputer": "EXPLICIT"}, + } + ) + + cfg = SMBServerConfig({"Port": 445}) + assert cfg.smb_nb_computer == "EXPLICIT" + assert cfg.smb_nb_domain == "CORP.LOCAL" # still derived from Host + finally: + _set_global_config(original) + + def test_smb_nb_computer_default_when_no_host(self): + """Without Host, SMB identity falls back to hardcoded 'DEMENTOR'.""" + original = get_global_config() + try: + _set_global_config({}) + + cfg = SMBServerConfig({"Port": 445}) + finally: + _set_global_config(original) + assert cfg.smb_nb_computer == "DEMENTOR" + assert cfg.smb_nb_domain == "WORKGROUP" + + def test_per_server_fqdn_overrides_globals(self): + """Per-server FQDN takes priority over Globals.Host.""" + original = get_global_config() + try: + _set_global_config(self._global_cfg("MAIL01.corp.com")) + + # Per-server config dict with explicit "Host" key + cfg = SMTPServerConfig({"Port": 25, "Host": "perserver.example.com"}) + assert cfg.smtp_fqdn == "perserver.example.com" + finally: + _set_global_config(original) + + def test_no_host_smtp_falls_back_to_hardcoded_default(self): + original = get_global_config() + try: + _set_global_config({}) + + cfg = SMTPServerConfig({"Port": 25}) + finally: + _set_global_config(original) + assert cfg.smtp_fqdn == "DEMENTOR" + + +# --------------------------------------------------------------------------- +# CLI -H option integration +# --------------------------------------------------------------------------- + + +class TestCLIHostOption: + def test_parse_options_globals_host(self): + + result = parse_options(["Globals.Host=DC01.contoso.lab"]) + assert result == {"Globals": {"Host": "DC01.contoso.lab"}} + + def test_host_flag_sets_globals_host(self): + """Simulate -H DC01.contoso.lab being applied to the config.""" + original = get_global_config() + try: + import dementor.config as cfg_mod # noqa: PLC0415 + + cfg_mod.dm_config.setdefault("Globals", {})["Host"] = "DC01.contoso.lab" + result = get_global_config()["Globals"] + finally: + _set_global_config(original) + + assert result["Host"] == "DC01.contoso.lab" + # Factories derive on attribute access - verify the factory works correctly + factory = HostFallbackValue("NetBIOSComputer", "DEMENTOR") + assert factory("DC01") == "DC01" + + def test_option_flag_equivalent_to_host_flag(self): + """Globals.Host via -O must produce same result as -H.""" + via_O = parse_options(["Globals.Host=DC01.contoso.lab"]) + # -H DC01.contoso.lab is equivalent to setting Globals.Host directly + assert via_O["Globals"]["Host"] == "DC01.contoso.lab" diff --git a/tests/config/test_toml_config.py b/tests/config/test_toml_config.py new file mode 100644 index 0000000..ba6bc63 --- /dev/null +++ b/tests/config/test_toml_config.py @@ -0,0 +1,362 @@ +# Copyright (c) 2025-Present MatrixEditor +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Unit tests for the dementor configuration backbone. + +Covers: +- ``dementor.config.toml``: ``TomlConfig``, ``Attribute``, ``_set_field``, ``build_config`` +- ``dementor.config.util``: ``get_value``, ``is_true`` +- ``dementor.config``: ``get_global_config``, ``_set_global_config`` +""" + +import pytest + +from unittest.mock import patch + +from dementor.config.toml import TomlConfig, Attribute, _LOCAL +from dementor.config.util import get_value, is_true +from dementor.config import get_global_config, _set_global_config + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class SimpleConfig(TomlConfig): + _section_ = "Simple" + _fields_ = [ + Attribute("host", "host", default_val="localhost"), + Attribute("port", "port", default_val=8080, factory=int), + ] + + +class RequiredFieldConfig(TomlConfig): + _section_ = "Req" + _fields_ = [ + Attribute("required_val", "value"), # no default_val -> _LOCAL sentinel + ] + + +class FactoryConfig(TomlConfig): + _section_ = "Factory" + _fields_ = [ + Attribute("count", "count", default_val="5", factory=int), + Attribute( + "flag", + "flag", + default_val="yes", + factory=lambda v: str(v).lower() in ("true", "1", "yes"), + ), + ] + + +class GlobalFallbackConfig(TomlConfig): + _section_ = "Proto" + _fields_ = [ + # section_local=False means also look in Globals + Attribute("timeout", "timeout", default_val=30, section_local=False), + ] + + +# --------------------------------------------------------------------------- +# TomlConfig - basic instantiation +# --------------------------------------------------------------------------- + + +class TestTomlConfigBasic: + def test_default_values_used_when_config_empty(self): + cfg = SimpleConfig({}) + assert cfg.host == "localhost" + assert cfg.port == 8080 + + def test_none_config_treated_as_empty_dict(self): + cfg = SimpleConfig(None) + assert cfg.host == "localhost" + assert cfg.port == 8080 + + def test_config_values_override_defaults(self): + cfg = SimpleConfig({"host": "example.com", "port": "9090"}) + assert cfg.host == "example.com" + assert cfg.port == 9090 # factory=int applied + + def test_partial_override(self): + cfg = SimpleConfig({"host": "myhost"}) + assert cfg.host == "myhost" + assert cfg.port == 8080 # default preserved + + def test_factory_applied_to_value(self): + cfg = SimpleConfig({"port": "1234"}) + assert cfg.port == 1234 + assert isinstance(cfg.port, int) + + def test_factory_applied_to_default(self): + cfg = FactoryConfig({}) + assert cfg.count == 5 + assert isinstance(cfg.count, int) + assert cfg.flag is True + + def test_factory_processes_supplied_value(self): + cfg = FactoryConfig({"count": "99", "flag": "false"}) + assert cfg.count == 99 + assert cfg.flag is False + + +# --------------------------------------------------------------------------- +# Attribute - sentinel behaviour +# --------------------------------------------------------------------------- + + +class TestAttributeSentinel: + def test_required_field_raises_when_missing(self): + with pytest.raises(ValueError, match="value"): + RequiredFieldConfig({}) + + def test_required_field_succeeds_when_supplied(self): + cfg = RequiredFieldConfig({"value": "hello"}) + assert cfg.required_val == "hello" + + def test_local_sentinel_is_not_none(self): + # _LOCAL must be a distinct object from None so that None can be a valid default + assert _LOCAL is not None + + def test_none_default_distinct_from_local(self): + class NullableConfig(TomlConfig): + _section_ = "N" + _fields_ = [Attribute("val", "val", default_val=None)] + + cfg = NullableConfig({}) + assert cfg.val is None + + +# --------------------------------------------------------------------------- +# _set_field - type coercion and setter dispatch +# --------------------------------------------------------------------------- + + +class TestSetField: + def test_custom_setter_called_when_present(self): + class SetterConfig(TomlConfig): + _section_ = "S" + _fields_ = [Attribute("value", "value", default_val=0, factory=int)] + + def set_value(self, val: int) -> None: + self.value = val * 2 # double on set + + cfg = SetterConfig({"value": "7"}) + assert cfg.value == 14 + + def test_setattr_used_when_no_setter(self): + cfg = SimpleConfig({"host": "direct"}) + assert cfg.host == "direct" + + def test_dotted_qname_reads_nested_config(self): + class DottedConfig(TomlConfig): + _section_ = "Outer" + _fields_ = [ + Attribute("inner_val", "Inner.key", default_val="default"), + ] + + cfg = DottedConfig({"Inner": {"key": "nested_value"}}) + assert cfg.inner_val == "nested_value" + + +# --------------------------------------------------------------------------- +# TomlConfig.__getitem__ +# --------------------------------------------------------------------------- + + +class TestGetItem: + def test_getitem_by_attr_name(self): + cfg = SimpleConfig({"host": "h1"}) + assert cfg["host"] == "h1" + + def test_getitem_by_qname(self): + cfg = SimpleConfig({"port": "7777"}) + assert cfg["port"] == 7777 + + def test_getitem_missing_raises_key_error(self): + cfg = SimpleConfig({}) + with pytest.raises(KeyError): + _ = cfg["nonexistent"] + + +# --------------------------------------------------------------------------- +# TomlConfig.build_config - reads from global config +# --------------------------------------------------------------------------- + + +class TestBuildConfig: + def test_build_config_uses_global_config(self): + with patch( + "dementor.config.util.get_global_config", + return_value={"Simple": {"host": "fromglobal"}}, + ): + cfg = TomlConfig.build_config(SimpleConfig) + assert cfg.host == "fromglobal" + + def test_build_config_empty_section_uses_defaults(self): + with patch("dementor.config.util.get_global_config", return_value={}): + cfg = TomlConfig.build_config(SimpleConfig) + assert cfg.host == "localhost" + + def test_build_config_section_override(self): + with patch( + "dementor.config.util.get_global_config", + return_value={"Alt": {"host": "althost"}}, + ): + cfg = TomlConfig.build_config(SimpleConfig, section="Alt") + assert cfg.host == "althost" + + def test_build_config_raises_when_section_none(self): + class NoSection(TomlConfig): + _section_ = "" + _fields_ = [] + + with pytest.raises(ValueError, match="section cannot be None"): + TomlConfig.build_config(NoSection) + + +# --------------------------------------------------------------------------- +# TomlConfig.as_dict / __repr__ +# --------------------------------------------------------------------------- + + +class TestAsDict: + def test_as_dict_returns_all_fields(self): + cfg = SimpleConfig({"host": "myhost", "port": "9000"}) + d = cfg.as_dict() + assert d == {"host": "myhost", "port": 9000} + + def test_repr_contains_field_values(self): + cfg = SimpleConfig({"host": "repr_test"}) + assert "repr_test" in repr(cfg) + + +# --------------------------------------------------------------------------- +# get_value - utility function +# --------------------------------------------------------------------------- + + +class TestGetValue: + def test_returns_default_when_key_missing(self): + with patch("dementor.config.util.get_global_config", return_value={}): + result = get_value("NoSuchSection", "key", default="fallback") + assert result == "fallback" + + def test_returns_value_from_section(self): + with patch( + "dementor.config.util.get_global_config", + return_value={"NTLM": {"Challenge": "aabbccdd"}}, + ): + result = get_value("NTLM", "Challenge", default=None) + assert result == "aabbccdd" + + def test_returns_none_default_when_not_specified(self): + with patch("dementor.config.util.get_global_config", return_value={}): + result = get_value("Missing", "key") + assert result is None + + def test_returns_whole_section_when_key_none(self): + section_data = {"Port": 443, "SSL": True} + with patch( + "dementor.config.util.get_global_config", return_value={"LDAP": section_data} + ): + result = get_value("LDAP", key=None, default={}) + assert result == section_data + + def test_dotted_section_path(self): + config = {"HTTP": {"server": {"Port": 80}}} + with patch("dementor.config.util.get_global_config", return_value=config): + result = get_value("HTTP.server", "Port", default=0) + assert result == 80 + + def test_missing_nested_path_returns_default(self): + with patch("dementor.config.util.get_global_config", return_value={}): + result = get_value("A.B.C", "key", default=42) + assert result == 42 + + +# --------------------------------------------------------------------------- +# is_true - bool coercion +# --------------------------------------------------------------------------- + + +class TestIsTrue: + @pytest.mark.parametrize( + "truthy", ["true", "True", "TRUE", "1", "on", "ON", "yes", "YES"] + ) + def test_truthy_values(self, truthy): + assert is_true(truthy) is True + + @pytest.mark.parametrize( + "falsy", ["false", "False", "0", "off", "no", "", "random", "2"] + ) + def test_falsy_values(self, falsy): + assert is_true(falsy) is False + + def test_non_string_input(self): + # is_true coerces via str() so non-string inputs should work + assert is_true(1) is True + assert is_true(0) is False + + +# --------------------------------------------------------------------------- +# get_global_config / _set_global_config +# --------------------------------------------------------------------------- + + +class TestGlobalConfig: + def test_set_and_get_global_config(self): + original = get_global_config() + try: + _set_global_config({"test_key": "test_val"}) + assert get_global_config()["test_key"] == "test_val" + finally: + _set_global_config(original) + + def test_global_config_returns_dict(self): + result = get_global_config() + assert isinstance(result, dict) + + def test_global_config_isolation(self): + original = get_global_config() + _set_global_config({"isolated": True}) + assert get_global_config().get("isolated") is True + _set_global_config(original) + assert "isolated" not in get_global_config() + + +# --------------------------------------------------------------------------- +# GlobalFallbackConfig - section_local=False reads Globals +# --------------------------------------------------------------------------- + + +class TestGlobalFallback: + def test_global_fallback_used_when_section_missing(self): + config = {"Globals": {"timeout": 60}} + with patch("dementor.config.util.get_global_config", return_value=config): + cfg = TomlConfig.build_config(GlobalFallbackConfig) + assert cfg.timeout == 60 + + def test_section_value_overrides_global(self): + config = {"Proto": {"timeout": 10}, "Globals": {"timeout": 60}} + with patch("dementor.config.util.get_global_config", return_value=config): + cfg = TomlConfig.build_config(GlobalFallbackConfig) + assert cfg.timeout == 10 diff --git a/tests/test_db.py b/tests/test_db.py index b9b8dcf..c76842b 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -4,8 +4,6 @@ All tests use SQLite :memory: with StaticPool -- no external DB required. """ -from __future__ import annotations - import json import os import tempfile diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 0000000..a984c5f --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,351 @@ +# Copyright (c) 2025-Present MatrixEditor +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Unit tests for dementor.filters. + +Covers: +- ``FilterObj``: literal, regex, and glob pattern matching +- ``Filters``: construction from strings, dicts (Target/File), and membership test +- ``in_scope``: whitelist-only, blacklist-only, combined, and no-filter paths +""" + +import sys +import tempfile +import textwrap +from pathlib import Path + +import pytest + +from dementor.filters import FilterObj, Filters, in_scope + + +# --------------------------------------------------------------------------- +# FilterObj - literal matching +# --------------------------------------------------------------------------- + + +class TestFilterObjLiteral: + def test_exact_match(self): + f = FilterObj("host1") + assert f.matches("host1") is True + + def test_no_match(self): + f = FilterObj("host1") + assert f.matches("host2") is False + + def test_empty_target_matches_empty_string(self): + f = FilterObj("") + assert f.matches("") is True + + def test_case_sensitive(self): + f = FilterObj("HOST1") + assert f.matches("host1") is False + + def test_ip_address_literal(self): + f = FilterObj("192.168.1.1") + assert f.matches("192.168.1.1") is True + assert f.matches("192.168.1.2") is False + + def test_from_string_factory(self): + f = FilterObj("host99") + assert f.matches("host99") is True + assert f.matches("host1") is False + + +# --------------------------------------------------------------------------- +# FilterObj - regex matching +# --------------------------------------------------------------------------- + + +class TestFilterObjRegex: + def test_regex_prefix(self): + f = FilterObj(r"re:.*\.example\.com") + assert f.matches("api.example.com") is True + assert f.matches("www.example.com") is True + + def test_regex_no_match(self): + f = FilterObj(r"re:.*\.example\.com") + assert f.matches("attacker.evil.com") is False + + def test_regex_ip_range(self): + f = FilterObj(r"re:192\.168\.1\.[0-9]+") + assert f.matches("192.168.1.100") is True + assert f.matches("10.0.0.1") is False + + def test_regex_target_stripped_of_prefix(self): + f = FilterObj(r"re:^admin$") + assert f.target == "^admin$" + + def test_regex_pattern_not_none(self): + f = FilterObj(r"re:foo") + assert f.pattern is not None + + def test_literal_pattern_is_none(self): + f = FilterObj("foo") + assert f.pattern is None + + +# --------------------------------------------------------------------------- +# FilterObj - glob matching (Python 3.13+) +# --------------------------------------------------------------------------- + + +class TestFilterObjGlob: + @pytest.mark.skipif( + (sys.version_info.major, sys.version_info.minor) < (3, 13), + reason="glob.translate requires Python 3.13+", + ) + def test_glob_wildcard(self): + f = FilterObj("g:*.example.com") + assert f.matches("api.example.com") is True + assert f.matches("www.example.com") is True + assert f.matches("evil.net") is False + + @pytest.mark.skipif( + (sys.version_info.major, sys.version_info.minor) >= (3, 13), + reason="This test covers the <3.13 fallback path only", + ) + def test_glob_pre_313_falls_back_to_literal(self): + with pytest.warns(UserWarning, match="glob.translate"): + f = FilterObj("g:*.example.com") + # In fallback mode, pattern is None and matches uses exact string compare + assert f.pattern is None + # Target has the prefix stripped + assert f.target == "*.example.com" + + def test_glob_target_stripped_of_prefix(self): + if (sys.version_info.major, sys.version_info.minor) < (3, 13): + with pytest.warns(UserWarning): # noqa: PT030 + f = FilterObj("g:*.local") + else: + f = FilterObj("g:*.local") + assert f.target == "*.local" + + +# --------------------------------------------------------------------------- +# FilterObj.from_file +# --------------------------------------------------------------------------- + + +class TestFilterObjFromFile: + def test_from_file_loads_patterns(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as fh: + fh.write( + textwrap.dedent("""\ + host1 + host2 + re:.*\\.admin\\. + """) + ) + tmppath = fh.name + try: + filters = FilterObj.from_file(tmppath, extra=None) + assert len(filters) == 3 + assert any(f.matches("host1") for f in filters) + assert any(f.matches("host2") for f in filters) + finally: + Path(tmppath).unlink(missing_ok=True) + + def test_from_file_nonexistent_returns_empty(self): + result = FilterObj.from_file("/nonexistent/path/targets.txt", extra=None) + assert result == [] + + def test_from_file_attaches_extra_to_each_filter(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as fh: + fh.write("myhost\n") + tmppath = fh.name + try: + extra_meta = {"source": "test"} + filters = FilterObj.from_file(tmppath, extra=extra_meta) + assert filters[0].extra == extra_meta + finally: + Path(tmppath).unlink(missing_ok=True) + + +# --------------------------------------------------------------------------- +# Filters - construction +# --------------------------------------------------------------------------- + + +class TestFiltersConstruction: + def test_from_string_list(self): + f = Filters(["host1", "host2"]) + assert len(f.filters) == 2 + + def test_from_string_list_skips_empty_strings(self): + f = Filters(["host1", "", "host2"]) + assert len(f.filters) == 2 + + def test_from_dict_with_target_key(self): + f = Filters([{"Target": "host3", "reason": "admin"}]) + assert len(f.filters) == 1 + assert f.filters[0].matches("host3") + assert f.filters[0].extra == {"Target": "host3", "reason": "admin"} + + def test_from_dict_with_file_key(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as fh: + fh.write("filehost1\nfilehost2\n") + tmppath = fh.name + try: + f = Filters([{"File": tmppath}]) + assert len(f.filters) == 2 + finally: + Path(tmppath).unlink(missing_ok=True) + + def test_dict_missing_target_and_file_skipped(self): + f = Filters([{"other_key": "value"}]) + assert len(f.filters) == 0 + + def test_mixed_string_and_dict(self): + f = Filters(["host1", {"Target": "host2"}]) + assert len(f.filters) == 2 + + def test_empty_config(self): + f = Filters([]) + assert len(f.filters) == 0 + + +# --------------------------------------------------------------------------- +# Filters - matching +# --------------------------------------------------------------------------- + + +class TestFiltersMatching: + def test_contains_true_for_match(self): + f = Filters(["host1", "host2"]) + assert "host1" in f + assert "host2" in f + + def test_contains_false_for_non_match(self): + f = Filters(["host1"]) + assert "host99" not in f + + def test_get_matched_returns_all_matching_filters(self): + f = Filters([r"re:host[0-9]", "host5"]) + matches = f.get_matched("host5") + # both the regex and the literal match "host5" + assert len(matches) == 2 + + def test_get_first_match_returns_first(self): + f = Filters(["host1", r"re:host.*"]) + first = f.get_first_match("host1") + assert first is not None + assert first.matches("host1") + + def test_get_first_match_returns_none_when_no_match(self): + f = Filters(["host1"]) + assert f.get_first_match("host99") is None + + def test_has_match_true(self): + f = Filters(["host1"]) + assert f.has_match("host1") is True + + def test_has_match_false(self): + f = Filters(["host1"]) + assert f.has_match("host99") is False + + def test_regex_filter_in_contains(self): + f = Filters([r"re:192\.168\..*"]) + assert "192.168.1.100" in f + assert "10.0.0.1" not in f + + +# --------------------------------------------------------------------------- +# in_scope - whitelist / blacklist logic +# --------------------------------------------------------------------------- + + +class _ScopeConfig: + """Minimal config stub for in_scope tests.""" + + def __init__(self, targets=None, ignored=None): + if targets is not None: + self.targets = targets + if ignored is not None: + self.ignored = ignored + + +class TestInScope: + def test_no_filters_always_in_scope(self): + cfg = _ScopeConfig() + assert in_scope("anything", cfg) is True + + def test_whitelist_only_passes_match(self): + cfg = _ScopeConfig(targets=Filters(["host1", "host2"])) + assert in_scope("host1", cfg) is True + assert in_scope("host2", cfg) is True + + def test_whitelist_only_blocks_non_match(self): + cfg = _ScopeConfig(targets=Filters(["host1"])) + assert in_scope("host99", cfg) is False + + def test_blacklist_only_blocks_match(self): + cfg = _ScopeConfig(ignored=Filters(["host1"])) + assert in_scope("host1", cfg) is False + + def test_blacklist_only_passes_non_match(self): + cfg = _ScopeConfig(ignored=Filters(["host1"])) + assert in_scope("host99", cfg) is True + + def test_whitelist_and_blacklist_combined(self): + cfg = _ScopeConfig( + targets=Filters(["host1", "host2"]), + ignored=Filters(["host1"]), + ) + # host1 is whitelisted but also blacklisted -> out of scope + assert in_scope("host1", cfg) is False + # host2 is whitelisted and not blacklisted -> in scope + assert in_scope("host2", cfg) is True + # host3 is not whitelisted -> out of scope + assert in_scope("host3", cfg) is False + + def test_none_targets_treated_as_no_whitelist(self): + cfg = _ScopeConfig(targets=None, ignored=Filters(["bad"])) + assert in_scope("good", cfg) is True + assert in_scope("bad", cfg) is False + + def test_none_ignored_treated_as_no_blacklist(self): + cfg = _ScopeConfig(targets=Filters(["good"]), ignored=None) + assert in_scope("good", cfg) is True + assert in_scope("bad", cfg) is False + + def test_config_without_targets_attribute(self): + class NoTargets: + ignored = Filters(["blocked"]) + + assert in_scope("allowed", NoTargets()) is True + assert in_scope("blocked", NoTargets()) is False + + def test_config_without_ignored_attribute(self): + class NoIgnored: + targets = Filters(["allowed"]) + + assert in_scope("allowed", NoIgnored()) is True + assert in_scope("other", NoIgnored()) is False + + def test_whitelist_with_regex(self): + cfg = _ScopeConfig(targets=Filters([r"re:10\.0\.0\.[0-9]+"])) + assert in_scope("10.0.0.1", cfg) is True + assert in_scope("10.0.0.255", cfg) is True + assert in_scope("192.168.1.1", cfg) is False + + def test_blacklist_with_regex(self): + cfg = _ScopeConfig(ignored=Filters([r"re:.*\.internal\."])) + assert in_scope("host.internal.corp", cfg) is False + assert in_scope("host.external.corp", cfg) is True diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 0000000..ca1f26c --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,229 @@ +# Copyright (c) 2025-Present MatrixEditor +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Tests for dementor/log/logger.py - ProtocolLogger and LoggingConfig.""" + +import pytest + +from unittest.mock import patch + +from dementor.log.logger import ProtocolLogger, LoggingConfig + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def logger(): + """Return a fresh ProtocolLogger with no extra context.""" + return ProtocolLogger() + + +@pytest.fixture +def logger_with_context(): + """Return a ProtocolLogger pre-configured with protocol context.""" + return ProtocolLogger( + extra={ + "protocol": "SMB", + "protocol_color": "cyan", + "host": "192.168.1.1", + "port": "445", + } + ) + + +# --------------------------------------------------------------------------- +# ProtocolLogger init tests +# --------------------------------------------------------------------------- + + +class TestProtocolLoggerInit: + def test_init_no_extra(self): + lg = ProtocolLogger() + assert lg.extra == {} + + def test_init_with_extra(self): + extra = {"protocol": "HTTP", "host": "10.0.0.1"} + lg = ProtocolLogger(extra=extra) + assert lg.extra["protocol"] == "HTTP" + assert lg.extra["host"] == "10.0.0.1" + + def test_init_none_extra_uses_empty_dict(self): + lg = ProtocolLogger(extra=None) + assert lg.extra == {} + + def test_logger_name(self, logger): + assert logger.logger.name == "dementor" + + +# --------------------------------------------------------------------------- +# Accessor tests +# --------------------------------------------------------------------------- + + +class TestProtocolLoggerAccessors: + def test_get_protocol_name_default_empty(self, logger): + assert logger.get_protocol_name() == "" + + def test_get_protocol_name_from_context(self, logger_with_context): + assert logger_with_context.get_protocol_name() == "SMB" + + def test_get_protocol_color_default_white(self, logger): + assert logger.get_protocol_color() == "white" + + def test_get_protocol_color_from_context(self, logger_with_context): + assert logger_with_context.get_protocol_color() == "cyan" + + def test_get_host_default_empty(self, logger): + assert logger.get_host() == "" + + def test_get_host_from_context(self, logger_with_context): + assert logger_with_context.get_host() == "192.168.1.1" + + def test_get_port_default_empty(self, logger): + assert logger.get_port() == "" + + def test_get_port_from_context(self, logger_with_context): + assert logger_with_context.get_port() == "445" + + def test_get_extra_from_per_call_extra(self, logger): + per_call = {"protocol": "SMTP"} + result = logger._get_extra("protocol", per_call, "default") + assert result == "SMTP" + # per_call should have had the key popped + assert "protocol" not in per_call + + def test_get_extra_falls_back_to_default(self, logger): + result = logger._get_extra("nonexistent_key", None, "fallback") + assert result == "fallback" + + def test_get_extra_per_call_overrides_context(self, logger_with_context): + per_call = {"protocol": "FTP"} + result = logger_with_context._get_extra("protocol", per_call) + assert result == "FTP" + + +# --------------------------------------------------------------------------- +# Format tests +# --------------------------------------------------------------------------- + + +class TestProtocolLoggerFormat: + def test_format_returns_tuple(self, logger): + msg, kwargs = logger.format("test message") + assert isinstance(msg, str) + assert isinstance(kwargs, dict) + + def test_format_with_context(self, logger_with_context): + msg, _ = logger_with_context.format("hello") + assert isinstance(msg, str) + assert "hello" in msg + + def test_format_inline_returns_tuple(self, logger): + result = logger.format_inline("test", {}) + assert isinstance(result, tuple) + msg, _ = result + assert isinstance(msg, str) + + def test_format_inline_with_context(self, logger_with_context): + result = logger_with_context.format_inline("inline message", {}) + msg, _ = result + assert isinstance(msg, str) + assert "inline message" in msg + + +# --------------------------------------------------------------------------- +# Log method tests +# --------------------------------------------------------------------------- + + +class TestProtocolLoggerLogMethods: + def test_log_method_exists(self, logger): + assert callable(logger.log) + + def test_debug_method_exists(self, logger): + assert callable(logger.debug) + + def test_info_method_exists(self, logger): + assert callable(logger.info) + + def test_warning_method_exists(self, logger): + assert callable(logger.warning) + + def test_success_method_exists(self, logger): + assert callable(logger.success) + + def test_display_method_exists(self, logger): + assert callable(logger.display) + + def test_highlight_method_exists(self, logger): + assert callable(logger.highlight) + + def test_fail_method_exists(self, logger): + assert callable(logger.fail) + + def test_log_does_not_raise_with_basic_message(self, logger): + with patch.object(logger.logger, "log"): + logger.debug("test message") + + def test_log_accepts_is_client_kwarg(self, logger): + with patch.object(logger.logger, "log"): + logger.debug("test", is_client=True) + + def test_log_accepts_is_server_kwarg(self, logger): + with patch.object(logger.logger, "log"): + logger.info("test", is_server=True) + + +# --------------------------------------------------------------------------- +# log_config lazy load test +# --------------------------------------------------------------------------- + + +class TestProtocolLoggerLogConfig: + def test_log_config_returns_logging_config(self, logger): + cfg = logger.log_config + assert isinstance(cfg, LoggingConfig) + + def test_log_config_is_cached(self, logger): + cfg1 = logger.log_config + cfg2 = logger.log_config + assert cfg1 is cfg2 + + +# --------------------------------------------------------------------------- +# add_logfile test +# --------------------------------------------------------------------------- + + +class TestProtocolLoggerAddLogfile: + def test_add_logfile_method_exists(self, logger): + assert callable(logger.add_logfile) + + def test_add_logfile_adds_rotating_handler(self, logger, tmp_path): + log_path = str(tmp_path / "test.log") + logger.add_logfile(log_path) + handler_types = [type(h).__name__ for h in logger.logger.handlers] + assert "RotatingFileHandler" in handler_types + # Cleanup + for h in logger.logger.handlers[:]: + if hasattr(h, "baseFilename"): + h.close() + logger.logger.removeHandler(h) diff --git a/tests/test_ntlm.py b/tests/test_ntlm.py index 07784eb..809782e 100644 --- a/tests/test_ntlm.py +++ b/tests/test_ntlm.py @@ -1,12 +1,10 @@ -"""Unit tests for dementor.protocols.ntlm — NTLM authentication helpers. +"""Unit tests for dementor.protocols.ntlm - NTLM authentication helpers. Tests cover every public and private function in ntlm.py, organized by tier: Tier 1 (pure functions): no mocking needed Tier 2 (mock-dependent): require impacket objects or MagicMock """ -from __future__ import annotations - import struct from unittest.mock import MagicMock @@ -24,14 +22,14 @@ NTLM_V2, NTLM_V2_LM, NTLM_VERSION_PLACEHOLDER, - NTLM_build_challenge_message, - NTLM_decode_string, - NTLM_encode_string, - NTLM_handle_authenticate_message, - NTLM_handle_legacy_raw_auth, - NTLM_handle_negotiate_message, - NTLM_timestamp, - NTLM_to_hashcat, + ntlm_build_challenge_message, + ntlm_decode_string, + ntlm_encode_string, + ntlm_handle_authenticate_message, + ntlm_handle_legacy_raw_auth, + ntlm_handle_negotiate_message, + ntlm_timestamp, + ntlm_to_hashcat, _classify_hash_type, _compute_dummy_lm_responses, _config_version_to_bytes, @@ -144,17 +142,17 @@ def _build_ntlm_authenticate( class TestNTLMTimestamp: - """NTLM_timestamp() at line 1607.""" + """ntlm_timestamp() at line 1607.""" def test_returns_positive_int(self): - assert NTLM_timestamp() > 0 + assert ntlm_timestamp() > 0 def test_after_epoch_offset(self): - assert NTLM_timestamp() > NTLM_FILETIME_EPOCH_OFFSET + assert ntlm_timestamp() > NTLM_FILETIME_EPOCH_OFFSET def test_monotonic(self): - t1 = NTLM_timestamp() - t2 = NTLM_timestamp() + t1 = ntlm_timestamp() + t2 = ntlm_timestamp() assert t2 >= t1 @@ -229,7 +227,7 @@ def test_version(self, value, expected): class TestNTLMDecodeString: - """NTLM_decode_string(data, negotiate_flags, is_negotiate_oem) at line 357.""" + """ntlm_decode_string(data, negotiate_flags, is_negotiate_oem) at line 357.""" UNICODE = ntlm.NTLMSSP_NEGOTIATE_UNICODE @@ -264,12 +262,12 @@ class TestNTLMDecodeString: ], ) def test_decode(self, data, flags, is_oem, expected): - result = NTLM_decode_string(data, flags, is_oem) + result = ntlm_decode_string(data, flags, is_oem) assert result == expected class TestNTLMEncodeString: - """NTLM_encode_string(string, negotiate_flags) at line 397.""" + """ntlm_encode_string(string, negotiate_flags) at line 397.""" @pytest.mark.parametrize( ("string", "flags", "expected"), @@ -291,13 +289,13 @@ class TestNTLMEncodeString: ], ) def test_encode(self, string, flags, expected): - assert NTLM_encode_string(string, flags) == expected + assert ntlm_encode_string(string, flags) == expected def test_roundtrip_unicode(self): flags = ntlm.NTLMSSP_NEGOTIATE_UNICODE original = "Test123" - encoded = NTLM_encode_string(original, flags) - decoded = NTLM_decode_string(encoded, flags) + encoded = ntlm_encode_string(original, flags) + decoded = ntlm_decode_string(encoded, flags) assert decoded == original @@ -388,7 +386,7 @@ def test_contains_desl_of_null_hash(self): class TestNTLMToHashcat: - """NTLM_to_hashcat(...) at line 1231 — THE MOST CRITICAL function.""" + """ntlm_to_hashcat(...) at line 1231 - THE MOST CRITICAL function.""" # -- NetNTLMv2 (hashcat -m 5600) ----------------------------------------- @@ -396,7 +394,7 @@ def test_v2_primary_hash_format(self): nt_proof = b"\xaa" * 16 blob = b"\xbb" * 32 nt_response = nt_proof + blob - result = NTLM_to_hashcat( + result = ntlm_to_hashcat( CHALLENGE, "user", "domain", b"\x00" * 24, nt_response, 0 ) assert len(result) == 1 # Z(24) LM suppressed @@ -416,7 +414,7 @@ def test_v2_with_lmv2_companion(self): lm_proof = b"\xcc" * 16 lm_cchal = b"\xdd" * 8 lm_response = lm_proof + lm_cchal - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) assert len(result) == 2 assert result[0][0] == NTLM_V2 assert result[1][0] == NTLM_V2_LM @@ -427,19 +425,19 @@ def test_v2_with_lmv2_companion(self): def test_v2_lm_suppressed_when_null(self): nt_response = b"\xaa" * 48 lm_response = b"\x00" * 24 - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) assert len(result) == 1 assert result[0][0] == NTLM_V2 def test_v2_lm_wrong_length_skipped(self): nt_response = b"\xaa" * 48 lm_response = b"\xcc" * 16 # wrong length - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) assert len(result) == 1 def test_v2_server_challenge_hex_16_chars(self): nt_response = b"\xaa" * 48 - result = NTLM_to_hashcat( + result = ntlm_to_hashcat( CHALLENGE, "user", "domain", b"\x00" * 24, nt_response, 0 ) parts = result[0][1].split(":") @@ -447,7 +445,7 @@ def test_v2_server_challenge_hex_16_chars(self): def test_v2_ntproofstr_hex_32_chars(self): nt_response = b"\xaa" * 48 - result = NTLM_to_hashcat( + result = ntlm_to_hashcat( CHALLENGE, "user", "domain", b"\x00" * 24, nt_response, 0 ) parts = result[0][1].split(":") @@ -455,7 +453,7 @@ def test_v2_ntproofstr_hex_32_chars(self): def test_v2_user_domain_are_strings(self): nt_response = b"\xaa" * 48 - result = NTLM_to_hashcat(CHALLENGE, "Admin", "CORP", b"\x00" * 24, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "Admin", "CORP", b"\x00" * 24, nt_response, 0) parts = result[0][1].split(":") assert parts[0] == "Admin" assert parts[2] == "CORP" @@ -463,7 +461,7 @@ def test_v2_user_domain_are_strings(self): def test_v2_user_as_bytes_decoded(self): nt_response = b"\xaa" * 48 user_bytes = "Admin".encode("utf-16-le") - result = NTLM_to_hashcat( + result = ntlm_to_hashcat( CHALLENGE, user_bytes, "CORP", @@ -480,7 +478,7 @@ def test_v1ess_hash_format(self): client_challenge = b"\xdd" * 8 lm_response = client_challenge + b"\x00" * 16 nt_response = b"\xee" * 24 - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) assert len(result) == 1 label, line = result[0] assert label == NTLM_V1_ESS @@ -490,7 +488,7 @@ def test_v1ess_hash_format(self): def test_v1ess_lm_field_48_hex(self): lm_response = b"\xdd" * 8 + b"\x00" * 16 nt_response = b"\xee" * 24 - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) parts = result[0][1].split(":") # LM field = CChal(8) + Z(16) = 24 bytes = 48 hex chars assert len(parts[3]) == 48 @@ -498,14 +496,14 @@ def test_v1ess_lm_field_48_hex(self): def test_v1ess_nt_field_48_hex(self): lm_response = b"\xdd" * 8 + b"\x00" * 16 nt_response = b"\xee" * 24 - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) parts = result[0][1].split(":") assert len(parts[4]) == 48 # 24 bytes = 48 hex chars def test_v1ess_server_challenge_raw(self): lm_response = b"\xdd" * 8 + b"\x00" * 16 nt_response = b"\xee" * 24 - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) parts = result[0][1].split(":") # Must be raw ServerChallenge, NOT pre-computed FinalChallenge assert parts[5] == CHALLENGE.hex() @@ -515,7 +513,7 @@ def test_v1ess_server_challenge_raw(self): def test_v1_with_real_lm(self): nt_response = b"\xaa" * 24 lm_response = b"\xbb" * 24 - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) assert len(result) == 1 label, line = result[0] assert label == NTLM_V1 @@ -524,21 +522,21 @@ def test_v1_with_real_lm(self): def test_v1_level2_duplication_lm_empty(self): shared = b"\xaa" * 24 - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", shared, shared, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", shared, shared, 0) parts = result[0][1].split(":") assert parts[3] == "" # LM slot empty def test_v1_dummy_lm_null_hash(self): nt_response = b"\xaa" * 24 dummy_null = ntlm.ntlmssp_DES_encrypt(NTLM_ESS_ZERO_PAD, CHALLENGE) - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", dummy_null, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", dummy_null, nt_response, 0) parts = result[0][1].split(":") assert parts[3] == "" def test_v1_dummy_lm_default_hash(self): nt_response = b"\xaa" * 24 dummy_default = ntlm.ntlmssp_DES_encrypt(ntlm.DEFAULT_LM_HASH, CHALLENGE) - result = NTLM_to_hashcat( + result = ntlm_to_hashcat( CHALLENGE, "user", "domain", dummy_default, nt_response, 0 ) parts = result[0][1].split(":") @@ -547,31 +545,31 @@ def test_v1_dummy_lm_default_hash(self): def test_v1_hashcat_format_six_tokens(self): nt_response = b"\xaa" * 24 lm_response = b"\xbb" * 24 - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", lm_response, nt_response, 0) parts = result[0][1].split(":") assert len(parts) == 6 # -- Edge cases ---------------------------------------------------------- def test_empty_nt_response_returns_empty(self): - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", b"", b"", 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", b"", b"", 0) assert result == [] def test_none_nt_response_returns_empty(self): - result = NTLM_to_hashcat(CHALLENGE, "user", "domain", None, None, 0) + result = ntlm_to_hashcat(CHALLENGE, "user", "domain", None, None, 0) assert result == [] def test_bad_challenge_7_raises(self): with pytest.raises(ValueError, match="8 bytes"): - NTLM_to_hashcat(b"\x00" * 7, "u", "d", b"", b"\xaa" * 24, 0) + ntlm_to_hashcat(b"\x00" * 7, "u", "d", b"", b"\xaa" * 24, 0) def test_bad_challenge_9_raises(self): with pytest.raises(ValueError, match="8 bytes"): - NTLM_to_hashcat(b"\x00" * 9, "u", "d", b"", b"\xaa" * 24, 0) + ntlm_to_hashcat(b"\x00" * 9, "u", "d", b"", b"\xaa" * 24, 0) def test_user_as_string(self): nt_response = b"\xaa" * 48 - result = NTLM_to_hashcat( + result = ntlm_to_hashcat( CHALLENGE, "TestUser", "TestDomain", b"\x00" * 24, nt_response, 0 ) parts = result[0][1].split(":") @@ -614,17 +612,17 @@ def test_non_anonymous_with_user(self): ) assert _is_anonymous_authenticate(token) is False - def test_non_anonymous_with_nt(self): + def test_anonymous_with_nt(self): token = _build_ntlm_authenticate( user_name=b"", nt_response=b"\xaa" * 24, lm_response=b"" ) - assert _is_anonymous_authenticate(token) is False + assert _is_anonymous_authenticate(token) is True - def test_non_anonymous_with_lm(self): + def test_anonymous_with_lm(self): token = _build_ntlm_authenticate( user_name=b"", nt_response=b"", lm_response=b"\xbb" * 24 ) - assert _is_anonymous_authenticate(token) is False + assert _is_anonymous_authenticate(token) is True def test_exception_returns_false(self): """Fail-open: parse error returns False so we don't drop captures.""" @@ -698,7 +696,7 @@ def test_malformed_version(self): class TestNTLMBuildChallengeMessage: - """NTLM_build_challenge_message(token, *, ...) at line 587.""" + """ntlm_build_challenge_message(token, *, ...) at line 587.""" def _build(self, client_flags: int, **kwargs): token = _build_ntlm_negotiate(client_flags) @@ -708,7 +706,7 @@ def _build(self, client_flags: int, **kwargs): "nb_domain": "WORKGROUP", } defaults.update(kwargs) - return NTLM_build_challenge_message(token, **defaults) + return ntlm_build_challenge_message(token, **defaults) def test_challenge_in_response(self): msg = self._build(ntlm.NTLMSSP_NEGOTIATE_UNICODE) @@ -717,7 +715,7 @@ def test_challenge_in_response(self): def test_bad_challenge_length_raises(self): token = _build_ntlm_negotiate(ntlm.NTLMSSP_NEGOTIATE_UNICODE) with pytest.raises(ValueError, match="8 bytes"): - NTLM_build_challenge_message(token, challenge=b"\x00" * 7) + ntlm_build_challenge_message(token, challenge=b"\x00" * 7) def test_unicode_flag_echoed(self): msg = self._build(ntlm.NTLMSSP_NEGOTIATE_UNICODE) @@ -831,21 +829,21 @@ def test_version_echoed(self): data = neg.getData() token = ntlm.NTLMAuthNegotiate() token.fromString(data) - msg = NTLM_build_challenge_message(token, challenge=CHALLENGE) + msg = ntlm_build_challenge_message(token, challenge=CHALLENGE) assert msg["flags"] & ntlm.NTLMSSP_NEGOTIATE_VERSION class TestNTLMHandleNegotiateMessage: - """NTLM_handle_negotiate_message(negotiate, logger) at line 517.""" + """ntlm_handle_negotiate_message(negotiate, logger) at line 517.""" def test_returns_dict(self, mock_logger): neg = _build_ntlm_negotiate(ntlm.NTLMSSP_NEGOTIATE_UNICODE) - result = NTLM_handle_negotiate_message(neg, mock_logger) + result = ntlm_handle_negotiate_message(neg, mock_logger) assert isinstance(result, dict) def test_empty_fields_omitted(self, mock_logger): neg = _build_ntlm_negotiate(ntlm.NTLMSSP_NEGOTIATE_UNICODE) - result = NTLM_handle_negotiate_message(neg, mock_logger) + result = ntlm_handle_negotiate_message(neg, mock_logger) # Minimal negotiate has no workstation/domain for k in ("name", "domain"): if k in result: @@ -853,13 +851,13 @@ def test_empty_fields_omitted(self, mock_logger): def test_logger_debug_called(self, mock_logger): neg = _build_ntlm_negotiate(ntlm.NTLMSSP_NEGOTIATE_UNICODE) - NTLM_handle_negotiate_message(neg, mock_logger) + ntlm_handle_negotiate_message(neg, mock_logger) assert mock_logger.debug.called def test_no_version_no_os_key(self, mock_logger): # Without VERSION flag, os field should be empty/absent neg = _build_ntlm_negotiate(ntlm.NTLMSSP_NEGOTIATE_UNICODE) - result = NTLM_handle_negotiate_message(neg, mock_logger) + result = ntlm_handle_negotiate_message(neg, mock_logger) if "os" in result: assert result["os"] == "" @@ -869,16 +867,16 @@ def test_malformed_no_crash(self, mock_logger): token.__getitem__ = MagicMock(side_effect=KeyError("bad")) token.fields = {} # Should not raise - result = NTLM_handle_negotiate_message(token, mock_logger) + result = ntlm_handle_negotiate_message(token, mock_logger) assert isinstance(result, dict) class TestNTLMHandleAuthenticateMessage: - """NTLM_handle_authenticate_message(auth_token, *, ...) at line 909.""" + """ntlm_handle_authenticate_message(auth_token, *, ...) at line 909.""" def test_anonymous_returns_false(self, mock_logger, mock_session): token = _build_ntlm_authenticate(user_name=b"", nt_response=b"", lm_response=b"") - result = NTLM_handle_authenticate_message( + result = ntlm_handle_authenticate_message( token, challenge=CHALLENGE, client=("10.0.0.1", 12345), @@ -898,7 +896,7 @@ def test_valid_v2_returns_true(self, mock_logger, mock_session): nt_response=nt_response, lm_response=lm_response, ) - result = NTLM_handle_authenticate_message( + result = ntlm_handle_authenticate_message( token, challenge=CHALLENGE, client=("10.0.0.1", 12345), @@ -915,7 +913,7 @@ def test_valid_v1_returns_true(self, mock_logger, mock_session): nt_response=b"\xaa" * 24, lm_response=b"\xbb" * 24, ) - result = NTLM_handle_authenticate_message( + result = ntlm_handle_authenticate_message( token, challenge=CHALLENGE, client=("10.0.0.1", 12345), @@ -931,7 +929,7 @@ def test_empty_nt_response_returns_false(self, mock_logger, mock_session): nt_response=b"", lm_response=b"", ) - result = NTLM_handle_authenticate_message( + result = ntlm_handle_authenticate_message( token, challenge=CHALLENGE, client=("10.0.0.1", 12345), @@ -951,7 +949,7 @@ def test_v2_with_lmv2_companion_calls_db_twice(self, mock_logger, mock_session): nt_response=nt_response, lm_response=lm_response, ) - NTLM_handle_authenticate_message( + ntlm_handle_authenticate_message( token, challenge=CHALLENGE, client=("10.0.0.1", 12345), @@ -966,7 +964,7 @@ def test_bad_challenge_returns_false(self, mock_logger, mock_session): user_name="admin".encode("utf-16-le"), nt_response=b"\xaa" * 24, ) - result = NTLM_handle_authenticate_message( + result = ntlm_handle_authenticate_message( token, challenge=b"\x00" * 7, # bad client=("10.0.0.1", 12345), @@ -983,7 +981,7 @@ def test_extras_passed_through(self, mock_logger, mock_session): lm_response=b"\xbb" * 24, ) extras = {"custom_key": "custom_value"} - NTLM_handle_authenticate_message( + ntlm_handle_authenticate_message( token, challenge=CHALLENGE, client=("10.0.0.1", 12345), @@ -1003,7 +1001,7 @@ def test_negotiate_fields_merged(self, mock_logger, mock_session): lm_response=b"\xbb" * 24, ) neg_fields = {"os": "Windows 10 Build 19041"} - result = NTLM_handle_authenticate_message( + result = ntlm_handle_authenticate_message( token, challenge=CHALLENGE, client=("10.0.0.1", 12345), @@ -1021,7 +1019,7 @@ def test_none_extras_handled(self, mock_logger, mock_session): nt_response=b"\xaa" * 24, ) # extras=None should not crash - result = NTLM_handle_authenticate_message( + result = ntlm_handle_authenticate_message( token, challenge=CHALLENGE, client=("10.0.0.1", 12345), @@ -1033,10 +1031,10 @@ def test_none_extras_handled(self, mock_logger, mock_session): class TestNTLMHandleLegacyRawAuth: - """NTLM_handle_legacy_raw_auth(*, ...) at line 1461.""" + """ntlm_handle_legacy_raw_auth(*, ...) at line 1461.""" def test_cleartext_captured(self, mock_logger, mock_session): - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name="admin", domain_name="CORP", lm_response=b"", @@ -1055,7 +1053,7 @@ def test_cleartext_captured(self, mock_logger, mock_session): ) def test_cleartext_empty_skips(self, mock_logger, mock_session): - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name="admin", domain_name="CORP", lm_response=b"", @@ -1070,7 +1068,7 @@ def test_cleartext_empty_skips(self, mock_logger, mock_session): mock_session.db.add_auth.assert_not_called() def test_raw_v1_captured(self, mock_logger, mock_session): - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name="admin", domain_name="CORP", lm_response=b"\xbb" * 24, @@ -1084,7 +1082,7 @@ def test_raw_v1_captured(self, mock_logger, mock_session): assert mock_session.db.add_auth.called def test_raw_anonymous_skips(self, mock_logger, mock_session): - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name="", domain_name="", lm_response=b"", @@ -1098,7 +1096,7 @@ def test_raw_anonymous_skips(self, mock_logger, mock_session): mock_session.db.add_auth.assert_not_called() def test_raw_anonymous_z1_skips(self, mock_logger, mock_session): - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name="", domain_name="", lm_response=b"\x00", @@ -1112,7 +1110,7 @@ def test_raw_anonymous_z1_skips(self, mock_logger, mock_session): mock_session.db.add_auth.assert_not_called() def test_raw_both_empty_skips(self, mock_logger, mock_session): - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name="admin", domain_name="CORP", lm_response=b"", @@ -1127,7 +1125,7 @@ def test_raw_both_empty_skips(self, mock_logger, mock_session): def test_bad_challenge_no_crash(self, mock_logger, mock_session): # 7-byte challenge should log error, not crash - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name="admin", domain_name="CORP", lm_response=b"\xbb" * 24, @@ -1138,10 +1136,10 @@ def test_bad_challenge_no_crash(self, mock_logger, mock_session): logger=mock_logger, transport=NTLM_TRANSPORT_RAW, ) - # Should not crash — ValueError caught internally + # Should not crash - ValueError caught internally def test_user_bytes_decoded(self, mock_logger, mock_session): - NTLM_handle_legacy_raw_auth( + ntlm_handle_legacy_raw_auth( user_name=b"Admin", domain_name=b"CORP", lm_response=b"\xbb" * 24, @@ -1222,7 +1220,7 @@ def test_logger_called_with_blob_info(self, mock_logger): # expected_hash_type, has_lmv2_companion) # Extracted from real Windows-to-Windows SMB authentication exchanges. PCAP_VECTORS = [ - # XP SP3 -> XP SP0: NetNTLMv1-ESS (v5.1.2600) — TCP-flow-matched challenge + # XP SP3 -> XP SP0: NetNTLMv1-ESS (v5.1.2600) - TCP-flow-matched challenge ( "XPSP3", bytes.fromhex("a2bb534e5d77cde7"), @@ -1234,7 +1232,7 @@ def test_logger_called_with_blob_info(self, mock_logger): "NetNTLMv1-ESS", False, ), - # XP SP0 -> XP SP3: NetNTLMv1-ESS (no VERSION) — TCP-flow-matched challenge + # XP SP0 -> XP SP3: NetNTLMv1-ESS (no VERSION) - TCP-flow-matched challenge ( "XPSP0", bytes.fromhex("61c6ccdc55be7307"), @@ -1316,7 +1314,7 @@ def test_logger_called_with_blob_info(self, mock_logger): "NetNTLMv2", False, ), - # Srv03 -> XP SP3: NetNTLMv1-ESS (v5.2.3790) — TCP-flow-matched challenge + # Srv03 -> XP SP3: NetNTLMv1-ESS (v5.2.3790) - TCP-flow-matched challenge ( "Srv03", bytes.fromhex("38ec222f9dedff96"), @@ -1370,7 +1368,7 @@ def test_logger_called_with_blob_info(self, mock_logger): "NetNTLMv2", False, ), - # Srv16 -> XP SP3: NetNTLMv2, LM=Z(24) (v10.0.14393) — TCP-flow-matched + # Srv16 -> XP SP3: NetNTLMv2, LM=Z(24) (v10.0.14393) - TCP-flow-matched ( "Srv16", bytes.fromhex("77936e2ec48d1eb5"), @@ -1384,7 +1382,7 @@ def test_logger_called_with_blob_info(self, mock_logger): "NetNTLMv2", False, ), - # Srv19 -> Vista: NetNTLMv2, LM=Z(24) (v10.0.17763) — TCP-flow-matched + # Srv19 -> Vista: NetNTLMv2, LM=Z(24) (v10.0.17763) - TCP-flow-matched ( "Srv19", bytes.fromhex("0e3f0e0f5c3add3d"), @@ -1398,7 +1396,7 @@ def test_logger_called_with_blob_info(self, mock_logger): "NetNTLMv2", False, ), - # Srv22 -> Vista: NetNTLMv2, LM=Z(24) (v10.0.20348) — TCP-flow-matched + # Srv22 -> Vista: NetNTLMv2, LM=Z(24) (v10.0.20348) - TCP-flow-matched ( "Srv22", bytes.fromhex("975db6c485693f24"), @@ -1414,7 +1412,7 @@ def test_logger_called_with_blob_info(self, mock_logger): ), ] -# Anonymous probes from pcap — XP SP3, XP SP0, Srv03, Win7 send these before real auth +# Anonymous probes from pcap - XP SP3, XP SP0, Srv03, Win7 send these before real auth # Tuple: (id, flags, lm_response) PCAP_ANONYMOUS_PROBES = [ ("XPSP3_anon", 0xA2888A05, b"\x00"), @@ -1461,7 +1459,7 @@ def test_classify(self, vec): class TestPcapHashcatFormat: - """Verify NTLM_to_hashcat produces valid hashcat lines from real pcap data.""" + """Verify ntlm_to_hashcat produces valid hashcat lines from real pcap data.""" @pytest.mark.parametrize( "vec", @@ -1470,7 +1468,7 @@ class TestPcapHashcatFormat: ) def test_hashcat_output(self, vec): _id, challenge, user, domain, nt_resp, lm_resp, flags, expected_type, _lmv2 = vec - result = NTLM_to_hashcat(challenge, user, domain, lm_resp, nt_resp, flags) + result = ntlm_to_hashcat(challenge, user, domain, lm_resp, nt_resp, flags) assert len(result) >= 1, f"{_id}: expected at least 1 hash, got 0" label, line = result[0] @@ -1509,7 +1507,7 @@ def test_hashcat_output(self, vec): def test_lmv2_companion(self, vec): """Vista and Srv08 produce LMv2 companion hashes (no MsvAvTimestamp).""" _id, challenge, user, domain, nt_resp, lm_resp, flags, _, _ = vec - result = NTLM_to_hashcat(challenge, user, domain, lm_resp, nt_resp, flags) + result = ntlm_to_hashcat(challenge, user, domain, lm_resp, nt_resp, flags) assert len(result) == 2, ( f"{_id}: expected 2 hashes (primary + LMv2), got {len(result)}" ) @@ -1527,10 +1525,10 @@ def test_lmv2_companion(self, vec): ids=[v[0] for v in PCAP_VECTORS if v[7] == "NetNTLMv2" and not v[8]], ) def test_lmv2_suppressed_when_null(self, vec): - """Win7+ sends LM=Z(24) due to MsvAvTimestamp — LMv2 must be suppressed.""" + """Win7+ sends LM=Z(24) due to MsvAvTimestamp - LMv2 must be suppressed.""" _id, challenge, user, domain, nt_resp, lm_resp, flags, _, _ = vec assert lm_resp == b"\x00" * 24, f"{_id}: expected Z(24) LM response" - result = NTLM_to_hashcat(challenge, user, domain, lm_resp, nt_resp, flags) + result = ntlm_to_hashcat(challenge, user, domain, lm_resp, nt_resp, flags) assert len(result) == 1, ( f"{_id}: expected 1 hash (LMv2 suppressed), got {len(result)}" ) @@ -1598,7 +1596,7 @@ def test_log_blob_no_crash(self, vec, mock_logger): class TestPcapFullAuthPipeline: - """End-to-end: run real pcap vectors through NTLM_handle_authenticate_message.""" + """End-to-end: run real pcap vectors through ntlm_handle_authenticate_message.""" @pytest.mark.parametrize( "vec", @@ -1614,7 +1612,7 @@ def test_authenticate_captures(self, vec, mock_logger, mock_session): nt_response=nt_resp, lm_response=lm_resp, ) - result = NTLM_handle_authenticate_message( + result = ntlm_handle_authenticate_message( token, challenge=challenge, client=("10.0.0.99", 12345), @@ -1681,7 +1679,7 @@ def test_is_anonymous_detects_probe(self, probe): def test_hashcat_returns_empty(self, probe): """Anonymous probes produce no hashcat output.""" _id, flags, lm = probe - result = NTLM_to_hashcat( + result = ntlm_to_hashcat( b"\x00" * 8, # challenge doesn't matter "", "", @@ -1693,7 +1691,7 @@ def test_hashcat_returns_empty(self, probe): class TestPcapNegotiateFlags: - """Verify NTLM_build_challenge_message echoes real Windows negotiate flags correctly. + """Verify ntlm_build_challenge_message echoes real Windows negotiate flags correctly. Tests every unique negotiate flag combination from the pcap (14 Windows versions). Key behaviors validated: @@ -1712,7 +1710,7 @@ class TestPcapNegotiateFlags: def test_challenge_mandatory_flags(self, client_id, neg_flags): """Server response always has NTLM + ALWAYS_SIGN + REQUEST_TARGET.""" token = _build_ntlm_negotiate(neg_flags) - msg = NTLM_build_challenge_message(token, challenge=CHALLENGE) + msg = ntlm_build_challenge_message(token, challenge=CHALLENGE) resp_flags = msg["flags"] assert resp_flags & ntlm.NTLMSSP_NEGOTIATE_NTLM assert resp_flags & ntlm.NTLMSSP_NEGOTIATE_ALWAYS_SIGN @@ -1726,7 +1724,7 @@ def test_challenge_mandatory_flags(self, client_id, neg_flags): def test_challenge_echoes_unicode(self, client_id, neg_flags): """Server echoes UNICODE flag from client.""" token = _build_ntlm_negotiate(neg_flags) - msg = NTLM_build_challenge_message(token, challenge=CHALLENGE) + msg = ntlm_build_challenge_message(token, challenge=CHALLENGE) client_unicode = bool(neg_flags & ntlm.NTLMSSP_NEGOTIATE_UNICODE) server_unicode = bool(msg["flags"] & ntlm.NTLMSSP_NEGOTIATE_UNICODE) assert client_unicode == server_unicode, f"{client_id}: UNICODE echo mismatch" @@ -1739,7 +1737,7 @@ def test_challenge_echoes_unicode(self, client_id, neg_flags): def test_challenge_ess_lm_key_exclusivity(self, client_id, neg_flags): """When client sends both ESS and LM_KEY, server keeps only ESS.""" token = _build_ntlm_negotiate(neg_flags) - msg = NTLM_build_challenge_message(token, challenge=CHALLENGE) + msg = ntlm_build_challenge_message(token, challenge=CHALLENGE) resp_flags = msg["flags"] has_ess = bool(resp_flags & ntlm.NTLMSSP_NEGOTIATE_EXTENDED_SESSIONSECURITY) has_lm_key = bool(resp_flags & ntlm.NTLMSSP_NEGOTIATE_LM_KEY) @@ -1754,6 +1752,6 @@ def test_challenge_ess_lm_key_exclusivity(self, client_id, neg_flags): def test_challenge_has_target_info(self, client_id, neg_flags): """Server always includes TargetInfo (NTLMv2 AV_PAIRs) for pcap clients.""" token = _build_ntlm_negotiate(neg_flags) - msg = NTLM_build_challenge_message(token, challenge=CHALLENGE) + msg = ntlm_build_challenge_message(token, challenge=CHALLENGE) assert msg["flags"] & ntlm.NTLMSSP_NEGOTIATE_TARGET_INFO assert msg["TargetInfoFields_len"] > 0 diff --git a/tests/test_smb.py b/tests/test_smb.py index 3fabc68..25cf181 100644 --- a/tests/test_smb.py +++ b/tests/test_smb.py @@ -1,4 +1,4 @@ -"""Unit tests for dementor.protocols.smb — SMB protocol handler. +"""Unit tests for dementor.protocols.smb - SMB protocol handler. Tests cover module-level functions, SMBServerConfig methods, and SMBHandler methods (via a mock handler that bypasses the real socket). @@ -47,7 +47,7 @@ def mock_smb_config(): cfg.smb_server_os = "Windows" cfg.smb_native_lanman = "Windows" cfg.smb_captures_per_connection = 0 - cfg.smb_error_code = nt_errors.STATUS_SMB_BAD_UID + cfg.smb_error_code = None cfg.ntlm_challenge = b"\x01\x02\x03\x04\x05\x06\x07\x08" cfg.ntlm_disable_ess = False cfg.ntlm_disable_ntlmv2 = False @@ -290,7 +290,7 @@ def test_monotonic(self): class TestSetSmbErrorCode: - """SMBServerConfig.set_smb_error_code(value) at line 288.""" + """SMBServerConfig.set_smb_error_code(value).""" def _make_config(self): cfg = object.__new__(SMBServerConfig) @@ -322,9 +322,14 @@ def test_int_zero(self): cfg.set_smb_error_code(0) assert cfg.smb_error_code == 0 + def test_no_code(self): + cfg = self._make_config() + cfg.set_smb_error_code(None) + assert cfg.smb_error_code is None + class TestSmb3NegContextPad: - """_smb3_neg_context_pad(data_len) — instance method at line 840.""" + """_smb3_neg_context_pad(data_len) - instance method at line 840.""" @pytest.mark.parametrize( ("data_len", "expected_pad_len"), @@ -371,8 +376,8 @@ class TestBuildTrans2FileInfo: (26, 4), # FileMailslotQueryInformation (30, 16), # FileCompressionInformation # Pass-through levels - (0x03EC, None), # FileBasicInfo (class 4) — size varies - (0x03ED, None), # FileStandardInfo (class 5) — size varies + (0x03EC, None), # FileBasicInfo (class 4) - size varies + (0x03ED, None), # FileStandardInfo (class 5) - size varies ], ids=[ "standard_0001", @@ -419,7 +424,7 @@ def test_file_all_info_0107(self, mock_handler): assert len(result) > 0 def test_file_all_info_raw_15(self, mock_handler): - """FileAllInformation (class 15) — XP SP3 sends as raw class.""" + """FileAllInformation (class 15) - XP SP3 sends as raw class.""" result = mock_handler._build_trans2_file_info(15) assert result is not None assert len(result) > 0 @@ -430,7 +435,7 @@ def test_compression_info_010b(self, mock_handler): assert len(result) == 16 def test_name_valid_0006(self, mock_handler): - """0x0006: FileInternalInformation / SMB_INFO_IS_NAME_VALID — 8 bytes.""" + """0x0006: FileInternalInformation / SMB_INFO_IS_NAME_VALID - 8 bytes.""" result = mock_handler._build_trans2_file_info(0x0006) assert result is not None assert len(result) == 8 @@ -930,7 +935,7 @@ def test_unknown_smb1_sends_not_implemented(self, mock_handler): class TestSmb2SessionSetup: """handle_smb2_session_setup IS_GUEST logic at line 1597. - These are partial integration tests — we mock handle_ntlmssp to return + These are partial integration tests - we mock handle_ntlmssp to return a fixed response and only verify the IS_GUEST SessionFlags logic. """