11import asyncio
22import contextlib
3+ from dataclasses import dataclass
34import json
45import os
6+ import queue
57import threading
68import time
79import typing
1517import trustme
1618import uvicorn
1719import websockets
20+ from websockets .exceptions import ConnectionClosedOK , ConnectionClosedError
1821from cryptography .hazmat .backends import default_backend
1922from cryptography .hazmat .primitives .serialization import (
2023 BestAvailableEncryption ,
@@ -613,30 +616,11 @@ async def watch_restarts(self): # pragma: nocover
613616 await self .startup ()
614617
615618
616- async def echo (websocket ):
617- while True :
618- try :
619- # Echo every text or binary message
620- async for message in websocket :
621- await websocket .send (message )
622-
623- except websockets .exceptions .ConnectionClosed as e :
624- # Client sent us a close frame: echo it back exactly
625- await websocket .close (code = e .code , reason = e .reason )
626-
627-
628- class TestWebsocketServer :
629- def __init__ (self , port ):
630- self .url = f"ws://127.0.0.1:{ port } "
631- self .port = port
632-
633- def run (self ):
634- async def serve (port ):
635- # GitHub actions only likes 127, not localhost, wtf...
636- async with websockets .serve (echo , "127.0.0.1" , port ): # pyright: ignore
637- await asyncio .Future () # run forever
638-
639- asyncio .run (serve (self .port ))
619+ @pytest .fixture (scope = "session" )
620+ def server ():
621+ config = Config (app = app , lifespan = "off" , loop = "asyncio" )
622+ server = TestServer (config = config )
623+ yield from serve_in_thread (server )
640624
641625
642626def serve_in_thread (server : Server ):
@@ -651,26 +635,6 @@ def serve_in_thread(server: Server):
651635 thread .join ()
652636
653637
654- @pytest .fixture (scope = "session" )
655- def ws_server ():
656- server = TestWebsocketServer (port = 8964 )
657- thread = threading .Thread (target = server .run , daemon = True )
658- thread .start ()
659- try :
660- time .sleep (2 ) # FIXME find a reliable way to check the server is up
661- yield server
662- finally :
663- pass
664- # thread.join()
665-
666-
667- @pytest .fixture (scope = "session" )
668- def server ():
669- config = Config (app = app , lifespan = "off" , loop = "asyncio" )
670- server = TestServer (config = config )
671- yield from serve_in_thread (server )
672-
673-
674638@pytest .fixture (scope = "session" )
675639def https_server (cert_pem_file , cert_private_key_file ):
676640 config = Config (
@@ -685,6 +649,72 @@ def https_server(cert_pem_file, cert_private_key_file):
685649 yield from serve_in_thread (server )
686650
687651
652+ async def echo (ws ):
653+ try :
654+ async for msg in ws :
655+ await ws .send (msg )
656+ except (ConnectionClosedOK , ConnectionClosedError ):
657+ # Normal / abnormal close — nothing extra to do.
658+ pass
659+
660+
661+ def start_ws_server (port : int = 8964 ):
662+ """
663+ Start a websockets server on 127.0.0.1:port in a background thread.
664+ Returns (url, stop) where stop() shuts it down.
665+ """
666+ ready = threading .Event ()
667+ stop_callable_q : queue .Queue [typing .Callable ] = queue .Queue ()
668+
669+ def _thread_target ():
670+ loop = asyncio .new_event_loop ()
671+ asyncio .set_event_loop (loop )
672+
673+ stop_async = asyncio .Event ()
674+
675+ def _stop ():
676+ # can be called from main thread
677+ loop .call_soon_threadsafe (stop_async .set )
678+
679+ async def _run ():
680+ async with websockets .serve (echo , "127.0.0.1" , port ) as _ :
681+ stop_callable_q .put (_stop )
682+ ready .set ()
683+ await stop_async .wait ()
684+
685+ try :
686+ loop .run_until_complete (_run ())
687+ finally :
688+ loop .run_until_complete (loop .shutdown_asyncgens ())
689+ loop .close ()
690+
691+ t = threading .Thread (target = _thread_target , daemon = True )
692+ t .start ()
693+
694+ # Wait until server is really listening and we have a stop() handle
695+ stop = stop_callable_q .get () # blocks until put()
696+ ready .wait () # the socket is bound now
697+
698+ url = f"ws://127.0.0.1:{ port } "
699+ return url , stop , t
700+
701+
702+ @dataclass
703+ class WSServer :
704+ url : str
705+ stop : typing .Callable
706+
707+
708+ @pytest .fixture (scope = "session" )
709+ def ws_server ():
710+ url , stop , thread = start_ws_server (port = 8964 )
711+ try :
712+ yield WSServer (url = url , stop = stop )
713+ finally :
714+ stop () # trigger graceful shutdown
715+ thread .join (5 ) # optional: wait up to 5s for thread to exit
716+
717+
688718@pytest .fixture (scope = "session" )
689719def proxy_server (request ):
690720 ps = proxy .Proxy (port = 8002 , plugins = ["proxy.plugin.ManInTheMiddlePlugin" ])
0 commit comments