mirror of
https://github.com/samuelclay/NewsBlur.git
synced 2025-04-13 09:42:01 +00:00
110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
import functools
|
|
import hashlib
|
|
from datetime import datetime, timedelta
|
|
|
|
from django.core.cache import cache
|
|
from django.http import HttpResponse
|
|
|
|
|
|
class ratelimit(object):
|
|
"Instances of this class can be used as decorators"
|
|
# This class is designed to be sub-classed
|
|
minutes = 1 # The time period
|
|
requests = 4 # Number of allowed requests in that time period
|
|
use_path = False # Whether to include the request path in the key
|
|
|
|
prefix = "rl-" # Prefix for memcache key
|
|
|
|
def __init__(self, **options):
|
|
for key, value in options.items():
|
|
setattr(self, key, value)
|
|
|
|
def __call__(self, fn):
|
|
def wrapper(request, *args, **kwargs):
|
|
return self.view_wrapper(request, fn, *args, **kwargs)
|
|
|
|
functools.update_wrapper(wrapper, fn)
|
|
return wrapper
|
|
|
|
def view_wrapper(self, request, fn, *args, **kwargs):
|
|
if not self.should_ratelimit(request):
|
|
return fn(request, *args, **kwargs)
|
|
|
|
counts = list(self.get_counters(request).values())
|
|
|
|
# Increment rate limiting counter
|
|
self.cache_incr(self.current_key(request))
|
|
|
|
# Have they failed?
|
|
if sum(counts) >= self.requests:
|
|
return self.disallowed(request)
|
|
|
|
return fn(request, *args, **kwargs)
|
|
|
|
def cache_get_many(self, keys):
|
|
return cache.get_many(keys)
|
|
|
|
def cache_incr(self, key):
|
|
# memcache is only backend that can increment atomically
|
|
try:
|
|
# add first, to ensure the key exists
|
|
cache.add(key, 0, self.expire_after())
|
|
cache.incr(key)
|
|
except (AttributeError, ValueError):
|
|
cache.set(key, cache.get(key, 0) + 1, self.expire_after())
|
|
|
|
def should_ratelimit(self, request):
|
|
return True
|
|
|
|
def get_counters(self, request):
|
|
return self.cache_get_many(self.keys_to_check(request))
|
|
|
|
def keys_to_check(self, request):
|
|
extra = self.key_extra(request)
|
|
now = datetime.now()
|
|
return [
|
|
"%s%s-%s" % (self.prefix, extra, (now - timedelta(minutes=minute)).strftime("%Y%m%d%H%M"))
|
|
for minute in range(self.minutes + 1)
|
|
]
|
|
|
|
def current_key(self, request):
|
|
return "%s%s-%s" % (self.prefix, self.key_extra(request), datetime.now().strftime("%Y%m%d%H%M"))
|
|
|
|
def key_extra(self, request):
|
|
key = getattr(request.session, "session_key", "")
|
|
if not key:
|
|
key = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0]
|
|
if not key:
|
|
key = request.COOKIES.get("newsblur_sessionid", "")
|
|
if not key:
|
|
key = request.META.get("HTTP_USER_AGENT", "")
|
|
|
|
# Add request path to the key if use_path is enabled
|
|
if getattr(self, 'use_path', False):
|
|
path = request.path
|
|
key = f"{key}-{path}"
|
|
|
|
return key
|
|
|
|
def disallowed(self, request):
|
|
return HttpResponse("Rate limit exceeded", status=429)
|
|
|
|
def expire_after(self):
|
|
"Used for setting the memcached cache expiry"
|
|
return (self.minutes + 1) * 60
|
|
|
|
|
|
class ratelimit_post(ratelimit):
|
|
"Rate limit POSTs - can be used to protect a login form"
|
|
key_field = None # If provided, this POST var will affect the rate limit
|
|
|
|
def should_ratelimit(self, request):
|
|
return request.method == "POST"
|
|
|
|
def key_extra(self, request):
|
|
# IP address and key_field (if it is set)
|
|
extra = super(ratelimit_post, self).key_extra(request)
|
|
if self.key_field:
|
|
value = hashlib.sha1((request.POST.get(self.key_field, "")).encode("utf-8")).hexdigest()
|
|
extra += "-" + value
|
|
return extra
|