diff --git a/RateLimiter.Tests/RateLimiter.Tests.csproj b/RateLimiter.Tests/RateLimiter.Tests.csproj index 5cbfc4e8..d2dca726 100644 --- a/RateLimiter.Tests/RateLimiter.Tests.csproj +++ b/RateLimiter.Tests/RateLimiter.Tests.csproj @@ -9,7 +9,11 @@ - - + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + \ No newline at end of file diff --git a/RateLimiter.Tests/RateLimiterTest.cs b/RateLimiter.Tests/RateLimiterTest.cs index 172d44a7..1561e4c7 100644 --- a/RateLimiter.Tests/RateLimiterTest.cs +++ b/RateLimiter.Tests/RateLimiterTest.cs @@ -1,13 +1,206 @@ -using NUnit.Framework; +using Microsoft.AspNetCore.Http; +using Moq; +using RateLimiter.Rules; +using System.Text; +using System.Threading.Tasks; +using Xunit; +using Microsoft.Extensions.Caching.Distributed; +using System.Net; +using RateLimiter.Configuration; +using System.Collections.Generic; +using RateLimiter.Policy; +using System.Linq; -namespace RateLimiter.Tests; - -[TestFixture] -public class RateLimiterTest +namespace RateLimiter.Tests { - [Test] - public void Example() - { - Assert.That(true, Is.True); - } -} \ No newline at end of file + public class RateLimiterTest + { + private readonly Mock _cacheMock; + + public RateLimiterTest() + { + _cacheMock = new Mock(); + } + + #region Combined + + [Fact] + public async Task EvaluateAll_ShouldAllow_WhenAllUnderLimit() + { + // Arrange + List configIpBlocked = ["192.168.1.1", "192.168.1.2"]; + FixedWindowConfig configFixed = new(5, 10); + string current = "2"; + _cacheMock.Setup(c => c.GetAsync(It.IsAny(), default)) + .ReturnsAsync(Encoding.UTF8.GetBytes(current)); + + var ruleFixed = new FixedWindowRule(_cacheMock.Object, configFixed.Limit, configFixed.Seconds); + var ruleIp = new IpBlacklistRule(configIpBlocked); + var policy = new RateLimitPolicy(); + policy.AddRule(ruleFixed); + policy.AddRule(ruleIp); + + var httpContext = GetHttpContext(); + + // Act + bool result = await policy.EvaluateAllAsync(httpContext); + + // Assert + Assert.True(result); + } + + [Fact] + public async Task EvaluateAll_ShouldNotAllow_BlockedIp() + { + // Arrange + List configIpBlocked = ["192.168.1.101"]; + FixedWindowConfig configFixed = new(5, 10); + string current = "2"; + _cacheMock.Setup(c => c.GetAsync(It.IsAny(), default)) + .ReturnsAsync(Encoding.UTF8.GetBytes(current)); + + var ruleFixed = new FixedWindowRule(_cacheMock.Object, configFixed.Limit, configFixed.Seconds); + var ruleIp = new IpBlacklistRule(configIpBlocked); + var policy = new RateLimitPolicy(); + policy.AddRule(ruleFixed); + policy.AddRule(ruleIp); + + var httpContext = GetHttpContext(configIpBlocked.First()); + + // Act + bool result = await policy.EvaluateAllAsync(httpContext); + + // Assert + Assert.False(result); + } + + #endregion + + #region Fixed + + [Fact] + public async Task EvaluateFixed_ShouldAllow_WhenUnderLimit() + { + // Arrange + FixedWindowConfig config = new(5, 10); + string current = "2"; + _cacheMock.Setup(c => c.GetAsync(It.IsAny(), default)) + .ReturnsAsync(Encoding.UTF8.GetBytes(current)); + + var rule = new FixedWindowRule(_cacheMock.Object, config.Limit, config.Seconds); + var httpContext = GetHttpContext(); + + // Act + bool result = await rule.EvaluateAsync(httpContext); + + // Assert + Assert.True(result); + } + + [Fact] + public async Task EvaluateFixed_ShouldNotAllow_WhenOverLimit() + { + // Arrange + FixedWindowConfig config = new(2, 10); + string current = "2"; + _cacheMock.Setup(c => c.GetAsync(It.IsAny(), default)) + .ReturnsAsync(Encoding.UTF8.GetBytes(current)); + + var rule = new FixedWindowRule(_cacheMock.Object, config.Limit, config.Seconds); + var httpContext = GetHttpContext(); + + // Act + bool result = await rule.EvaluateAsync(httpContext); + + // Assert + Assert.False(result); + } + + [Fact] + public async Task EvaluateFixed_ShouldAllow_WhenCacheEmpty() + { + // Arrange + FixedWindowConfig config = new(5, 10); + _cacheMock.Setup(c => c.GetAsync(It.IsAny(), default)) + .ReturnsAsync(null as byte[]); + + var rule = new FixedWindowRule(_cacheMock.Object, config.Limit, config.Seconds); + var httpContext = GetHttpContext(); + + // Act + bool result = await rule.EvaluateAsync(httpContext); + + // Assert + Assert.True(result); // First request should pass + } + + [Fact] + public async Task EvaluateFixed_ShouldUpdateCache_WhenRequestMade() + { + // Arrange + FixedWindowConfig config = new(5, 10); + _cacheMock.Setup(c => c.GetAsync(It.IsAny(), default)) + .ReturnsAsync(Encoding.UTF8.GetBytes("2")); // Simulate 2 requests so far + + var rule = new FixedWindowRule(_cacheMock.Object, config.Limit, config.Seconds); + var httpContext = GetHttpContext(); + + // Act + bool result = await rule.EvaluateAsync(httpContext); + + // Assert + Assert.True(result); // Request should pass + _cacheMock.Verify(c => c.SetAsync(It.IsAny(), It.IsAny(), It.IsAny(), default), Times.Once); + } + + #endregion + + #region Geo + + [Fact] + public async Task EvaluateGeo_ShouldAllow_WhenUnderLimit() + { + // Arrange + GeoBasedConfig config = new(Country.EU, 10); + _cacheMock.Setup(c => c.GetAsync(It.IsAny(), default)) + .ReturnsAsync(null as byte[]); + + var rule = new GeoBasedRule(_cacheMock.Object, [config]); + var httpContext = GetHttpContext(); + + // Act + bool result = await rule.EvaluateAsync(httpContext); + + // Assert + Assert.True(result); + } + + [Fact] + public async Task EvaluateGeo_ShouldNotAllow_WhenOverLimit() + { + // Arrange + GeoBasedConfig config = new(Country.EU, 10); + string current = "0"; + _cacheMock.Setup(c => c.GetAsync(It.IsAny(), default)) + .ReturnsAsync(Encoding.UTF8.GetBytes(current)); + + var rule = new GeoBasedRule(_cacheMock.Object, [config]); + var httpContext = GetHttpContext(); + + // Act + bool result = await rule.EvaluateAsync(httpContext); + + // Assert + Assert.False(result); + } + + #endregion + + private static DefaultHttpContext GetHttpContext(string ip = "192.168.1.100") + { + var httpContext = new DefaultHttpContext(); + httpContext.Connection.RemoteIpAddress = IPAddress.Parse(ip); + return httpContext; + } + } +} diff --git a/RateLimiter/Configuration/CooldownConfig.cs b/RateLimiter/Configuration/CooldownConfig.cs new file mode 100644 index 00000000..eccc2118 --- /dev/null +++ b/RateLimiter/Configuration/CooldownConfig.cs @@ -0,0 +1,17 @@ +namespace RateLimiter.Configuration +{ + public class CooldownConfig + { + public int Seconds { get; set; } + + public CooldownConfig() + { + + } + + public CooldownConfig(int seconds) + { + Seconds = seconds; + } + } +} diff --git a/RateLimiter/Configuration/Country.cs b/RateLimiter/Configuration/Country.cs new file mode 100644 index 00000000..92ae568a --- /dev/null +++ b/RateLimiter/Configuration/Country.cs @@ -0,0 +1,10 @@ +namespace RateLimiter.Configuration +{ + // Enum not used to skip calling ".ToString()" every time + public static class Country + { + public const string Default = nameof(Default); + public const string EU = nameof(EU); + public const string US = nameof(US); + } +} diff --git a/RateLimiter/Configuration/FixedWindowConfig.cs b/RateLimiter/Configuration/FixedWindowConfig.cs new file mode 100644 index 00000000..4b53444f --- /dev/null +++ b/RateLimiter/Configuration/FixedWindowConfig.cs @@ -0,0 +1,19 @@ +namespace RateLimiter.Configuration +{ + public class FixedWindowConfig + { + public int Limit { get; set; } + public int Seconds { get; set; } + + public FixedWindowConfig() + { + + } + + public FixedWindowConfig(int limit, int seconds) + { + Limit = limit; + Seconds = seconds; + } + } +} diff --git a/RateLimiter/Configuration/GeoBasedConfig.cs b/RateLimiter/Configuration/GeoBasedConfig.cs new file mode 100644 index 00000000..975fd4a2 --- /dev/null +++ b/RateLimiter/Configuration/GeoBasedConfig.cs @@ -0,0 +1,19 @@ +namespace RateLimiter.Configuration +{ + public class GeoBasedConfig + { + public string Country { get; set; } + public int Seconds { get; set; } + + public GeoBasedConfig() + { + Country = Configuration.Country.Default; + } + + public GeoBasedConfig(string country, int seconds) + { + Country = country; + Seconds = seconds; + } + } +} diff --git a/RateLimiter/Configuration/IpWhitelistConfig.cs b/RateLimiter/Configuration/IpWhitelistConfig.cs new file mode 100644 index 00000000..31caa475 --- /dev/null +++ b/RateLimiter/Configuration/IpWhitelistConfig.cs @@ -0,0 +1,19 @@ +using System.Collections.Generic; + +namespace RateLimiter.Configuration +{ + public class IpWhitelistConfig + { + public IEnumerable Blocked { get; set; } + + public IpWhitelistConfig() + { + Blocked = []; + } + + public IpWhitelistConfig(IEnumerable blocked) + { + Blocked = blocked; + } + } +} diff --git a/RateLimiter/Configuration/RateLimiterConfig.cs b/RateLimiter/Configuration/RateLimiterConfig.cs new file mode 100644 index 00000000..a7d0f40a --- /dev/null +++ b/RateLimiter/Configuration/RateLimiterConfig.cs @@ -0,0 +1,12 @@ +using System.Collections.Generic; + +namespace RateLimiter.Configuration +{ + public class RateLimiterConfig + { + public FixedWindowConfig FixedWindowConfig { get; set; } = new(3, 5); + public CooldownConfig CooldownConfig { get; set; } = new(1); + public IEnumerable GeoBasedConfig { get; set; } = []; + public IEnumerable IpBlacklistConfig { get; set; } = []; + } +} diff --git a/RateLimiter/Policy/RateLimitPolicy.cs b/RateLimiter/Policy/RateLimitPolicy.cs new file mode 100644 index 00000000..7d35fdc0 --- /dev/null +++ b/RateLimiter/Policy/RateLimitPolicy.cs @@ -0,0 +1,32 @@ +using Microsoft.AspNetCore.Http; +using RateLimiter.Rules; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace RateLimiter.Policy +{ + public class RateLimitPolicy() + { + private readonly List _rateLimiters = []; + + public void AddRule(IRateLimiterRule rule) + { + _rateLimiters.Add(rule); + } + + public async Task EvaluateAllAsync(HttpContext httpContext) + { + // Use Task.WhenAny() if evaluation is long-running + + foreach (var rule in _rateLimiters) + { + if (!await rule.EvaluateAsync(httpContext)) + { + return false; + } + } + + return true; + } + } +} diff --git a/RateLimiter/Policy/RateLimitPolicyRegistry.cs b/RateLimiter/Policy/RateLimitPolicyRegistry.cs new file mode 100644 index 00000000..0b61b22a --- /dev/null +++ b/RateLimiter/Policy/RateLimitPolicyRegistry.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; + +namespace RateLimiter.Policy +{ + public class RateLimitPolicyRegistry() + { + private readonly Dictionary _policies = []; + + public void AddPolicy(string name, Action configure) + { + var policy = new RateLimitPolicy(); + configure(policy); + // overrides if already exists + _policies[name] = policy; + } + + public RateLimitPolicy? GetPolicy(string name) + { + return _policies.TryGetValue(name, out var policy) ? policy : null; + } + } +} diff --git a/RateLimiter/RateLimiter.csproj b/RateLimiter/RateLimiter.csproj index 19962f52..b7d3f96e 100644 --- a/RateLimiter/RateLimiter.csproj +++ b/RateLimiter/RateLimiter.csproj @@ -4,4 +4,8 @@ latest enable + + + + \ No newline at end of file diff --git a/RateLimiter/Rules/CooldownRule.cs b/RateLimiter/Rules/CooldownRule.cs new file mode 100644 index 00000000..d3705769 --- /dev/null +++ b/RateLimiter/Rules/CooldownRule.cs @@ -0,0 +1,37 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Caching.Distributed; +using System; +using System.Text.Json; +using System.Threading.Tasks; + +namespace RateLimiter.Rules +{ + public class CooldownRule(IDistributedCache cache, int seconds) : IRateLimiterRule + { + private readonly IDistributedCache _cache = cache; + private readonly TimeSpan _timeSpan = TimeSpan.FromSeconds(seconds); + + public async Task EvaluateAsync(HttpContext httpContext) + { + var request = httpContext.Request; + var serviceProvider = request.HttpContext.RequestServices; + + string? key = ((IRateLimiterRule)this).GetKey(httpContext); + if (string.IsNullOrEmpty(key)) return false; + + var cacheKey = $"RateLimit_{nameof(CooldownRule)}_{request.Method}_{request.Path}_{key}"; + var cacheValue = await _cache.GetStringAsync(cacheKey); + if (cacheValue != null) + { + return false; + } + + var options = new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = _timeSpan + }; + await _cache.SetStringAsync(cacheKey, JsonSerializer.Serialize(0), options); + return true; + } + } +} diff --git a/RateLimiter/Rules/FixedWindowRule.cs b/RateLimiter/Rules/FixedWindowRule.cs new file mode 100644 index 00000000..c171b436 --- /dev/null +++ b/RateLimiter/Rules/FixedWindowRule.cs @@ -0,0 +1,40 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Caching.Distributed; +using System; +using System.Text.Json; +using System.Threading.Tasks; + +namespace RateLimiter.Rules +{ + public class FixedWindowRule(IDistributedCache cache, int limit, int seconds) : IRateLimiterRule + { + private readonly IDistributedCache _cache = cache; + private readonly int _limit = limit; + private readonly TimeSpan _timeSpan = TimeSpan.FromSeconds(seconds); + + public async Task EvaluateAsync(HttpContext httpContext) + { + var request = httpContext.Request; + var serviceProvider = request.HttpContext.RequestServices; + + string? key = ((IRateLimiterRule)this).GetKey(httpContext); + if (string.IsNullOrEmpty(key)) return false; + + var cacheKey = $"RateLimit_{nameof(FixedWindowRule)}_{request.Method}_{request.Path}_{key}"; + var cacheValue = await _cache.GetStringAsync(cacheKey); + int requestCount = cacheValue != null ? JsonSerializer.Deserialize(cacheValue) : 0; + if (requestCount >= _limit) + { + return false; + } + + requestCount++; + var options = new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = _timeSpan + }; + await _cache.SetStringAsync(cacheKey, JsonSerializer.Serialize(requestCount), options); + return true; + } + } +} diff --git a/RateLimiter/Rules/GeoBasedRule.cs b/RateLimiter/Rules/GeoBasedRule.cs new file mode 100644 index 00000000..fd563cbb --- /dev/null +++ b/RateLimiter/Rules/GeoBasedRule.cs @@ -0,0 +1,55 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Caching.Distributed; +using RateLimiter.Configuration; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading.Tasks; + +namespace RateLimiter.Rules +{ + public class GeoBasedRule(IDistributedCache cache, + IEnumerable configs) + : IRateLimiterRule + { + private readonly IDistributedCache _cache = cache; + private readonly Dictionary _configs = configs.ToDictionary(x => x.Country); + + public async Task EvaluateAsync(HttpContext httpContext) + { + string country = ResolveCountry(); + if (!_configs.TryGetValue(country, out var config)) + { + // no limit for current country, so skip it + return true; + } + + var request = httpContext.Request; + var serviceProvider = request.HttpContext.RequestServices; + + string? key = ((IRateLimiterRule)this).GetKey(httpContext); + if (string.IsNullOrEmpty(key)) return false; + + var cacheKey = $"RateLimit_{nameof(GeoBasedRule)}_{request.Method}_{request.Path}_{key}"; + var cacheValue = await _cache.GetStringAsync(cacheKey); + if (cacheValue != null) + { + return false; + } + + var options = new DistributedCacheEntryOptions + { + AbsoluteExpirationRelativeToNow = TimeSpan.FromSeconds(config.Seconds) + }; + await _cache.SetStringAsync(cacheKey, JsonSerializer.Serialize(0), options); + return true; + } + + private string ResolveCountry() + { + // get from token or resolve by ip + return Country.EU; + } + } +} diff --git a/RateLimiter/Rules/IRateLimiterRule.cs b/RateLimiter/Rules/IRateLimiterRule.cs new file mode 100644 index 00000000..512079a7 --- /dev/null +++ b/RateLimiter/Rules/IRateLimiterRule.cs @@ -0,0 +1,23 @@ +using Microsoft.AspNetCore.Http; +using System.Threading.Tasks; + +namespace RateLimiter.Rules +{ + public interface IRateLimiterRule + { + Task EvaluateAsync(HttpContext httpContext); + + string? GetKey(HttpContext context) + { + var user = context.User; + var userId = user?.Identity?.IsAuthenticated == true + ? user.FindFirst("sub")?.Value + : null; + + var ipAddress = context.Connection.RemoteIpAddress?.ToString(); + var rateLimitKey = !string.IsNullOrEmpty(userId) ? userId : ipAddress; + + return rateLimitKey; + } + } +} \ No newline at end of file diff --git a/RateLimiter/Rules/IpBlacklistRule.cs b/RateLimiter/Rules/IpBlacklistRule.cs new file mode 100644 index 00000000..bd516579 --- /dev/null +++ b/RateLimiter/Rules/IpBlacklistRule.cs @@ -0,0 +1,19 @@ +using Microsoft.AspNetCore.Http; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace RateLimiter.Rules +{ + public class IpBlacklistRule(IEnumerable blocked) : IRateLimiterRule + { + private readonly HashSet _blockedIps = blocked.ToHashSet(); + + public Task EvaluateAsync(HttpContext httpContext) + { + string ip = httpContext.Connection.RemoteIpAddress.ToString(); + + return Task.FromResult(!_blockedIps.Contains(ip)); + } + } +} diff --git a/RateLimiter/Usage/RateLimitAttribute.cs b/RateLimiter/Usage/RateLimitAttribute.cs new file mode 100644 index 00000000..62d8c9a9 --- /dev/null +++ b/RateLimiter/Usage/RateLimitAttribute.cs @@ -0,0 +1,38 @@ +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.Extensions.DependencyInjection; +using RateLimiter.Policy; +using System; +using System.Threading.Tasks; + +namespace RateLimiter.Usage +{ + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class)] + public class RateLimitAttribute(string policyName) : ActionFilterAttribute + { + public override async Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next) + { + var serviceProvider = context.HttpContext.RequestServices; + + var policyRegistry = serviceProvider.GetRequiredService(); + var policy = policyRegistry.GetPolicy(policyName); + if (policy == null) + { + context.Result = new StatusCodeResult(500); + return; + } + + if (!await policy.EvaluateAllAsync(context.HttpContext)) + { + context.Result = new ContentResult + { + StatusCode = 429, + Content = "Rate limit exceeded. Try again later." + }; + return; + } + + await next(); + } + } +} diff --git a/RateLimiter/Usage/appsettings.json b/RateLimiter/Usage/appsettings.json new file mode 100644 index 00000000..d9adf2d6 --- /dev/null +++ b/RateLimiter/Usage/appsettings.json @@ -0,0 +1,29 @@ +{ + "RateLimiterConfig": { + "FixedWindowConfig": { + "Limit": 2, + "Seconds": 10 + }, + "CooldownConfig": { + "Seconds": 5 + }, + "GeoBasedConfig": [ + { + "Country": "Default", + "Seconds": 5 + }, + { + "Country": "EU", + "Seconds": 5 + }, + { + "Country": "US", + "Seconds": 10 + } + ], + "IpBlacklistConfig": [ + "192.168.1.100", + "192.168.1.200" + ] + } +}