Skip to content

Commit

Permalink
Bump version to 0.1.7 and add sliding window rate limit
Browse files Browse the repository at this point in the history
  • Loading branch information
xyzizz committed Apr 11, 2024
1 parent 8d179c0 commit aabe3c3
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 25 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ async def websocket_endpoint(websocket: WebSocket):

The lua script used.

### fixed window
```lua
local key = KEYS[1]
local limit = tonumber(ARGV[1])
Expand All @@ -162,6 +163,27 @@ else
end
```

### sliding window
```lua
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = tonumber(ARGV[2])
local current_time = redis.call('TIME')[1]
local start_time = current_time - expire_time / 1000

redis.call('ZREMRANGEBYSCORE', key, 0, start_time)

local current = redis.call('ZCARD', key)

if current >= limit then
return redis.call("PTTL",key)
else
redis.call("ZADD", key, current_time, current_time)
redis.call('PEXPIRE', key, expire_time)
return 0
end
```

## License

This project is licensed under the
Expand Down
11 changes: 10 additions & 1 deletion examples/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from fastapi_limiter import FastAPILimiter
from fastapi_limiter.depends import RateLimiter, WebSocketRateLimiter

from fastapi_limiter.utils import RateLimitType

@asynccontextmanager
async def lifespan(_: FastAPI):
Expand Down Expand Up @@ -52,6 +52,15 @@ async def websocket_endpoint(websocket: WebSocket):
except HTTPException:
await websocket.send_text("Hello again")

@app.get(
"/test_sliding_window",
dependencies=[
Depends(RateLimiter(times=2, seconds=5, rate_limit_type=RateLimitType.SLIDING_WINDOW))
],
)
async def test_sliding_window():
return {"msg": "Hello World"}


if __name__ == "__main__":
uvicorn.run("main:app", debug=True, reload=True)
22 changes: 5 additions & 17 deletions fastapi_limiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from starlette.responses import Response
from starlette.status import HTTP_429_TOO_MANY_REQUESTS
from starlette.websockets import WebSocket
from fastapi_limiter.utils import LuaScript


async def default_identifier(request: Union[Request, WebSocket]):
Expand Down Expand Up @@ -51,22 +52,8 @@ class FastAPILimiter:
identifier: Optional[Callable] = None
http_callback: Optional[Callable] = None
ws_callback: Optional[Callable] = None
lua_script = """local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = ARGV[2]
local current = tonumber(redis.call('get', key) or "0")
if current > 0 then
if current + 1 > limit then
return redis.call("PTTL",key)
else
redis.call("INCR", key)
return 0
end
else
redis.call("SET", key, 1,"px",expire_time)
return 0
end"""
lua_sha_fix_window: Optional[str] = None
lua_sha_sliding_window: Optional[str] = None

@classmethod
async def init(
Expand All @@ -82,7 +69,8 @@ async def init(
cls.identifier = identifier
cls.http_callback = http_callback
cls.ws_callback = ws_callback
cls.lua_sha = await redis.script_load(cls.lua_script)
cls.lua_sha_fix_window = await redis.script_load(LuaScript.FIXED_WINDOW_LIMIT_SCRIPT.value)
cls.lua_sha_sliding_window = await redis.script_load(LuaScript.SLIDING_WINDOW_LIMIT_SCRIPT.value)

@classmethod
async def close(cls) -> None:
Expand Down
24 changes: 18 additions & 6 deletions fastapi_limiter/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from starlette.websockets import WebSocket

from fastapi_limiter import FastAPILimiter
from fastapi_limiter.utils import RateLimitType


class RateLimiter:
Expand All @@ -19,16 +20,30 @@ def __init__(
hours: Annotated[int, Field(ge=-1)] = 0,
identifier: Optional[Callable] = None,
callback: Optional[Callable] = None,
rate_limit_type: RateLimitType = RateLimitType.FIXED_WINDOW
):
self.times = times
self.milliseconds = milliseconds + 1000 * seconds + 60000 * minutes + 3600000 * hours
self.identifier = identifier
self.callback = callback
self.rate_limit_type = rate_limit_type

async def _check(self, key):
def _get_lua_sha(self, specific_lua_sha=None):
if specific_lua_sha:
return specific_lua_sha
elif self.rate_limit_type is RateLimitType.SLIDING_WINDOW:
return FastAPILimiter.lua_sha_sliding_window
return FastAPILimiter.lua_sha_fix_window


async def _check(self, key, specific_lua_sha=None):
redis = FastAPILimiter.redis
pexpire = await redis.evalsha(
FastAPILimiter.lua_sha, 1, key, str(self.times), str(self.milliseconds)
self._get_lua_sha(specific_lua_sha),
1,
key,
str(self.times),
str(self.milliseconds)
)
return pexpire

Expand All @@ -53,10 +68,7 @@ async def __call__(self, request: Request, response: Response):
try:
pexpire = await self._check(key)
except pyredis.exceptions.NoScriptError:
FastAPILimiter.lua_sha = await FastAPILimiter.redis.script_load(
FastAPILimiter.lua_script
)
pexpire = await self._check(key)
pexpire = await self._check(key, specific_lua_sha=FastAPILimiter.lua_sha_fix_window)
if pexpire != 0:
return await callback(request, response, pexpire)

Expand Down
47 changes: 47 additions & 0 deletions fastapi_limiter/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from enum import Enum


class RateLimitType(Enum):
FIXED_WINDOW = "fixed_window"
SLIDING_WINDOW = "sliding_window"


class LuaScript(Enum):
FIXED_WINDOW_LIMIT_SCRIPT = """
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = ARGV[2]
local current = tonumber(redis.call('get', key) or "0")
if current > 0 then
if current + 1 > limit then
return redis.call("PTTL",key)
else
redis.call("INCR", key)
return 0
end
else
redis.call("SET", key, 1, "px", expire_time)
return 0
end
"""
SLIDING_WINDOW_LIMIT_SCRIPT = """
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local expire_time = tonumber(ARGV[2])
local current_time = redis.call('TIME')[1]
local start_time = current_time - expire_time / 1000
redis.call('ZREMRANGEBYSCORE', key, 0, start_time)
local current = redis.call('ZCARD', key)
if current >= limit then
return redis.call("PTTL",key)
else
redis.call("ZADD", key, current_time, current_time)
redis.call('PEXPIRE', key, expire_time)
return 0
end
"""
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ packages = [
]
readme = "README.md"
repository = "https://github.com/long2ice/fastapi-limiter.git"
version = "0.1.6"
version = "0.1.7"

[tool.poetry.dependencies]
redis = ">=4.2.0rc1"
Expand Down
17 changes: 17 additions & 0 deletions tests/test_depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,20 @@ def test_limiter_websockets():
data = ws.receive_text()
assert data == "Hello, world"
ws.close()


def test_limiter_sliding_window():
with TestClient(app) as client:
def req(sleep_times, assert_code):
nonlocal client
response = client.get("/test_sliding_window")
assert response.status_code == assert_code
sleep(sleep_times)

req(4, 200) # 0s
req(1, 200) # 4s
req(1, 200) # 5s
req(1, 429) # 6s
req(1, 429) # 7s
req(1, 429) # 8s
req(1, 200) # 9s

0 comments on commit aabe3c3

Please sign in to comment.