Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion curl_cffi/requests/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
]
elif isinstance(headers, list):
# list of "Name: Value" pairs
if isinstance(headers[0], (str, bytes)):
if isinstance(headers[0], str | bytes):
sep = ":" if isinstance(headers[0], str) else b":"
h = []
for line in headers:
Expand Down
70 changes: 41 additions & 29 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import threading
import warnings
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, Future
from contextlib import asynccontextmanager, contextmanager, suppress
from collections.abc import Callable
from io import BytesIO
Expand All @@ -21,6 +21,7 @@
Union,
cast,
)
from collections.abc import AsyncGenerator, Generator
from urllib.parse import urlparse
from datetime import timedelta

Expand Down Expand Up @@ -245,7 +246,12 @@ def __init__(
)

def _parse_response(
self, curl, buffer, header_buffer, default_encoding, discard_cookies
self,
curl: Curl,
buffer: BytesIO,
header_buffer: BytesIO,
default_encoding: Union[str, Callable[[bytes], str]],
discard_cookies: bool,
) -> R:
c = curl
rsp = cast(R, self.response_class(c))
Expand Down Expand Up @@ -349,7 +355,7 @@ def __init__(
thread: Optional[ThreadType] = None,
use_thread_local_curl: bool = True,
**kwargs: Unpack[BaseSessionParams[R]],
):
) -> None:
"""
Parameters set in the ``__init__`` method will be overriden by the same
parameter in request method.
Expand Down Expand Up @@ -439,7 +445,7 @@ def executor(self):
def __enter__(self):
return self

def __exit__(self, *args):
def __exit__(self, *args) -> None:
self.close()

def close(self) -> None:
Expand All @@ -453,7 +459,7 @@ def stream(
method: HttpMethod,
url: str,
**kwargs: Unpack[StreamRequestParams],
):
) -> Generator[R, None, None]:
"""Equivalent to ``with request(..., stream=True) as r:``"""
rsp = self.request(method=method, url=url, **kwargs, stream=True)
try:
Expand All @@ -462,7 +468,13 @@ def stream(
rsp.close()

def ws_connect(
self, url, on_message=None, on_error=None, on_open=None, on_close=None, **kwargs
self,
url: str,
on_message=None,
on_error=None,
on_open=None,
on_close=None,
**kwargs,
) -> WebSocket:
"""Connects to a websocket url.

Expand Down Expand Up @@ -534,7 +546,7 @@ def request(
max_recv_speed: int = 0,
multipart: Optional[CurlMime] = None,
discard_cookies: bool = False,
):
) -> R:
"""Send the request, see ``requests.request`` for details on parameters."""

self._check_session_closed()
Expand Down Expand Up @@ -608,7 +620,7 @@ def perform():
cast(threading.Event, header_recved).set()
q.put(STREAM_END) # type: ignore

def cleanup(fut):
def cleanup(fut: Future[None]):
header_parsed.wait()
c.reset()

Expand Down Expand Up @@ -668,31 +680,31 @@ def cleanup(fut):
finally:
c.reset()

def head(self, url: str, **kwargs: Unpack[RequestParams]):
def head(self, url: str, **kwargs: Unpack[RequestParams]) -> R:
return self.request(method="HEAD", url=url, **kwargs)

def get(self, url: str, **kwargs: Unpack[RequestParams]):
def get(self, url: str, **kwargs: Unpack[RequestParams]) -> R:
return self.request(method="GET", url=url, **kwargs)

def post(self, url: str, **kwargs: Unpack[RequestParams]):
def post(self, url: str, **kwargs: Unpack[RequestParams]) -> R:
return self.request(method="POST", url=url, **kwargs)

def put(self, url: str, **kwargs: Unpack[RequestParams]):
def put(self, url: str, **kwargs: Unpack[RequestParams]) -> R:
return self.request(method="PUT", url=url, **kwargs)

def patch(self, url: str, **kwargs: Unpack[RequestParams]):
def patch(self, url: str, **kwargs: Unpack[RequestParams]) -> R:
return self.request(method="PATCH", url=url, **kwargs)

def delete(self, url: str, **kwargs: Unpack[RequestParams]):
def delete(self, url: str, **kwargs: Unpack[RequestParams]) -> R:
return self.request(method="DELETE", url=url, **kwargs)

def options(self, url: str, **kwargs: Unpack[RequestParams]):
def options(self, url: str, **kwargs: Unpack[RequestParams]) -> R:
return self.request(method="OPTIONS", url=url, **kwargs)

def trace(self, url: str, **kwargs: Unpack[RequestParams]):
def trace(self, url: str, **kwargs: Unpack[RequestParams]) -> R:
return self.request(method="TRACE", url=url, **kwargs)

def query(self, url: str, **kwargs: Unpack[RequestParams]):
def query(self, url: str, **kwargs: Unpack[RequestParams]) -> R:
return self.request(method="QUERY", url=url, **kwargs)


Expand All @@ -706,7 +718,7 @@ def __init__(
async_curl: Optional[AsyncCurl] = None,
max_clients: int = 10,
**kwargs: Unpack[BaseSessionParams[R]],
):
) -> None:
"""
Parameters set in the ``__init__`` method will be override by the same parameter
in request method.
Expand Down Expand Up @@ -766,39 +778,39 @@ def __init__(
self.init_pool()

@property
def loop(self):
def loop(self) -> asyncio.AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_running_loop()
return self._loop

@property
def acurl(self):
def acurl(self) -> AsyncCurl:
if self._acurl is None:
self._acurl = AsyncCurl(loop=self.loop)
return self._acurl

def init_pool(self):
self.pool = asyncio.LifoQueue(self.max_clients)
def init_pool(self) -> None:
self.pool: asyncio.LifoQueue[Curl | None] = asyncio.LifoQueue(self.max_clients)
while True:
try:
self.pool.put_nowait(None)
except asyncio.QueueFull:
break

async def pop_curl(self):
async def pop_curl(self) -> Curl:
curl = await self.pool.get()
if curl is None:
curl = Curl(debug=self.debug)
return curl

def push_curl(self, curl):
def push_curl(self, curl: Curl | None) -> None:
with suppress(asyncio.QueueFull):
self.pool.put_nowait(curl)

async def __aenter__(self):
async def __aenter__(self): # TODO: -> Self
return self

async def __aexit__(self, *args):
async def __aexit__(self, *args) -> None:
await self.close()
return None

Expand All @@ -814,7 +826,7 @@ async def close(self) -> None:
except asyncio.QueueEmpty:
break

def release_curl(self, curl):
def release_curl(self, curl: Curl) -> None:
curl.clean_handles_and_buffers()
if not self._closed:
self.acurl.remove_handle(curl)
Expand All @@ -829,7 +841,7 @@ async def stream(
method: HttpMethod,
url: str,
**kwargs: Unpack[StreamRequestParams],
):
) -> AsyncGenerator[R, None, None]:
"""Equivalent to ``async with request(..., stream=True) as r:``"""
rsp = await self.request(method=method, url=url, **kwargs, stream=True)
try:
Expand Down Expand Up @@ -1095,7 +1107,7 @@ async def request(
if stream:
task = self.acurl.add_handle(curl)

async def perform():
async def perform() -> None:
try:
await task
except CurlError as e:
Expand Down
8 changes: 4 additions & 4 deletions curl_cffi/requests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def update_url_params(url: str, params: Union[dict, list, tuple]) -> str:
new_args_counter = Counter(x[0] for x in params)
for key, value in params:
# Bool and Dict values should be converted to json-friendly values
if isinstance(value, (bool, dict)):
if isinstance(value, bool | dict):
value = dumps(value)
# 1 to 1 mapping, we have to search and update it.
if old_args_counter.get(key) == 1 and new_args_counter.get(key) == 1:
Expand Down Expand Up @@ -395,7 +395,7 @@ def set_curl_options(
c.setopt(CurlOpt.URL, url.encode())

# data/body/json
if isinstance(data, (dict, list, tuple)):
if isinstance(data, dict | list | tuple):
body = urlencode(data).encode()
elif isinstance(data, str):
body = data.encode()
Expand Down Expand Up @@ -455,7 +455,7 @@ def set_curl_options(
update_header_line(
header_lines, "Content-Type", "application/x-www-form-urlencoded"
)
if isinstance(data, (str, bytes)) and data:
if isinstance(data, str | bytes) and data:
update_header_line(header_lines, "Content-Type", "application/octet-stream")

# Never send `Expect` header.
Expand Down Expand Up @@ -514,7 +514,7 @@ def set_curl_options(
c.setopt(CurlOpt.LOW_SPEED_LIMIT, 1)
c.setopt(CurlOpt.LOW_SPEED_TIME, math.ceil(all_timeout))

elif isinstance(timeout, (int, float)):
elif isinstance(timeout, int | float):
if not stream:
c.setopt(CurlOpt.TIMEOUT_MS, int(timeout * 1000))
else:
Expand Down
2 changes: 1 addition & 1 deletion curl_cffi/requests/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ async def send(
# cURL expects bytes
if isinstance(payload, str):
payload = payload.encode("utf-8")
elif isinstance(payload, (bytearray, memoryview)):
elif isinstance(payload, bytearray | memoryview):
payload = bytes(payload)

try:
Expand Down
Loading