Skip to content

Add WAF rate limiting custom resource. #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/waf_rate_limit.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# CFNDSL

Resource('RateLimitRule') {
Type 'Custom::WAFRateLimit'
Property('ServiceToken', FnGetAtt('WAFRateLimitFunction', 'Arn'))
Property('EnvironmentName', Ref('EnvironmentName'))
Property('Region', Ref("AWS::Region"))
Property('Rate', 5000)
Property('Negated', true)
Property('Action', 'BLOCK')
Property('IPSet', waf_ip_set(ip_blocks, ['rate_limited']))
Property('WebACLId', Ref('WebACL'))
Property('Priority', 2)
}

Resource('WAFRateLimitFunction') {
Type 'AWS::Lambda::Function'
Property('Code', './waf_rate_limit/')
Property('Handler', 'handler.lambda_handler')
Property('Runtime', 'python3.6')
Property('Timeout', 60)
Property('Role', FnGetAtt('WAFRole', 'Arn'))
}

Resource("WAFRole") {
Type 'AWS::IAM::Role'
Property('AssumeRolePolicyDocument', {
Statement: [
Effect: 'Allow',
Principal: { Service: [ 'lambda.amazonaws.com' ] },
Action: [ 'sts:AssumeRole' ]
]
})
Property('Path','/')
Property('Policies', Policies.new.get_policies('waf'))
}
4 changes: 3 additions & 1 deletion ssm-secure-parameter/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def lambda_handler(event, context):
lambda_response.respond_error(f"{key} property missing")
return

replace = cr_params.get('Update', True)

try:
parameter = logic.SSMSecureParameterLogic(cr_params['Path'])
length = 16 or cr_params['Length']
Expand All @@ -40,7 +42,7 @@ def lambda_handler(event, context):
elif event['RequestType'] == 'Update':
password, version = parameter.create(
length=length,
update=True
update=replace
)

event['PhysicalResourceId'] = cr_params['Path']
Expand Down
Empty file added waf_rate_limit/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions waf_rate_limit/cr_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
from urllib.request import urlopen, Request, HTTPError, URLError
import json

logger = logging.getLogger()
logger.setLevel(logging.INFO)

class CustomResourceResponse:
def __init__(self, request_payload):
self.payload = request_payload
self.response = {
"StackId": request_payload["StackId"],
"RequestId": request_payload["RequestId"],
"LogicalResourceId": request_payload["LogicalResourceId"],
"Status": 'SUCCESS',
}


def respond_error(self, message):
self.response['Status'] = 'FAILED'
self.response['Reason'] = message
self.respond({})

def respond(self, data):
event = self.payload
response = self.response
####
#### copied from https://github.com/ryansb/cfn-wrapper-python/blob/master/cfn_resource.py
####

if event.get("PhysicalResourceId", False):
response["PhysicalResourceId"] = event["PhysicalResourceId"]

logger.debug("Received %s request with event: %s" %
(event['RequestType'], json.dumps(event)))

response["Data"] = data

serialized = json.dumps(response)

logger.info(f"Responding to {event['RequestType']} request with: {serialized}")
req_data = serialized.encode('utf-8')

req = Request(
event['ResponseURL'],
data=req_data,
headers={'Content-Length': len(req_data), 'Content-Type': ''}
)
req.get_method = lambda: 'PUT'

try:
urlopen(req)
logger.debug("Request to CFN API succeeded, nothing to do here")
except HTTPError as e:
logger.error("Callback to CFN API failed with status %d" % e.code)
logger.error("Response: %s" % e.reason)
except URLError as e:
logger.error("Failed to reach the server - %s" % e.reason)
44 changes: 44 additions & 0 deletions waf_rate_limit/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import sys
import os

sys.path.append(f"{os.environ['LAMBDA_TASK_ROOT']}/lib")
sys.path.append(os.path.dirname(os.path.realpath(__file__)))

import cr_response
from logic import WafRateLimit
import json

def lambda_handler(event, context):

print(f"Received event:{json.dumps(event)}")

lambda_response = cr_response.CustomResourceResponse(event)
cr_params = event['ResourceProperties']
waf_logic = WafRateLimit(cr_params)
try:
# if create request, generate physical id, both for create/update copy files
if event['RequestType'] == 'Create':
event['PhysicalResourceId'] = waf_logic._create_rate_based_rule()
data = {
"RuleID" : event['PhysicalResourceId']
}
lambda_response.respond(data)

elif event['RequestType'] == 'Update':
waf_logic._update_rate_based_rule(event['PhysicalResourceId'])
data = {
"RuleID" : event['PhysicalResourceId']
}
lambda_response.respond(data)

elif event['RequestType'] == 'Delete':
print(event['PhysicalResourceId'])
waf_logic._delete_rate_based_rule(event['PhysicalResourceId'])
data = { }
lambda_response.respond(data)

except Exception as e:
message = str(e)
lambda_response.respond_error(message)

return 'OK'
228 changes: 228 additions & 0 deletions waf_rate_limit/logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import boto3
import os
import glob
import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)


class WafRateLimit:

def __init__(self, resource_properties):
self.rate = resource_properties['Rate']
self.action = resource_properties['Action']
self.region = resource_properties['Region']
self.env = resource_properties['EnvironmentName']
self.ip_set = resource_properties['IPSet']
self.negated = resource_properties['Negated']
self.region = resource_properties['Region']
self.regional = resource_properties.get('Regional', 'false')
self.web_acl_id = resource_properties['WebACLId']
self.priority = int(resource_properties['Priority'])

if 'EnvironmentName' in resource_properties:
self.rule_name = f"{resource_properties['EnvironmentName']}-rate-limit"
self.ip_set_name = f"{resource_properties['EnvironmentName']}-rate-limit-ip-set"
else:
self.rule_name = resource_properties['RuleName']
self.ip_set_name = resource_properties['IpSetName']

self.metric_name = self.rule_name.replace('-', '')

if to_bool(self.regional):
self.client = boto3.client('waf-regional', region_name=self.region)
else:
self.client = boto3.client('waf', region_name=self.region)

def retry(func):
# Reattempt to execute a given function with optional arguments.
# This is to avoid the insane error about a token already being expired.
def wrapper(self, *args, **kwargs):
attempts = 5
remaining = attempts

while remaining:
try:
result = func(self, *args, **kwargs)
return result
except self.client.exceptions.WAFStaleDataException as e:
logger.info(str(e))
logger.info("(%d/%d) Retrying request with a new change token..." % (remaining + 1, attempts))
remaining -= 1

logger.info("ERROR - failed to execute request.")
exit(1)

return wrapper

def _create_rate_based_rule(self):
rule_id = self.create_rate_based_rule()

if len(self.ip_set):
ip_set_id = self.create_ip_set()
self.update_ip_set('INSERT', ip_set_id, self.ip_set)
self.update_rate_based_rule('INSERT', ip_set_id, rule_id)

self._add_to_web_acl(rule_id)

return rule_id

@retry
def create_rate_based_rule(self):
change_token = self._get_change_token()
logger.info("Creating WAF rule '%s' ..." % self.rule_name)

rule_id = self.client.create_rate_based_rule(
Name=self.rule_name,
MetricName=self.metric_name,
RateLimit=int(self.rate),
RateKey='IP',
ChangeToken=change_token
)['Rule']['RuleId']

return rule_id

@retry
def create_ip_set(self):
change_token = self._get_change_token()
logger.info("Creating IP set '%s' ..." % self.ip_set_name)

ip_set_id = self.client.create_ip_set(
Name=self.ip_set_name,
ChangeToken=change_token
)['IPSet']['IPSetId']

return ip_set_id

@retry
def update_ip_set(self, action, ip_set_id, ip_set):
change_token = self._get_change_token()
logger.info("Updating IP set '%s' (%s) with %d IPs as %s ..." % (self.ip_set_name, ip_set_id, len(self.ip_set), action))

self.client.update_ip_set(
IPSetId=ip_set_id,
ChangeToken=change_token,
Updates=generate_waf_ip_set(action, ip_set)
)

def _update_rate_based_rule(self, rule_id):
self._delete_rate_based_rule(rule_id)
return self._create_rate_based_rule()

@retry
def update_rate_based_rule(self, action, ip_set_id, rule_id):
change_token = self._get_change_token()
logger.info("Updating rule '%s' (%s) with IP set '%s' (%s) as %s ..." % (self.rule_name, rule_id, self.ip_set_name, ip_set_id, action))

self.client.update_rate_based_rule(
RuleId=rule_id,
ChangeToken=change_token,
Updates=[{
'Action': action,
'Predicate': {
'Negated': to_bool(self.negated),
'Type': 'IPMatch',
'DataId': ip_set_id
}
}],
RateLimit=int(self.rate)
)

def _delete_rate_based_rule(self, rule_id):
logger.info("Getting IP set for rule '%s' (%s) ..." % (self.rule_name, rule_id))

try:
predicates = self.client.get_rate_based_rule(
RuleId=rule_id
)['Rule']['MatchPredicates']
except self.client.exceptions.WAFNonexistentItemException as e:
logger.info("%s: rule ID '%s' does not exist. Returning success" % (str(e), rule_id))
return

if len(predicates):
ip_set_id = predicates[0]['DataId']

logger.info("Getting IPs for IP set '%s' ..." % (ip_set_id))

current_ip_set = self.client.get_ip_set(
IPSetId=ip_set_id
)['IPSet']['IPSetDescriptors']

if len(current_ip_set):
self.update_ip_set('DELETE', ip_set_id, current_ip_set)

self.update_rate_based_rule('DELETE', ip_set_id, rule_id)
self.delete_ip_set(ip_set_id)

self._delete_from_web_acl(rule_id)
self.delete_rate_based_rule(rule_id)

@retry
def delete_ip_set(self, ip_set_id):
change_token = self._get_change_token()
logger.info("Deleting IP set '%s' ..." % (ip_set_id))

self.client.delete_ip_set(
IPSetId=ip_set_id,
ChangeToken=change_token
)

@retry
def delete_rate_based_rule(self, rule_id):
change_token = self._get_change_token()
logger.info("Deleting rule '%s' (%s) ..." % (self.rule_name, rule_id))

self.client.delete_rate_based_rule(
RuleId=rule_id,
ChangeToken=change_token
)

def _get_change_token(self):
token = self.client.get_change_token()['ChangeToken']
logger.info("Got change token: %s" % token)
return token

def _add_to_web_acl(self, rule_id):
self._update_web_acl('INSERT', self.action, self.priority, rule_id)

def _delete_from_web_acl(self, rule_id):
# Get the current rule priority, as it is needed in the update request
web_acl_rules = self.client.get_web_acl(
WebACLId=self.web_acl_id
)['WebACL']['Rules']

current_rule = list(filter(lambda rule: rule['RuleId'] == rule_id, web_acl_rules))[0]
current_action = current_rule['Action']['Type']
current_priority = int(current_rule['Priority'])

self._update_web_acl('DELETE', current_action, current_priority, rule_id)

@retry
def _update_web_acl(self, new_action, current_action, priority, rule_id):
"""Add a rule ID with a web ACL.
"""
change_token = self._get_change_token()
logger.info("%sing rule '%s' (%s) in web ACL ID '%s'" % (new_action, self.rule_name, rule_id, self.web_acl_id))

self.client.update_web_acl(
WebACLId=self.web_acl_id,
Updates=[{
"Action": new_action,
"ActivatedRule": {
"Action": {
"Type": current_action
},
"Priority": priority,
"RuleId": rule_id,
"Type": "RATE_BASED"
}
}],
ChangeToken=change_token
)

def generate_waf_ip_set(action, ips):
return [{'Action': action, 'IPSetDescriptor': ip } for ip in ips]

def to_bool(value):
return value.lower() == 'true'
Loading