From 1b994af2ee6c5d310e205a988c17d953a3c16535 Mon Sep 17 00:00:00 2001 From: LilSpazJoekp <15524072+LilSpazJoekp@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:08:35 -0500 Subject: [PATCH] Support delayed session creation --- CHANGES.rst | 1 + asyncpraw/models/reddit/emoji.py | 8 +++--- asyncpraw/models/reddit/subreddit.py | 27 +++++++++++-------- asyncpraw/models/reddit/widgets.py | 8 +++--- pyproject.toml | 1 - tests/integration/__init__.py | 4 +-- .../models/reddit/test_subreddit.py | 15 ++++++----- tests/unit/models/reddit/test_subreddit.py | 8 +++--- tests/unit/test_reddit.py | 21 ++++++++++++--- 9 files changed, 58 insertions(+), 35 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 86ce2208..505ffd75 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -33,6 +33,7 @@ Unreleased - :func:`.stream_generator` now accepts the ``continue_after_id`` parameter, which starts the stream after a given item ID. - Support for new share URL format created from Reddit's mobile apps. +- Support delayed session creation in asyncprawcore 2.5.0+. **Changed** diff --git a/asyncpraw/models/reddit/emoji.py b/asyncpraw/models/reddit/emoji.py index 578f8545..98895d58 100644 --- a/asyncpraw/models/reddit/emoji.py +++ b/asyncpraw/models/reddit/emoji.py @@ -224,10 +224,10 @@ async def add( # TODO(@LilSpazJoekp): This is a blocking operation. It should be made async. with file.open("rb") as image: # noqa: ASYNC230 upload_data["file"] = image - response = await self._reddit._core._requestor._http.post( - upload_url, data=upload_data - ) - response.raise_for_status() + async with self._reddit._core._requestor.request( + "POST", upload_url, data=upload_data + ) as response: + response.raise_for_status() data = { "mod_flair_only": mod_flair_only, diff --git a/asyncpraw/models/reddit/subreddit.py b/asyncpraw/models/reddit/subreddit.py index 8b5d7521..9bf2e65a 100644 --- a/asyncpraw/models/reddit/subreddit.py +++ b/asyncpraw/models/reddit/subreddit.py @@ -4,6 +4,7 @@ import contextlib from asyncio import TimeoutError +from contextlib import asynccontextmanager from copy import deepcopy from csv import writer from io import StringIO @@ -12,8 +13,10 @@ from typing import ( TYPE_CHECKING, Any, + AsyncContextManager, AsyncGenerator, AsyncIterator, + Callable, Iterator, ) from urllib.parse import urljoin @@ -3266,14 +3269,16 @@ async def _parse_xml_response(self, response: ClientResponse): actual=int(actual), maximum_size=int(maximum_size) ) + @asynccontextmanager async def _read_and_post_media( self, file: Path, upload_url: str, upload_data: dict[str, Any] - ) -> ClientResponse: + ) -> Callable[..., AsyncContextManager[ClientResponse]]: with file.open("rb") as media: upload_data["file"] = media - return await self._reddit._core._requestor._http.post( - upload_url, data=upload_data - ) + async with self._reddit._core._requestor.request( + "POST", upload_url, data=upload_data + ) as response: + yield response async def _submit_media( self, *, data: dict[Any, Any], timeout: int, without_websockets: bool @@ -3380,13 +3385,13 @@ async def _upload_media( upload_url = f"https:{upload_lease['action']}" upload_data = {item["name"]: item["value"] for item in upload_lease["fields"]} - response = await self._read_and_post_media(file, upload_url, upload_data) - if response.status != 201: - await self._parse_xml_response(response) - try: - response.raise_for_status() - except HttpProcessingError: - raise ServerError(response=response) from None + async with self._read_and_post_media(file, upload_url, upload_data) as response: + if response.status != 201: + await self._parse_xml_response(response) + try: + response.raise_for_status() + except HttpProcessingError: + raise ServerError(response=response) from None if upload_type == "link": return f"{upload_url}/{upload_data['key']}" diff --git a/asyncpraw/models/reddit/widgets.py b/asyncpraw/models/reddit/widgets.py index 101edf28..ea4daf61 100644 --- a/asyncpraw/models/reddit/widgets.py +++ b/asyncpraw/models/reddit/widgets.py @@ -1902,9 +1902,9 @@ async def upload_image(self, file_path: str) -> str: # TODO(@LilSpazJoekp): This is a blocking operation. It should be made async. with file.open("rb") as image: # noqa: ASYNC230 upload_data["file"] = image - response = await self._reddit._core._requestor._http.post( - upload_url, data=upload_data - ) - response.raise_for_status() + async with self._reddit._core._requestor.request( + "POST", upload_url, data=upload_data + ) as response: + response.raise_for_status() return f"{upload_url}/{upload_data['key']}" diff --git a/pyproject.toml b/pyproject.toml index 9dfdc5c8..d46c5807 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,6 @@ readthedocs = [ "sphinxcontrib-trio" ] test = [ - "mock ==4.*", "pytest ==7.*", "pytest-asyncio ==0.18.*", "pytest-vcr ==1.*", diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index e3f1efe1..db266680 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -88,14 +88,14 @@ def cassette_name(self, request, vcr_cassette_name): return marker.args[0] @pytest.fixture - async def reddit(self, vcr, event_loop: asyncio.AbstractEventLoop): + async def reddit(self, vcr): """Configure Reddit.""" reddit_kwargs = { "client_id": pytest.placeholders.client_id, "client_secret": pytest.placeholders.client_secret, "requestor_kwargs": { "session": aiohttp.ClientSession( - loop=event_loop, headers={"Accept-Encoding": "identity"} + headers={"Accept-Encoding": "identity"} ) }, "user_agent": pytest.placeholders.user_agent, diff --git a/tests/integration/models/reddit/test_subreddit.py b/tests/integration/models/reddit/test_subreddit.py index cc080b32..ceb7af0b 100644 --- a/tests/integration/models/reddit/test_subreddit.py +++ b/tests/integration/models/reddit/test_subreddit.py @@ -2,6 +2,7 @@ import socket from asyncio import TimeoutError +from contextlib import asynccontextmanager import pytest from aiohttp import ClientResponse @@ -1634,18 +1635,20 @@ async def test_submit_image__large(self, reddit, tmp_path): "iYEVOuRfbLiKwMgHt2ewqQRIm0NWL79uiC2rPLj9P0PwW55MhjY2/O8d9JdKTf1iwzLjwWMnGQ=" "" ) - _post = reddit._core._requestor._http.post + _request = reddit._core._requestor.request - async def patch_request(url, *args, **kwargs): + def patch_request(method, url, *args, **kwargs): """Patch requests to return mock data on specific url.""" if "https://reddit-uploaded-media.s3-accelerate.amazonaws.com" in url: - response = ClientResponse - response.text = AsyncMock(return_value=mock_data) + response = MagicMock(spec=ClientResponse) + response.__aenter__.return_value.text = AsyncMock( + return_value=mock_data + ) response.status = 400 return response - return await _post(url, *args, **kwargs) + return _request(method, url, *args, **kwargs) - reddit._core._requestor._http.post = patch_request + reddit._core._requestor.request = patch_request fake_png = PNG_HEADER + b"\x1a" * 10 # Normally 1024 ** 2 * 20 (20 MB) with open(tmp_path.joinpath("fake_img.png"), "wb") as tempfile: diff --git a/tests/unit/models/reddit/test_subreddit.py b/tests/unit/models/reddit/test_subreddit.py index 63ea879d..15d0989c 100644 --- a/tests/unit/models/reddit/test_subreddit.py +++ b/tests/unit/models/reddit/test_subreddit.py @@ -6,6 +6,8 @@ from unittest import mock from unittest.mock import AsyncMock, MagicMock +from aiohttp import ClientResponse + from asyncpraw.exceptions import ClientException, MediaPostFailed from asyncpraw.models import InlineGif, InlineImage, InlineVideo, Subreddit, WikiPage from asyncpraw.models.reddit.subreddit import SubredditFlairTemplates @@ -97,12 +99,12 @@ async def test_media_upload_500(self, mock_method, reddit): from aiohttp.http_exceptions import HttpProcessingError from asyncprawcore.exceptions import ServerError - response = MagicMock() - response.status = 201 + response = MagicMock(spec=ClientResponse) response.raise_for_status = MagicMock( side_effect=HttpProcessingError(code=500, message="") ) - mock_method.return_value = response + response.status = 201 + mock_method.return_value.__aenter__.return_value = response with pytest.raises(ServerError): await Subreddit(reddit, display_name="test").submit_image( "Test", "/dev/null" diff --git a/tests/unit/test_reddit.py b/tests/unit/test_reddit.py index 7e5aae9f..1eaee67d 100644 --- a/tests/unit/test_reddit.py +++ b/tests/unit/test_reddit.py @@ -1,5 +1,4 @@ import configparser -import sys import types import pytest @@ -26,6 +25,15 @@ def pre_refresh_callback(self, authorizer): pass +class MockClientSession: + def __init__(self, *args, **kwargs): + self.closed = False + self.headers = {} + + async def close(self): + self.closed = True + + class TestReddit(UnitTest): REQUIRED_DUMMY_SETTINGS = { x: "dummy" for x in ["client_id", "client_secret", "user_agent"] @@ -46,8 +54,10 @@ async def test_check_for_updates_update_checker_missing(self, mock_update_check) assert not mock_update_check.called async def test_close_session(self): - temp_reddit = Reddit(**self.REQUIRED_DUMMY_SETTINGS) - assert not temp_reddit.requestor._http.closed + temp_reddit = Reddit( + **self.REQUIRED_DUMMY_SETTINGS, + requestor_kwargs={"session": MockClientSession()}, + ) async with temp_reddit as reddit: pass assert reddit.requestor._http.closed and temp_reddit.requestor._http.closed @@ -69,7 +79,10 @@ def test_conflicting_settings(self): ) async def test_context_manager(self): - async with Reddit(**self.REQUIRED_DUMMY_SETTINGS) as reddit: + async with Reddit( + **self.REQUIRED_DUMMY_SETTINGS, + requestor_kwargs={"session": MockClientSession()}, + ) as reddit: assert not reddit._validate_on_submit assert not reddit.requestor._http.closed assert reddit.requestor._http.closed