69 lines
1.9 KiB
Python
69 lines
1.9 KiB
Python
# SPDX-License-Identifier: AGPL-3.0-only
|
|
|
|
import anyio
|
|
import contextlib
|
|
from functools import wraps
|
|
from datetime import datetime, timezone
|
|
from dateutil.parser import parse as parsedate
|
|
|
|
def as_corofunc(f):
|
|
@wraps(f)
|
|
async def wrapped(*args, **kwargs):
|
|
# can't decide if i want an `anyio.sleep(0)` here.
|
|
return f(*args, **kwargs)
|
|
return wrapped
|
|
|
|
def as_async_cm(cls):
|
|
@wraps(cls, updated=()) # cls.__dict__ doesn't support .update()
|
|
class wrapped(cls, contextlib.AbstractAsyncContextManager):
|
|
__aenter__ = as_corofunc(cls.__enter__)
|
|
__aexit__ = as_corofunc(cls.__exit__)
|
|
return wrapped
|
|
|
|
suppress = as_async_cm(contextlib.suppress)
|
|
|
|
def shield(f):
|
|
@wraps(f)
|
|
async def shielded(*args, **kwargs):
|
|
with anyio.CancelScope(shield=True):
|
|
return await f(*args, **kwargs)
|
|
return shielded
|
|
|
|
def removeprefix(s, prefix):
|
|
try:
|
|
return s.removeprefix(prefix)
|
|
except AttributeError:
|
|
# compatibility for pre-3.9
|
|
return s[len(prefix):] if s.startswith(prefix) else s
|
|
|
|
async def sleep_until(dt):
|
|
await anyio.sleep((dt - datetime.now(timezone.utc)).total_seconds())
|
|
|
|
class HandleRateLimits:
|
|
def __init__(self, http):
|
|
self.http = http
|
|
|
|
def request(self, *args, **kwargs):
|
|
return _RateLimitContextManager(self.http, args, kwargs)
|
|
|
|
class _RateLimitContextManager(contextlib.AbstractAsyncContextManager):
|
|
def __init__(self, http, args, kwargs):
|
|
self.http = http
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
async def __aenter__(self):
|
|
self._request_cm = self.http.request(*self.args, **self.kwargs)
|
|
return await self._do_enter()
|
|
|
|
async def _do_enter(self):
|
|
resp = await self._request_cm.__aenter__()
|
|
if resp.headers.get('X-RateLimit-Remaining') not in {'0', '1'}:
|
|
return resp
|
|
|
|
await sleep_until(parsedate(resp.headers['X-RateLimit-Reset']))
|
|
await self._request_cm.__aexit__(*(None,)*3)
|
|
return await self.__aenter__()
|
|
|
|
async def __aexit__(self, *excinfo):
|
|
return await self._request_cm.__aexit__(*excinfo)
|