diff --git a/.vscode/launch.json b/.vscode/launch.json index 4508b45..5442383 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -3,9 +3,12 @@ "configurations": [ { "name": "Attach to Python Functions", - "type": "python", + "type": "debugpy", "request": "attach", - "port": 9091, + "connect": { + "host": "localhost", + "port": 9091, + }, "preLaunchTask": "func: host start" } ] diff --git a/README.md b/README.md index 0d7d36c..1ff4057 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ To debug in VS Code create a local.settings.json file in the root of the project "AZURE_TENANT_ID": "", "AZURE_CLIENT_ID": "", "AZURE_CLIENT_SECRET": "", - "AZURE_AUTHORITY_HOST": "login.microsoftonline.com", + "AZURE_AUTHORITY_HOST": "https://login.microsoftonline.com", "ARM_ENDPOINT": "management.azure.com", "GRAPH_ENDPOINT": "graph.microsoft.com", "LOGANALYTICS_ENDPOINT": "api.loganalytics.io", diff --git a/classes/__init__.py b/classes/__init__.py index f1a8a7c..88d5c9b 100644 --- a/classes/__init__.py +++ b/classes/__init__.py @@ -46,6 +46,14 @@ def __init__(self, error:str, source_error:dict={}, status_code:int=400): self.source_error = source_error self.status_code = status_code +class STATServerError(STATError): + """STAT exception raised when an API call returns a 5xx series error. + + This exception is a specialized version of STATError for cases where + a server side error was encountered and a retry may be needed. + """ + pass + class STATNotFound(STATError): """STAT exception raised when an API call returns a 404 Not Found error. @@ -85,6 +93,7 @@ def __init__(self): self.AccountsCount = 0 self.AccountsOnPrem = [] self.Alerts = [] + self.CreatedTime = '' self.Domains = [] self.DomainsCount = 0 self.EntitiesCount = 0 @@ -99,6 +108,8 @@ def __init__(self): self.IncidentARMId = "" self.IncidentTriggered = False self.IncidentAvailable = False + self.MailMessages = [] + self.MailMessagesCount = 0 self.ModuleVersions = {} self.MultiTenantConfig = {} self.OtherEntities = [] @@ -117,6 +128,7 @@ def __init__(self): def load_incident_trigger(self, req_body): self.IncidentARMId = req_body['object']['id'] + self.CreatedTime = req_body['object']['properties']['createdTimeUtc'] self.IncidentTriggered = True self.IncidentAvailable = True self.SentinelRGARMId = "/subscriptions/" + req_body['workspaceInfo']['SubscriptionId'] + "/resourceGroups/" + req_body['workspaceInfo']['ResourceGroupName'] @@ -127,6 +139,7 @@ def load_incident_trigger(self, req_body): def load_alert_trigger(self, req_body): self.IncidentTriggered = False + self.CreatedTime = req_body['EndTimeUtc'] self.SentinelRGARMId = "/subscriptions/" + req_body['WorkspaceSubscriptionId'] + "/resourceGroups/" + req_body['WorkspaceResourceGroup'] self.WorkspaceId = req_body['WorkspaceId'] @@ -135,6 +148,7 @@ def load_from_input(self, basebody): self.AccountsCount = basebody['AccountsCount'] self.AccountsOnPrem = basebody.get('AccountsOnPrem', []) self.Alerts = basebody.get('Alerts', []) + self.CreatedTime = basebody.get('CreatedTime', '') self.Domains = basebody['Domains'] self.DomainsCount = basebody['DomainsCount'] self.EntitiesCount = basebody['EntitiesCount'] @@ -149,6 +163,8 @@ def load_from_input(self, basebody): self.IncidentTriggered = basebody['IncidentTriggered'] self.IncidentAvailable = basebody['IncidentAvailable'] self.IncidentARMId = basebody['IncidentARMId'] + self.MailMessages = basebody.get('MailMessages', []) + self.MailMessagesCount = basebody.get('MailMessagesCount', 0) self.ModuleVersions = basebody['ModuleVersions'] self.MultiTenantConfig = basebody.get('MultiTenantConfig', {}) self.OtherEntities = basebody['OtherEntities'] @@ -189,11 +205,16 @@ def add_account_entity(self, data): def add_onprem_account_entity(self, data): self.AccountsOnPrem.append(data) - def get_ip_list(self): + def get_ip_list(self, include_mail_ips:bool=True): ip_list = [] for ip in self.IPs: ip_list.append(ip['Address']) + if include_mail_ips: + for message in self.MailMessages: + if message.get('senderDetail', {}).get('ipv4'): + ip_list.append(message.get('senderDetail', {}).get('ipv4')) + return ip_list def get_domain_list(self): @@ -203,27 +224,43 @@ def get_domain_list(self): return domain_list - def get_url_list(self): + def get_url_list(self, include_mail_urls:bool=True): url_list = [] for url in self.URLs: url_list.append(url['Url']) + if include_mail_urls: + for message in self.MailMessages: + for url in message.get('urls', []): + url_list.append(url.get('url')) + return url_list - def get_filehash_list(self): + def get_filehash_list(self, include_mail_hashes:bool=True): hash_list = [] for hash in self.FileHashes: hash_list.append(hash['FileHash']) + + if include_mail_hashes: + for message in self.MailMessages: + for attachment in message.get('attachments', []): + if attachment.get('sha256'): + hash_list.append(attachment.get('sha256')) return hash_list - def get_ip_kql_table(self): + def get_ip_kql_table(self, include_mail_ips:bool=True): ip_data = [] for ip in self.IPs: ip_data.append({'Address': ip.get('Address'), 'Latitude': ip.get('GeoData').get('latitude'), 'Longitude': ip.get('GeoData').get('longitude'), \ 'Country': ip.get('GeoData').get('country'), 'State': ip.get('GeoData').get('state')}) + + if include_mail_ips: + for message in self.MailMessages: + if message.get('senderDetail', {}).get('ipv4'): + ip_data.append({'Address': message.get('senderDetail', {}).get('ipv4')}) encoded = urllib.parse.quote(json.dumps(ip_data)) @@ -268,12 +305,17 @@ def get_host_kql_table(self): ''' return kql - def get_url_kql_table(self): + def get_url_kql_table(self, include_mail_urls:bool=True): url_data = [] for url in self.URLs: url_data.append({'Url': url.get('Url')}) + if include_mail_urls: + for message in self.MailMessages: + for url in message.get('urls', []): + url_data.append({'Url': url.get('url')}) + encoded = urllib.parse.quote(json.dumps(url_data)) kql = f'''let urlEntities = print t = todynamic(url_decode('{encoded}')) @@ -282,12 +324,18 @@ def get_url_kql_table(self): ''' return kql - def get_filehash_kql_table(self): + def get_filehash_kql_table(self, include_mail_hashes:bool=True): hash_data = [] for hash in self.FileHashes: hash_data.append({'FileHash': hash.get('FileHash'), 'Algorithm': hash.get('Algorithm')}) + if include_mail_hashes: + for message in self.MailMessages: + for attachment in message.get('attachments', []): + if attachment.get('sha256'): + hash_data.append({'FileHash': attachment.get('sha256'), 'Algorithm': 'SHA256'}) + encoded = urllib.parse.quote(json.dumps(hash_data)) kql = f'''let hashEntities = print t = todynamic(url_decode('{encoded}')) @@ -308,6 +356,21 @@ def get_domain_kql_table(self): kql = f'''let domainEntities = print t = todynamic(url_decode('{encoded}')) | mv-expand t | project Domain=tostring(t.Domain); +''' + return kql + + def get_mail_kql_table(self): + + mail_data = [] + + for mail in self.MailMessages: + mail_data.append({'rec': mail.get('recipientEmailAddress'), 'nid': mail.get('networkMessageId'), 'send': mail.get('senderDetail', {}).get('fromAddress'), 'sendfrom': mail.get('senderDetail', {}).get('mailFromAddress')}) + + encoded = urllib.parse.quote(json.dumps(mail_data)) + + kql = f'''let mailEntities = print t = todynamic(url_decode('{encoded}')) +| mv-expand t +| project RecipientEmailAddress=tostring(t.rec), NetworkMessageId=tostring(t.nid), SenderMailFromAddress=tostring(t.send), SenderFromAddress=tostring(t.sendfrom); ''' return kql diff --git a/host.json b/host.json index fd4bee7..f2b7c0d 100644 --- a/host.json +++ b/host.json @@ -10,6 +10,6 @@ }, "extensionBundle": { "id": "Microsoft.Azure.Functions.ExtensionBundle", - "version": "[3.*, 4.0.0)" + "version": "[4.0.0, 5.0.0)" } } \ No newline at end of file diff --git a/modules/base.py b/modules/base.py index c196c32..5ab5eff 100644 --- a/modules/base.py +++ b/modules/base.py @@ -5,6 +5,7 @@ import logging import requests import ipaddress +import datetime as dt stat_version = None @@ -43,10 +44,11 @@ def execute_base_module (req_body): enrich_files(entities) enrich_filehashes(entities) enrich_urls(entities) + enrich_mail_message(entities) append_other_entities(entities) base_object.CurrentVersion = data.get_current_version() - base_object.EntitiesCount = base_object.AccountsCount + base_object.IPsCount + base_object.DomainsCount + base_object.FileHashesCount + base_object.FilesCount + base_object.HostsCount + base_object.OtherEntitiesCount + base_object.URLsCount + base_object.EntitiesCount = base_object.AccountsCount + base_object.IPsCount + base_object.DomainsCount + base_object.FileHashesCount + base_object.FilesCount + base_object.HostsCount + base_object.OtherEntitiesCount + base_object.URLsCount + base_object.MailMessagesCount org_info = json.loads(rest.rest_call_get(base_object, api='msgraph', path='/v1.0/organization').content) base_object.TenantDisplayName = org_info['value'][0]['displayName'] @@ -67,15 +69,26 @@ def execute_base_module (req_body): account_comment = '' ip_comment = '' + mail_comment = '' if req_body.get('AddAccountComments', True) and base_object.AccountsCount > 0: - account_comment = 'Account Info:
' + get_account_comment() + account_comment = '

Account Info:

' + get_account_comment() if req_body.get('AddIPComments', True) and base_object.check_global_and_local_ips(): - ip_comment = 'IP Info:
' + get_ip_comment() - - if (req_body.get('AddAccountComments', True) and base_object.AccountsCount > 0) or (req_body.get('AddIPComments', True) and base_object.check_global_and_local_ips()): - comment = account_comment + '

' + ip_comment + ip_comment = '

IP Info:

' + get_ip_comment() + + if req_body.get('AddMailComments', True) and base_object.MailMessages: + mail_comment = '

Mail Message Info:

' + get_mail_comment() + + if (req_body.get('AddAccountComments', True) and base_object.AccountsCount > 0) or (req_body.get('AddIPComments', True) and base_object.check_global_and_local_ips()) or (req_body.get('AddMailComments', True) and base_object.MailMessages): + comment = '' + if account_comment: + comment += account_comment + '

' + if ip_comment: + comment += ip_comment + '

' + if mail_comment: + comment += mail_comment + rest.add_incident_comment(base_object, comment) return Response(base_object) @@ -212,6 +225,56 @@ def enrich_domains(entities): raw_entity = data.coalesce(domain.get('properties'), domain) base_object.Domains.append({'Domain': domain_name, 'RawEntity': raw_entity}) +def enrich_mail_message(entities): + mail_entities = list(filter(lambda x: x['kind'].lower() == 'mailmessage', entities)) + base_object.MailMessagesCount = len(mail_entities) + message_role = rest.check_app_role(base_object, 'msgraph', ['SecurityAnalyzedMessage.Read.All','SecurityAnalyzedMessage.ReadWrite.All']) + + for mail in mail_entities: + recipient = data.coalesce(mail.get('properties',{}).get('recipient'), mail.get('Recipient')) + network_message_id = data.coalesce(mail.get('properties',{}).get('networkMessageId'), mail.get('NetworkMessageId')) + receive_date = data.coalesce(mail.get('properties',{}).get('receiveDate'), mail.get('ReceivedDate')) + + if receive_date: + start_time = (dt.datetime.fromisoformat(receive_date) + dt.timedelta(days=-14)).strftime("%Y-%m-%dT%H:%M:%SZ") + end_time = (dt.datetime.fromisoformat(receive_date) + dt.timedelta(days=14)).strftime("%Y-%m-%dT%H:%M:%SZ") + else: + start_time = (dt.datetime.fromisoformat(base_object.CreatedTime) + dt.timedelta(days=-14)).strftime("%Y-%m-%dT%H:%M:%SZ") + end_time = (dt.datetime.fromisoformat(base_object.CreatedTime) + dt.timedelta(days=14)).strftime("%Y-%m-%dT%H:%M:%SZ") + + raw_entity = data.coalesce(mail.get('properties'), mail) + + if not message_role: + base_object.MailMessages.append({'networkMessageId': network_message_id, 'recipientEmailAddress': recipient, 'EnrichmentMethod': 'MailMessage - No App Role', 'RawEntity': raw_entity}) + continue + + if recipient and network_message_id: + try: + get_message = json.loads(rest.rest_call_get(base_object, api='msgraph', path=f"/beta/security/collaboration/analyzedemails?startTime={start_time}&endTime={end_time}&filter=networkMessageId eq '{network_message_id}' and recipientEmailAddress eq '{recipient}'").content) + if get_message['value']: + message_details = json.loads(rest.rest_call_get(base_object, api='msgraph', path=f"/beta/security/collaboration/analyzedemails/{get_message['value'][0]['id']}").content) + message_details['RawEntity'] = raw_entity + else: + message_details = { + 'networkMessageId': network_message_id, + 'recipientEmailAddress': recipient, + 'EnrichmentMethod': 'MailMessage - analyzedMessage could not be found', + 'RawEntity': raw_entity + } + except: + message_details = { + 'networkMessageId': network_message_id, + 'recipientEmailAddress': recipient, + 'EnrichmentMethod': 'MailMessage - Failed to get analyzedMessage', + 'RawEntity': raw_entity + } + + else: + message_details = {'EnrichmentMethod': 'MailMessage - No Recipient or NetworkMessageId', 'RawEntity': raw_entity} + + base_object.MailMessages.append(message_details) + + def enrich_files(entities): file_entities = list(filter(lambda x: x['kind'].lower() == 'file', entities)) base_object.FilesCount = len(file_entities) @@ -240,7 +303,7 @@ def enrich_urls(entities): base_object.URLs.append({'Url': url_data, 'RawEntity': raw_entity}) def append_other_entities(entities): - other_entities = list(filter(lambda x: x['kind'].lower() not in ('ip','account','dnsresolution','dns','file','filehash','host','url'), entities)) + other_entities = list(filter(lambda x: x['kind'].lower() not in ('ip','account','dnsresolution','dns','file','filehash','host','url','mailmessage'), entities)) base_object.OtherEntitiesCount = len(other_entities) for entity in other_entities: @@ -444,17 +507,25 @@ def get_account_comment(): upn_data = f'{account_upn}
(Contact User)' else: upn_data = account_upn - - account_list.append({'UserPrincipalName': upn_data, 'City': account.get('city'), 'Country': account.get('country'), \ - 'Department': account.get('department'), 'JobTitle': account.get('jobTitle'), 'Office': account.get('officeLocation'), \ - 'AADRoles': account.get('AssignedRoles'), 'ManagerUPN': account.get('manager', {}).get('userPrincipalName'), \ - 'MfaRegistered': account.get('isMfaRegistered'), 'SSPREnabled': account.get('isSSPREnabled'), \ - 'SSPRRegistered': account.get('isSSPRRegistered')}) - + + if upn_data: + account_list.append({ + 'User': f"{upn_data}
JobTitle: {account.get('jobTitle')}", + 'Location': f"Department: {account.get('department')}
Office: {account.get('officeLocation')}
City: {account.get('city')}
Country: {account.get('country')}", + 'OtherDetails': f"AADRoles: {', '.join(account.get('AssignedRoles', []))}
Manager: {account.get('manager', {}).get('userPrincipalName')}
MFA Registered: {account.get('isMfaRegistered')}
SSPR Enabled: {account.get('isSSPREnabled')}
SSPR Registered: {account.get('isSSPRRegistered')}
OnPremSynced: {account.get('onPremisesSyncEnabled')}" + }) + else: + account_list.append({ + 'User': "Unknown User", + 'OtherDetails': f"Failed to lookup account details for 1 account entity
Enrichment Method: {account.get('EnrichmentMethod')}", + }) + for onprem_acct in base_object.AccountsOnPrem: - account_list.append( - {'UserPrincipalName': data.coalesce(onprem_acct.get('userPrincipalName'),onprem_acct.get('onPremisesSamAccountName')), 'Department': onprem_acct.get('department'), 'JobTitle': onprem_acct.get('jobTitle'), 'ManagerUPN': onprem_acct.get('manager'), 'Notes': 'On-Prem - No Entra Sync'} - ) + account_list.append({ + 'User': f"{data.coalesce(onprem_acct.get('userPrincipalName'),onprem_acct.get('onPremisesSamAccountName'))}
JobTitle: {onprem_acct.get('jobTitle')}", + 'Location': f"Department: {onprem_acct.get('department')}", + 'OtherDetails': f"Manager: {onprem_acct.get('manager')}
OnPremSynced: On-Prem Only" + }) return data.list_to_html_table(account_list, 20, 20, escape_html=False) @@ -465,12 +536,36 @@ def get_ip_comment(): if ip.get('IPType') != 3: #Excludes link local addresses from the IP comment geo = ip.get('GeoData') - ip_list.append({'IP': ip.get('Address'), 'City': geo.get('city'), 'State': geo.get('state'), 'Country': geo.get('country'), \ - 'Organization': geo.get('organization'), 'OrganizationType': geo.get('organizationType'), 'ASN': geo.get('asn'), 'IPType': ip.get('IPType')}) + + ip_list.append({ + 'IP': ip.get('Address'), + 'Location': f"City: {geo.get('city', 'Unknown')}
State: {geo.get('state', 'Unknown')}
Country: {geo.get('country', 'Unknown')}", + 'OtherDetails': f"Organization: {geo.get('organization', 'Unknown')}
OrganizationType: {geo.get('organizationType', 'Unknown')}
ASN: {geo.get('asn', 'Unknown')}", + 'IPType': ip.get('IPType') + }) ip_list = data.sort_list_by_key(ip_list, 'IPType', ascending=True, drop_columns=['IPType']) - return data.list_to_html_table(ip_list) + return data.list_to_html_table(ip_list, escape_html=False) + +def get_mail_comment(): + + mail_list = [] + for msg in base_object.MailMessages: + if msg.get('EnrichmentMethod'): + mail_list.append({ + 'MessageDetails': f"NetworkMessageId: {msg.get('networkMessageId')}
Recipient: {msg.get('recipientEmailAddress', 'Unknown')}", + 'EnrichmentMethod': f"Enrichment Method: {msg.get('EnrichmentMethod')}", + }) + else: + mail_list.append({ + 'MessageDetails': f"Recipient: {msg.get('recipientEmailAddress')}
Sender: {msg.get('senderDetail', {}).get('fromAddress')}
SenderFromAddress: {msg.get('senderDetail', {}).get('mailFromAddress')}
Subject: {msg.get('subject')}
AttachmentCount: {len(msg.get('attachments', []))}
URLCount: {len(msg.get('urls', []))}", + 'Delivery': f"Original Delivery: {msg.get('originalDelivery', {}).get('location')}
Latest Delivery: {msg.get('latestDelivery', {}).get('location')}", + 'Authentication': f"SPF: {msg.get('authenticationDetails', {}).get('senderPolicyFramework')}
DKIM: {msg.get('authenticationDetails', {}).get('dkim')}
DMARC: {msg.get('authenticationDetails', {}).get('dmarc')}", + 'ThreatInfo': f"ThreatTypes: {', '.join(msg.get('threatTypes', []))}
DetectionMethods: {', '.join(msg.get('detectionMethods', []))}" + }) + + return data.list_to_html_table(mail_list, escape_html=False) def get_stat_version(version_check_type): diff --git a/modules/kql.py b/modules/kql.py index 2a67995..5197e5d 100644 --- a/modules/kql.py +++ b/modules/kql.py @@ -15,8 +15,9 @@ def execute_kql_module (req_body): ip_entities = base_object.get_ip_kql_table() account_entities = base_object.get_account_kql_table(include_unsynced=True) host_entities = base_object.get_host_kql_table() + mail_entities = base_object.get_mail_kql_table() - query = arm_id + ip_entities + account_entities + host_entities + req_body['KQLQuery'] + query = arm_id + ip_entities + account_entities + host_entities + mail_entities + req_body['KQLQuery'] if req_body.get('RunQueryAgainst') == 'M365': results = rest.execute_m365d_query(base_object, query) diff --git a/modules/version.json b/modules/version.json index d20a644..aeb01b2 100644 --- a/modules/version.json +++ b/modules/version.json @@ -1,3 +1,3 @@ { - "FunctionVersion": "2.2.0" + "FunctionVersion": "2.2.3" } diff --git a/requirements.txt b/requirements.txt index 0c034fc..d0b7ec2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,3 @@ azure-identity azure-keyvault-secrets requests pandas -cryptography==43.0.3 - -#Limiting cryptograpy due to https://github.com/Azure/azure-functions-python-worker/issues/1651 \ No newline at end of file diff --git a/shared/rest.py b/shared/rest.py index 3539923..b5635b8 100644 --- a/shared/rest.py +++ b/shared/rest.py @@ -8,7 +8,7 @@ import uuid import time import base64 -from classes import STATError, STATNotFound, BaseModule, STATTooManyRequests +from classes import STATError, STATNotFound, BaseModule, STATTooManyRequests, STATServerError stat_token = {} graph_endpoint = os.getenv('GRAPH_ENDPOINT') @@ -201,9 +201,14 @@ def execute_rest_call(base_module:BaseModule, method:str, api:str, path:str, bod if wait_time > 60: raise STATTooManyRequests(error=e.error, source_error=e.source_error, status_code=e.status_code, retry_after=e.retry_after) time.sleep(retry_after) + except STATServerError as e: + wait_time += 15 + if wait_time > 60: + raise STATServerError(error=f'Server error returned by {url}', source_error=e.source_error, status_code=500) + time.sleep(15) except ConnectionError as e: wait_time += 20 - if wait_time >= 60: + if wait_time > 60: raise STATError(error=f'Failed to establish a new connection to {url}', source_error=e, status_code=500) time.sleep(20) else: @@ -217,6 +222,8 @@ def check_rest_response(response:Response, api, path): raise STATNotFound(f'The API call to {api} with path {path} failed with status {response.status_code}', source_error={'status_code': int(response.status_code), 'reason': str(response.reason)}) elif response.status_code == 429 or response.status_code == 408: raise STATTooManyRequests(f'The API call to {api} with path {path} failed with status {response.status_code}', source_error={'status_code': int(response.status_code), 'reason': str(response.reason)}, retry_after=response.headers.get('Retry-After', 10), status_code=int(response.status_code)) + elif response.status_code >= 500: + raise STATServerError(f'The API call to {api} with path {path} failed with status {response.status_code}', source_error={'status_code': int(response.status_code), 'reason': str(response.reason)}) elif response.status_code >= 300: raise STATError(f'The API call to {api} with path {path} failed with status {response.status_code}', source_error={'status_code': int(response.status_code), 'reason': str(response.reason)}) return