import functools import hashlib from datetime import datetime, timedelta from django.core.cache import cache from django.http import HttpResponse class ratelimit(object): "Instances of this class can be used as decorators" # This class is designed to be sub-classed minutes = 1 # The time period requests = 4 # Number of allowed requests in that time period use_path = False # Whether to include the request path in the key prefix = "rl-" # Prefix for memcache key def __init__(self, **options): for key, value in options.items(): setattr(self, key, value) def __call__(self, fn): def wrapper(request, *args, **kwargs): return self.view_wrapper(request, fn, *args, **kwargs) functools.update_wrapper(wrapper, fn) return wrapper def view_wrapper(self, request, fn, *args, **kwargs): if not self.should_ratelimit(request): return fn(request, *args, **kwargs) counts = list(self.get_counters(request).values()) # Increment rate limiting counter self.cache_incr(self.current_key(request)) # Have they failed? if sum(counts) >= self.requests: return self.disallowed(request) return fn(request, *args, **kwargs) def cache_get_many(self, keys): return cache.get_many(keys) def cache_incr(self, key): # memcache is only backend that can increment atomically try: # add first, to ensure the key exists cache.add(key, 0, self.expire_after()) cache.incr(key) except (AttributeError, ValueError): cache.set(key, cache.get(key, 0) + 1, self.expire_after()) def should_ratelimit(self, request): return True def get_counters(self, request): return self.cache_get_many(self.keys_to_check(request)) def keys_to_check(self, request): extra = self.key_extra(request) now = datetime.now() return [ "%s%s-%s" % (self.prefix, extra, (now - timedelta(minutes=minute)).strftime("%Y%m%d%H%M")) for minute in range(self.minutes + 1) ] def current_key(self, request): return "%s%s-%s" % (self.prefix, self.key_extra(request), datetime.now().strftime("%Y%m%d%H%M")) def key_extra(self, request): key = getattr(request.session, "session_key", "") if not key: key = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0] if not key: key = request.COOKIES.get("newsblur_sessionid", "") if not key: key = request.META.get("HTTP_USER_AGENT", "") # Add request path to the key if use_path is enabled if getattr(self, 'use_path', False): path = request.path key = f"{key}-{path}" return key def disallowed(self, request): return HttpResponse("Rate limit exceeded", status=429) def expire_after(self): "Used for setting the memcached cache expiry" return (self.minutes + 1) * 60 class ratelimit_post(ratelimit): "Rate limit POSTs - can be used to protect a login form" key_field = None # If provided, this POST var will affect the rate limit def should_ratelimit(self, request): return request.method == "POST" def key_extra(self, request): # IP address and key_field (if it is set) extra = super(ratelimit_post, self).key_extra(request) if self.key_field: value = hashlib.sha1((request.POST.get(self.key_field, "")).encode("utf-8")).hexdigest() extra += "-" + value return extra