diff --git a/curl_cffi/requests/headers.py b/curl_cffi/requests/headers.py index 000c03da..7c30cd19 100644 --- a/curl_cffi/requests/headers.py +++ b/curl_cffi/requests/headers.py @@ -102,7 +102,7 @@ def __init__( ] elif isinstance(headers, list): # list of "Name: Value" pairs - if isinstance(headers[0], (str, bytes)): + if isinstance(headers[0], str | bytes): sep = ":" if isinstance(headers[0], str) else b":" h = [] for line in headers: diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index 00157278..5c253646 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -7,7 +7,7 @@ import sys import threading import warnings -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, Future from contextlib import asynccontextmanager, contextmanager, suppress from collections.abc import Callable from io import BytesIO @@ -21,6 +21,7 @@ Union, cast, ) +from collections.abc import AsyncGenerator, Generator from urllib.parse import urlparse from datetime import timedelta @@ -245,7 +246,12 @@ def __init__( ) def _parse_response( - self, curl, buffer, header_buffer, default_encoding, discard_cookies + self, + curl: Curl, + buffer: BytesIO, + header_buffer: BytesIO, + default_encoding: Union[str, Callable[[bytes], str]], + discard_cookies: bool, ) -> R: c = curl rsp = cast(R, self.response_class(c)) @@ -349,7 +355,7 @@ def __init__( thread: Optional[ThreadType] = None, use_thread_local_curl: bool = True, **kwargs: Unpack[BaseSessionParams[R]], - ): + ) -> None: """ Parameters set in the ``__init__`` method will be overriden by the same parameter in request method. @@ -439,7 +445,7 @@ def executor(self): def __enter__(self): return self - def __exit__(self, *args): + def __exit__(self, *args) -> None: self.close() def close(self) -> None: @@ -453,7 +459,7 @@ def stream( method: HttpMethod, url: str, **kwargs: Unpack[StreamRequestParams], - ): + ) -> Generator[R, None, None]: """Equivalent to ``with request(..., stream=True) as r:``""" rsp = self.request(method=method, url=url, **kwargs, stream=True) try: @@ -462,7 +468,13 @@ def stream( rsp.close() def ws_connect( - self, url, on_message=None, on_error=None, on_open=None, on_close=None, **kwargs + self, + url: str, + on_message=None, + on_error=None, + on_open=None, + on_close=None, + **kwargs, ) -> WebSocket: """Connects to a websocket url. @@ -534,7 +546,7 @@ def request( max_recv_speed: int = 0, multipart: Optional[CurlMime] = None, discard_cookies: bool = False, - ): + ) -> R: """Send the request, see ``requests.request`` for details on parameters.""" self._check_session_closed() @@ -608,7 +620,7 @@ def perform(): cast(threading.Event, header_recved).set() q.put(STREAM_END) # type: ignore - def cleanup(fut): + def cleanup(fut: Future[None]): header_parsed.wait() c.reset() @@ -668,31 +680,31 @@ def cleanup(fut): finally: c.reset() - def head(self, url: str, **kwargs: Unpack[RequestParams]): + def head(self, url: str, **kwargs: Unpack[RequestParams]) -> R: return self.request(method="HEAD", url=url, **kwargs) - def get(self, url: str, **kwargs: Unpack[RequestParams]): + def get(self, url: str, **kwargs: Unpack[RequestParams]) -> R: return self.request(method="GET", url=url, **kwargs) - def post(self, url: str, **kwargs: Unpack[RequestParams]): + def post(self, url: str, **kwargs: Unpack[RequestParams]) -> R: return self.request(method="POST", url=url, **kwargs) - def put(self, url: str, **kwargs: Unpack[RequestParams]): + def put(self, url: str, **kwargs: Unpack[RequestParams]) -> R: return self.request(method="PUT", url=url, **kwargs) - def patch(self, url: str, **kwargs: Unpack[RequestParams]): + def patch(self, url: str, **kwargs: Unpack[RequestParams]) -> R: return self.request(method="PATCH", url=url, **kwargs) - def delete(self, url: str, **kwargs: Unpack[RequestParams]): + def delete(self, url: str, **kwargs: Unpack[RequestParams]) -> R: return self.request(method="DELETE", url=url, **kwargs) - def options(self, url: str, **kwargs: Unpack[RequestParams]): + def options(self, url: str, **kwargs: Unpack[RequestParams]) -> R: return self.request(method="OPTIONS", url=url, **kwargs) - def trace(self, url: str, **kwargs: Unpack[RequestParams]): + def trace(self, url: str, **kwargs: Unpack[RequestParams]) -> R: return self.request(method="TRACE", url=url, **kwargs) - def query(self, url: str, **kwargs: Unpack[RequestParams]): + def query(self, url: str, **kwargs: Unpack[RequestParams]) -> R: return self.request(method="QUERY", url=url, **kwargs) @@ -706,7 +718,7 @@ def __init__( async_curl: Optional[AsyncCurl] = None, max_clients: int = 10, **kwargs: Unpack[BaseSessionParams[R]], - ): + ) -> None: """ Parameters set in the ``__init__`` method will be override by the same parameter in request method. @@ -766,39 +778,39 @@ def __init__( self.init_pool() @property - def loop(self): + def loop(self) -> asyncio.AbstractEventLoop: if self._loop is None: self._loop = asyncio.get_running_loop() return self._loop @property - def acurl(self): + def acurl(self) -> AsyncCurl: if self._acurl is None: self._acurl = AsyncCurl(loop=self.loop) return self._acurl - def init_pool(self): - self.pool = asyncio.LifoQueue(self.max_clients) + def init_pool(self) -> None: + self.pool: asyncio.LifoQueue[Curl | None] = asyncio.LifoQueue(self.max_clients) while True: try: self.pool.put_nowait(None) except asyncio.QueueFull: break - async def pop_curl(self): + async def pop_curl(self) -> Curl: curl = await self.pool.get() if curl is None: curl = Curl(debug=self.debug) return curl - def push_curl(self, curl): + def push_curl(self, curl: Curl | None) -> None: with suppress(asyncio.QueueFull): self.pool.put_nowait(curl) - async def __aenter__(self): + async def __aenter__(self): # TODO: -> Self return self - async def __aexit__(self, *args): + async def __aexit__(self, *args) -> None: await self.close() return None @@ -814,7 +826,7 @@ async def close(self) -> None: except asyncio.QueueEmpty: break - def release_curl(self, curl): + def release_curl(self, curl: Curl) -> None: curl.clean_handles_and_buffers() if not self._closed: self.acurl.remove_handle(curl) @@ -829,7 +841,7 @@ async def stream( method: HttpMethod, url: str, **kwargs: Unpack[StreamRequestParams], - ): + ) -> AsyncGenerator[R, None, None]: """Equivalent to ``async with request(..., stream=True) as r:``""" rsp = await self.request(method=method, url=url, **kwargs, stream=True) try: @@ -1095,7 +1107,7 @@ async def request( if stream: task = self.acurl.add_handle(curl) - async def perform(): + async def perform() -> None: try: await task except CurlError as e: diff --git a/curl_cffi/requests/utils.py b/curl_cffi/requests/utils.py index 5055bbdc..55e2a67e 100644 --- a/curl_cffi/requests/utils.py +++ b/curl_cffi/requests/utils.py @@ -118,7 +118,7 @@ def update_url_params(url: str, params: Union[dict, list, tuple]) -> str: new_args_counter = Counter(x[0] for x in params) for key, value in params: # Bool and Dict values should be converted to json-friendly values - if isinstance(value, (bool, dict)): + if isinstance(value, bool | dict): value = dumps(value) # 1 to 1 mapping, we have to search and update it. if old_args_counter.get(key) == 1 and new_args_counter.get(key) == 1: @@ -395,7 +395,7 @@ def set_curl_options( c.setopt(CurlOpt.URL, url.encode()) # data/body/json - if isinstance(data, (dict, list, tuple)): + if isinstance(data, dict | list | tuple): body = urlencode(data).encode() elif isinstance(data, str): body = data.encode() @@ -455,7 +455,7 @@ def set_curl_options( update_header_line( header_lines, "Content-Type", "application/x-www-form-urlencoded" ) - if isinstance(data, (str, bytes)) and data: + if isinstance(data, str | bytes) and data: update_header_line(header_lines, "Content-Type", "application/octet-stream") # Never send `Expect` header. @@ -514,7 +514,7 @@ def set_curl_options( c.setopt(CurlOpt.LOW_SPEED_LIMIT, 1) c.setopt(CurlOpt.LOW_SPEED_TIME, math.ceil(all_timeout)) - elif isinstance(timeout, (int, float)): + elif isinstance(timeout, int | float): if not stream: c.setopt(CurlOpt.TIMEOUT_MS, int(timeout * 1000)) else: diff --git a/curl_cffi/requests/websockets.py b/curl_cffi/requests/websockets.py index 9560a1bd..480b9a9b 100644 --- a/curl_cffi/requests/websockets.py +++ b/curl_cffi/requests/websockets.py @@ -906,7 +906,7 @@ async def send( # cURL expects bytes if isinstance(payload, str): payload = payload.encode("utf-8") - elif isinstance(payload, (bytearray, memoryview)): + elif isinstance(payload, bytearray | memoryview): payload = bytes(payload) try: