mirror of
				https://github.com/samuelclay/NewsBlur.git
				synced 2025-11-01 09:09:51 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			103 lines
		
	
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			103 lines
		
	
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import functools
 | 
						|
import hashlib
 | 
						|
from datetime import datetime, timedelta
 | 
						|
 | 
						|
from django.core.cache import cache
 | 
						|
from django.http import HttpResponse
 | 
						|
 | 
						|
 | 
						|
class ratelimit(object):
 | 
						|
    "Instances of this class can be used as decorators"
 | 
						|
    # This class is designed to be sub-classed
 | 
						|
    minutes = 1  # The time period
 | 
						|
    requests = 4  # Number of allowed requests in that time period
 | 
						|
 | 
						|
    prefix = "rl-"  # Prefix for memcache key
 | 
						|
 | 
						|
    def __init__(self, **options):
 | 
						|
        for key, value in options.items():
 | 
						|
            setattr(self, key, value)
 | 
						|
 | 
						|
    def __call__(self, fn):
 | 
						|
        def wrapper(request, *args, **kwargs):
 | 
						|
            return self.view_wrapper(request, fn, *args, **kwargs)
 | 
						|
 | 
						|
        functools.update_wrapper(wrapper, fn)
 | 
						|
        return wrapper
 | 
						|
 | 
						|
    def view_wrapper(self, request, fn, *args, **kwargs):
 | 
						|
        if not self.should_ratelimit(request):
 | 
						|
            return fn(request, *args, **kwargs)
 | 
						|
 | 
						|
        counts = list(self.get_counters(request).values())
 | 
						|
 | 
						|
        # Increment rate limiting counter
 | 
						|
        self.cache_incr(self.current_key(request))
 | 
						|
 | 
						|
        # Have they failed?
 | 
						|
        if sum(counts) >= self.requests:
 | 
						|
            return self.disallowed(request)
 | 
						|
 | 
						|
        return fn(request, *args, **kwargs)
 | 
						|
 | 
						|
    def cache_get_many(self, keys):
 | 
						|
        return cache.get_many(keys)
 | 
						|
 | 
						|
    def cache_incr(self, key):
 | 
						|
        # memcache is only backend that can increment atomically
 | 
						|
        try:
 | 
						|
            # add first, to ensure the key exists
 | 
						|
            cache.add(key, 0, self.expire_after())
 | 
						|
            cache.incr(key)
 | 
						|
        except (AttributeError, ValueError):
 | 
						|
            cache.set(key, cache.get(key, 0) + 1, self.expire_after())
 | 
						|
 | 
						|
    def should_ratelimit(self, request):
 | 
						|
        return True
 | 
						|
 | 
						|
    def get_counters(self, request):
 | 
						|
        return self.cache_get_many(self.keys_to_check(request))
 | 
						|
 | 
						|
    def keys_to_check(self, request):
 | 
						|
        extra = self.key_extra(request)
 | 
						|
        now = datetime.now()
 | 
						|
        return [
 | 
						|
            "%s%s-%s" % (self.prefix, extra, (now - timedelta(minutes=minute)).strftime("%Y%m%d%H%M"))
 | 
						|
            for minute in range(self.minutes + 1)
 | 
						|
        ]
 | 
						|
 | 
						|
    def current_key(self, request):
 | 
						|
        return "%s%s-%s" % (self.prefix, self.key_extra(request), datetime.now().strftime("%Y%m%d%H%M"))
 | 
						|
 | 
						|
    def key_extra(self, request):
 | 
						|
        key = getattr(request.session, "session_key", "")
 | 
						|
        if not key:
 | 
						|
            key = request.META.get("HTTP_X_FORWARDED_FOR", "").split(",")[0]
 | 
						|
        if not key:
 | 
						|
            key = request.COOKIES.get("newsblur_sessionid", "")
 | 
						|
        if not key:
 | 
						|
            key = request.META.get("HTTP_USER_AGENT", "")
 | 
						|
        return key
 | 
						|
 | 
						|
    def disallowed(self, request):
 | 
						|
        return HttpResponse("Rate limit exceeded", status=429)
 | 
						|
 | 
						|
    def expire_after(self):
 | 
						|
        "Used for setting the memcached cache expiry"
 | 
						|
        return (self.minutes + 1) * 60
 | 
						|
 | 
						|
 | 
						|
class ratelimit_post(ratelimit):
 | 
						|
    "Rate limit POSTs - can be used to protect a login form"
 | 
						|
    key_field = None  # If provided, this POST var will affect the rate limit
 | 
						|
 | 
						|
    def should_ratelimit(self, request):
 | 
						|
        return request.method == "POST"
 | 
						|
 | 
						|
    def key_extra(self, request):
 | 
						|
        # IP address and key_field (if it is set)
 | 
						|
        extra = super(ratelimit_post, self).key_extra(request)
 | 
						|
        if self.key_field:
 | 
						|
            value = hashlib.sha1((request.POST.get(self.key_field, "")).encode("utf-8")).hexdigest()
 | 
						|
            extra += "-" + value
 | 
						|
        return extra
 |