diff --git a/README.md b/README.md index 92ca795..2203c1c 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,7 @@ async def websocket_endpoint(websocket: WebSocket): The lua script used. +### fixed window ```lua local key = KEYS[1] local limit = tonumber(ARGV[1]) @@ -162,6 +163,27 @@ else end ``` +### sliding window +```lua +local key = KEYS[1] +local limit = tonumber(ARGV[1]) +local expire_time = tonumber(ARGV[2]) +local current_time = redis.call('TIME')[1] +local start_time = current_time - expire_time / 1000 + +redis.call('ZREMRANGEBYSCORE', key, 0, start_time) + +local current = redis.call('ZCARD', key) + +if current >= limit then + return redis.call("PTTL",key) +else + redis.call("ZADD", key, current_time, current_time) + redis.call('PEXPIRE', key, expire_time) + return 0 +end +``` + ## License This project is licensed under the diff --git a/examples/main.py b/examples/main.py index 3e0a2b0..9eea407 100644 --- a/examples/main.py +++ b/examples/main.py @@ -6,7 +6,7 @@ from fastapi_limiter import FastAPILimiter from fastapi_limiter.depends import RateLimiter, WebSocketRateLimiter - +from fastapi_limiter.utils import RateLimitType @asynccontextmanager async def lifespan(_: FastAPI): @@ -52,6 +52,15 @@ async def websocket_endpoint(websocket: WebSocket): except HTTPException: await websocket.send_text("Hello again") +@app.get( + "/test_sliding_window", + dependencies=[ + Depends(RateLimiter(times=2, seconds=5, rate_limit_type=RateLimitType.SLIDING_WINDOW)) + ], +) +async def test_sliding_window(): + return {"msg": "Hello World"} + if __name__ == "__main__": uvicorn.run("main:app", debug=True, reload=True) diff --git a/fastapi_limiter/__init__.py b/fastapi_limiter/__init__.py index 5b48091..9505f11 100644 --- a/fastapi_limiter/__init__.py +++ b/fastapi_limiter/__init__.py @@ -6,6 +6,7 @@ from starlette.responses import Response from starlette.status import HTTP_429_TOO_MANY_REQUESTS from starlette.websockets import WebSocket +from fastapi_limiter.utils import LuaScript async def default_identifier(request: Union[Request, WebSocket]): @@ -51,22 +52,8 @@ class FastAPILimiter: identifier: Optional[Callable] = None http_callback: Optional[Callable] = None ws_callback: Optional[Callable] = None - lua_script = """local key = KEYS[1] -local limit = tonumber(ARGV[1]) -local expire_time = ARGV[2] - -local current = tonumber(redis.call('get', key) or "0") -if current > 0 then - if current + 1 > limit then - return redis.call("PTTL",key) - else - redis.call("INCR", key) - return 0 - end -else - redis.call("SET", key, 1,"px",expire_time) - return 0 -end""" + lua_sha_fix_window: Optional[str] = None + lua_sha_sliding_window: Optional[str] = None @classmethod async def init( @@ -82,7 +69,8 @@ async def init( cls.identifier = identifier cls.http_callback = http_callback cls.ws_callback = ws_callback - cls.lua_sha = await redis.script_load(cls.lua_script) + cls.lua_sha_fix_window = await redis.script_load(LuaScript.FIXED_WINDOW_LIMIT_SCRIPT.value) + cls.lua_sha_sliding_window = await redis.script_load(LuaScript.SLIDING_WINDOW_LIMIT_SCRIPT.value) @classmethod async def close(cls) -> None: diff --git a/fastapi_limiter/depends.py b/fastapi_limiter/depends.py index 295df94..5f39143 100644 --- a/fastapi_limiter/depends.py +++ b/fastapi_limiter/depends.py @@ -7,6 +7,7 @@ from starlette.websockets import WebSocket from fastapi_limiter import FastAPILimiter +from fastapi_limiter.utils import RateLimitType class RateLimiter: @@ -19,16 +20,30 @@ def __init__( hours: Annotated[int, Field(ge=-1)] = 0, identifier: Optional[Callable] = None, callback: Optional[Callable] = None, + rate_limit_type: RateLimitType = RateLimitType.FIXED_WINDOW ): self.times = times self.milliseconds = milliseconds + 1000 * seconds + 60000 * minutes + 3600000 * hours self.identifier = identifier self.callback = callback + self.rate_limit_type = rate_limit_type - async def _check(self, key): + def _get_lua_sha(self, specific_lua_sha=None): + if specific_lua_sha: + return specific_lua_sha + elif self.rate_limit_type is RateLimitType.SLIDING_WINDOW: + return FastAPILimiter.lua_sha_sliding_window + return FastAPILimiter.lua_sha_fix_window + + + async def _check(self, key, specific_lua_sha=None): redis = FastAPILimiter.redis pexpire = await redis.evalsha( - FastAPILimiter.lua_sha, 1, key, str(self.times), str(self.milliseconds) + self._get_lua_sha(specific_lua_sha), + 1, + key, + str(self.times), + str(self.milliseconds) ) return pexpire @@ -53,10 +68,7 @@ async def __call__(self, request: Request, response: Response): try: pexpire = await self._check(key) except pyredis.exceptions.NoScriptError: - FastAPILimiter.lua_sha = await FastAPILimiter.redis.script_load( - FastAPILimiter.lua_script - ) - pexpire = await self._check(key) + pexpire = await self._check(key, specific_lua_sha=FastAPILimiter.lua_sha_fix_window) if pexpire != 0: return await callback(request, response, pexpire) diff --git a/fastapi_limiter/utils.py b/fastapi_limiter/utils.py new file mode 100644 index 0000000..dcebc7a --- /dev/null +++ b/fastapi_limiter/utils.py @@ -0,0 +1,47 @@ +from enum import Enum + + +class RateLimitType(Enum): + FIXED_WINDOW = "fixed_window" + SLIDING_WINDOW = "sliding_window" + + +class LuaScript(Enum): + FIXED_WINDOW_LIMIT_SCRIPT = """ + local key = KEYS[1] + local limit = tonumber(ARGV[1]) + local expire_time = ARGV[2] + + local current = tonumber(redis.call('get', key) or "0") + + if current > 0 then + if current + 1 > limit then + return redis.call("PTTL",key) + else + redis.call("INCR", key) + return 0 + end + else + redis.call("SET", key, 1, "px", expire_time) + return 0 + end + """ + SLIDING_WINDOW_LIMIT_SCRIPT = """ + local key = KEYS[1] + local limit = tonumber(ARGV[1]) + local expire_time = tonumber(ARGV[2]) + local current_time = redis.call('TIME')[1] + local start_time = current_time - expire_time / 1000 + + redis.call('ZREMRANGEBYSCORE', key, 0, start_time) + + local current = redis.call('ZCARD', key) + + if current >= limit then + return redis.call("PTTL",key) + else + redis.call("ZADD", key, current_time, current_time) + redis.call('PEXPIRE', key, expire_time) + return 0 + end + """ diff --git a/pyproject.toml b/pyproject.toml index 574e7dc..609a6a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ packages = [ ] readme = "README.md" repository = "https://github.com/long2ice/fastapi-limiter.git" -version = "0.1.6" +version = "0.1.7" [tool.poetry.dependencies] redis = ">=4.2.0rc1" diff --git a/tests/test_depends.py b/tests/test_depends.py index 3979784..5c154f3 100644 --- a/tests/test_depends.py +++ b/tests/test_depends.py @@ -65,3 +65,20 @@ def test_limiter_websockets(): data = ws.receive_text() assert data == "Hello, world" ws.close() + + +def test_limiter_sliding_window(): + with TestClient(app) as client: + def req(sleep_times, assert_code): + nonlocal client + response = client.get("/test_sliding_window") + assert response.status_code == assert_code + sleep(sleep_times) + + req(4, 200) # 0s + req(1, 200) # 4s + req(1, 200) # 5s + req(1, 429) # 6s + req(1, 429) # 7s + req(1, 429) # 8s + req(1, 200) # 9s \ No newline at end of file