|
| 1 | +import asyncio |
| 2 | +import functools |
| 3 | +import logging |
| 4 | +from collections import defaultdict |
| 5 | +from collections.abc import AsyncIterable, Iterable |
| 6 | + |
| 7 | +from httpcore import Response |
| 8 | +from httpcore._models import ByteStream |
| 9 | + |
| 10 | +from vcr.errors import CannotOverwriteExistingCassetteException |
| 11 | +from vcr.filters import decode_response |
| 12 | +from vcr.request import Request as VcrRequest |
| 13 | +from vcr.serializers.compat import convert_body_to_bytes |
| 14 | + |
| 15 | +_logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +async def _convert_byte_stream(stream): |
| 19 | + if isinstance(stream, Iterable): |
| 20 | + return list(stream) |
| 21 | + |
| 22 | + if isinstance(stream, AsyncIterable): |
| 23 | + return [part async for part in stream] |
| 24 | + |
| 25 | + |
| 26 | +def _serialize_headers(real_response): |
| 27 | + """ |
| 28 | + Some headers can appear multiple times, like "Set-Cookie". |
| 29 | + Therefore serialize every header key to a list of values. |
| 30 | + """ |
| 31 | + |
| 32 | + headers = defaultdict(list) |
| 33 | + |
| 34 | + for name, value in real_response.headers: |
| 35 | + headers[name.decode("ascii")].append(value.decode("ascii")) |
| 36 | + |
| 37 | + return dict(headers) |
| 38 | + |
| 39 | + |
| 40 | +async def _serialize_response(real_response): |
| 41 | + # The reason_phrase may not exist |
| 42 | + try: |
| 43 | + reason_phrase = real_response.extensions["reason_phrase"].decode("ascii") |
| 44 | + except KeyError: |
| 45 | + reason_phrase = None |
| 46 | + |
| 47 | + # Reading the response stream consumes the iterator, so we need to restore it afterwards |
| 48 | + content = b"".join(await _convert_byte_stream(real_response.stream)) |
| 49 | + real_response.stream = ByteStream(content) |
| 50 | + |
| 51 | + return { |
| 52 | + "status": {"code": real_response.status, "message": reason_phrase}, |
| 53 | + "headers": _serialize_headers(real_response), |
| 54 | + "body": {"string": content}, |
| 55 | + } |
| 56 | + |
| 57 | + |
| 58 | +def _deserialize_headers(headers): |
| 59 | + """ |
| 60 | + httpcore accepts headers as list of tuples of header key and value. |
| 61 | + """ |
| 62 | + |
| 63 | + return [ |
| 64 | + (name.encode("ascii"), value.encode("ascii")) for name, values in headers.items() for value in values |
| 65 | + ] |
| 66 | + |
| 67 | + |
| 68 | +def _deserialize_response(vcr_response): |
| 69 | + # Cassette format generated for HTTPX requests by older versions of |
| 70 | + # vcrpy. We restructure the content to resemble what a regular |
| 71 | + # cassette looks like. |
| 72 | + if "status_code" in vcr_response: |
| 73 | + vcr_response = decode_response( |
| 74 | + convert_body_to_bytes( |
| 75 | + { |
| 76 | + "headers": vcr_response["headers"], |
| 77 | + "body": {"string": vcr_response["content"]}, |
| 78 | + "status": {"code": vcr_response["status_code"]}, |
| 79 | + }, |
| 80 | + ), |
| 81 | + ) |
| 82 | + extensions = None |
| 83 | + else: |
| 84 | + extensions = ( |
| 85 | + {"reason_phrase": vcr_response["status"]["message"].encode("ascii")} |
| 86 | + if vcr_response["status"]["message"] |
| 87 | + else None |
| 88 | + ) |
| 89 | + |
| 90 | + return Response( |
| 91 | + vcr_response["status"]["code"], |
| 92 | + headers=_deserialize_headers(vcr_response["headers"]), |
| 93 | + content=vcr_response["body"]["string"], |
| 94 | + extensions=extensions, |
| 95 | + ) |
| 96 | + |
| 97 | + |
| 98 | +async def _make_vcr_request(real_request): |
| 99 | + # Reading the request stream consumes the iterator, so we need to restore it afterwards |
| 100 | + body = b"".join(await _convert_byte_stream(real_request.stream)) |
| 101 | + real_request.stream = ByteStream(body) |
| 102 | + |
| 103 | + uri = bytes(real_request.url).decode("ascii") |
| 104 | + headers = {name.decode("ascii"): value.decode("ascii") for name, value in real_request.headers} |
| 105 | + |
| 106 | + return VcrRequest(real_request.method.decode("ascii"), uri, body, headers) |
| 107 | + |
| 108 | + |
| 109 | +async def _vcr_request(cassette, real_request): |
| 110 | + vcr_request = await _make_vcr_request(real_request) |
| 111 | + |
| 112 | + if cassette.can_play_response_for(vcr_request): |
| 113 | + return vcr_request, _play_responses(cassette, vcr_request) |
| 114 | + |
| 115 | + if cassette.write_protected and cassette.filter_request(vcr_request): |
| 116 | + raise CannotOverwriteExistingCassetteException( |
| 117 | + cassette=cassette, |
| 118 | + failed_request=vcr_request, |
| 119 | + ) |
| 120 | + |
| 121 | + _logger.info("%s not in cassette, sending to real server", vcr_request) |
| 122 | + |
| 123 | + return vcr_request, None |
| 124 | + |
| 125 | + |
| 126 | +async def _record_responses(cassette, vcr_request, real_response): |
| 127 | + cassette.append(vcr_request, await _serialize_response(real_response)) |
| 128 | + |
| 129 | + |
| 130 | +def _play_responses(cassette, vcr_request): |
| 131 | + vcr_response = cassette.play_response(vcr_request) |
| 132 | + real_response = _deserialize_response(vcr_response) |
| 133 | + |
| 134 | + return real_response |
| 135 | + |
| 136 | + |
| 137 | +async def _vcr_handle_async_request( |
| 138 | + cassette, |
| 139 | + real_handle_async_request, |
| 140 | + self, |
| 141 | + real_request, |
| 142 | +): |
| 143 | + vcr_request, vcr_response = await _vcr_request(cassette, real_request) |
| 144 | + |
| 145 | + if vcr_response: |
| 146 | + return vcr_response |
| 147 | + |
| 148 | + real_response = await real_handle_async_request(self, real_request) |
| 149 | + await _record_responses(cassette, vcr_request, real_response) |
| 150 | + |
| 151 | + return real_response |
| 152 | + |
| 153 | + |
| 154 | +def vcr_handle_async_request(cassette, real_handle_async_request): |
| 155 | + @functools.wraps(real_handle_async_request) |
| 156 | + def _inner_handle_async_request(self, real_request): |
| 157 | + return _vcr_handle_async_request( |
| 158 | + cassette, |
| 159 | + real_handle_async_request, |
| 160 | + self, |
| 161 | + real_request, |
| 162 | + ) |
| 163 | + |
| 164 | + return _inner_handle_async_request |
| 165 | + |
| 166 | + |
| 167 | +def _run_async_function(sync_func, *args, **kwargs): |
| 168 | + """ |
| 169 | + Safely run an asynchronous function from a synchronous context. |
| 170 | + Handles both cases: |
| 171 | + - An event loop is already running. |
| 172 | + - No event loop exists yet. |
| 173 | + """ |
| 174 | + try: |
| 175 | + asyncio.get_running_loop() |
| 176 | + except RuntimeError: |
| 177 | + return asyncio.run(sync_func(*args, **kwargs)) |
| 178 | + else: |
| 179 | + # If inside a running loop, create a task and wait for it |
| 180 | + return asyncio.ensure_future(sync_func(*args, **kwargs)) |
| 181 | + |
| 182 | + |
| 183 | +def _vcr_handle_request(cassette, real_handle_request, self, real_request): |
| 184 | + vcr_request, vcr_response = _run_async_function( |
| 185 | + _vcr_request, |
| 186 | + cassette, |
| 187 | + real_request, |
| 188 | + ) |
| 189 | + |
| 190 | + if vcr_response: |
| 191 | + return vcr_response |
| 192 | + |
| 193 | + real_response = real_handle_request(self, real_request) |
| 194 | + _run_async_function(_record_responses, cassette, vcr_request, real_response) |
| 195 | + |
| 196 | + return real_response |
| 197 | + |
| 198 | + |
| 199 | +def vcr_handle_request(cassette, real_handle_request): |
| 200 | + @functools.wraps(real_handle_request) |
| 201 | + def _inner_handle_request(self, real_request): |
| 202 | + return _vcr_handle_request(cassette, real_handle_request, self, real_request) |
| 203 | + |
| 204 | + return _inner_handle_request |
0 commit comments