diff --git a/WebApiThrottle/ThrottlingCore.cs b/WebApiThrottle/ThrottlingCore.cs index a179792..50dd1b5 100644 --- a/WebApiThrottle/ThrottlingCore.cs +++ b/WebApiThrottle/ThrottlingCore.cs @@ -184,7 +184,7 @@ internal List> RatesWithDefaults(List= DateTime.UtcNow) { // increment request count - var totalRequests = entry.Value.TotalRequests + 1; + var totalRequests = entry.Value.TotalRequests + effectiveInffective; // deep copy throttleCounter = new ThrottleCounter diff --git a/WebApiThrottle/ThrottlingFilter.cs b/WebApiThrottle/ThrottlingFilter.cs index 97dfdf1..16f4414 100644 --- a/WebApiThrottle/ThrottlingFilter.cs +++ b/WebApiThrottle/ThrottlingFilter.cs @@ -182,7 +182,7 @@ public override void OnActionExecuting(HttpActionContext actionContext) { // increment counter var requestId = ComputeThrottleKey(identity, rateLimitPeriod); - var throttleCounter = core.ProcessRequest(timeSpan, requestId); + var throttleCounter = core.ProcessRequest(timeSpan, requestId, this.EffectiveIncrement); // check if key expired if (throttleCounter.Timestamp + timeSpan < DateTime.UtcNow) @@ -222,6 +222,11 @@ public override void OnActionExecuting(HttpActionContext actionContext) base.OnActionExecuting(actionContext); } + protected virtual int EffectiveIncrement + { + get { return 1; } + } + protected virtual RequestIdentity SetIdentity(HttpRequestMessage request) { var entry = new RequestIdentity(); diff --git a/WebApiThrottle/ThrottlingHandler.cs b/WebApiThrottle/ThrottlingHandler.cs index 93e5bb1..6ea4aa9 100644 --- a/WebApiThrottle/ThrottlingHandler.cs +++ b/WebApiThrottle/ThrottlingHandler.cs @@ -53,9 +53,9 @@ public ThrottlingHandler() /// /// The IpAddressParser /// - public ThrottlingHandler(ThrottlePolicy policy, - IPolicyRepository policyRepository, - IThrottleRepository repository, + public ThrottlingHandler(ThrottlePolicy policy, + IPolicyRepository policyRepository, + IThrottleRepository repository, IThrottleLogger logger, IIpAddressParser ipAddressParser = null) { @@ -178,7 +178,7 @@ protected override Task SendAsync(HttpRequestMessage reques { // increment counter var requestId = ComputeThrottleKey(identity, rateLimitPeriod); - var throttleCounter = core.ProcessRequest(timeSpan, requestId); + var throttleCounter = core.ProcessRequest(timeSpan, requestId, this.EffectiveIncrement); // check if key expired if (throttleCounter.Timestamp + timeSpan < DateTime.UtcNow) @@ -195,8 +195,8 @@ protected override Task SendAsync(HttpRequestMessage reques Logger.Log(core.ComputeLogEntry(requestId, identity, throttleCounter, rateLimitPeriod.ToString(), rateLimit, request)); } - var message = !string.IsNullOrEmpty(this.QuotaExceededMessage) - ? this.QuotaExceededMessage + var message = !string.IsNullOrEmpty(this.QuotaExceededMessage) + ? this.QuotaExceededMessage : "API calls quota exceeded! maximum admitted {0} per {1}."; var content = this.QuotaExceededContent != null @@ -217,6 +217,11 @@ protected override Task SendAsync(HttpRequestMessage reques return base.SendAsync(request, cancellationToken); } + protected virtual int EffectiveIncrement + { + get { return 1; } + } + protected IPAddress GetClientIp(HttpRequestMessage request) { return core.GetClientIp(request); @@ -227,8 +232,8 @@ protected virtual RequestIdentity SetIdentity(HttpRequestMessage request) var entry = new RequestIdentity(); entry.ClientIp = core.GetClientIp(request).ToString(); entry.Endpoint = request.RequestUri.AbsolutePath.ToLowerInvariant(); - entry.ClientKey = request.Headers.Contains("Authorization-Token") - ? request.Headers.GetValues("Authorization-Token").First() + entry.ClientKey = request.Headers.Contains("Authorization-Token") + ? request.Headers.GetValues("Authorization-Token").First() : "anon"; return entry; diff --git a/WebApiThrottle/ThrottlingMiddleware.cs b/WebApiThrottle/ThrottlingMiddleware.cs index 0178d16..157949d 100644 --- a/WebApiThrottle/ThrottlingMiddleware.cs +++ b/WebApiThrottle/ThrottlingMiddleware.cs @@ -49,10 +49,10 @@ public ThrottlingMiddleware(OwinMiddleware next) /// /// The IpAddressParser /// - public ThrottlingMiddleware(OwinMiddleware next, - ThrottlePolicy policy, - IPolicyRepository policyRepository, - IThrottleRepository repository, + public ThrottlingMiddleware(OwinMiddleware next, + ThrottlePolicy policy, + IPolicyRepository policyRepository, + IThrottleRepository repository, IThrottleLogger logger, IIpAddressParser ipAddressParser) : base(next) @@ -174,7 +174,7 @@ public override async Task Invoke(IOwinContext context) { // increment counter var requestId = ComputeThrottleKey(identity, rateLimitPeriod); - var throttleCounter = core.ProcessRequest(timeSpan, requestId); + var throttleCounter = core.ProcessRequest(timeSpan, requestId, this.EffectiveIncrement); // check if key expired if (throttleCounter.Timestamp + timeSpan < DateTime.UtcNow) @@ -213,6 +213,11 @@ public override async Task Invoke(IOwinContext context) await Next.Invoke(context); } + protected virtual int EffectiveIncrement + { + get { return 1; } + } + protected virtual RequestIdentity SetIdentity(IOwinRequest request) { var entry = new RequestIdentity();