10
10
from enum import IntEnum
11
11
from json import loads
12
12
from os import environ
13
+ from urllib .parse import parse_qsl , urlencode , urlsplit , urlunsplit
13
14
14
- from httpx import AsyncClient , Client , Headers , Limits , ReadTimeout , Request , Response
15
- from httpx import __version__ as httpx_version
15
+ from niquests import AsyncSession , ReadTimeout , Request , Response , Session
16
+ from niquests import __version__ as niquests_version
17
+ from niquests .structures import CaseInsensitiveDict
16
18
from starlette .requests import HTTPConnection
17
19
18
20
from . import options
@@ -49,6 +51,13 @@ class ServerVersion(typing.TypedDict):
49
51
"""Indicates if the subscription has extended support"""
50
52
51
53
54
+ @dataclass
55
+ class Limits :
56
+ max_keepalive_connections : int | None = 20
57
+ max_connections : int | None = 100
58
+ keepalive_expiry : int | float | None = 5
59
+
60
+
52
61
@dataclass
53
62
class RuntimeOptions :
54
63
xdebug_session : str
@@ -134,11 +143,11 @@ def __init__(self, **kwargs):
134
143
135
144
136
145
class NcSessionBase (ABC ):
137
- adapter : AsyncClient | Client
138
- adapter_dav : AsyncClient | Client
146
+ adapter : AsyncSession | Session
147
+ adapter_dav : AsyncSession | Session
139
148
cfg : BasicConfig
140
149
custom_headers : dict
141
- response_headers : Headers
150
+ response_headers : CaseInsensitiveDict
142
151
_user : str
143
152
_capabilities : dict
144
153
@@ -150,7 +159,7 @@ def __init__(self, **kwargs):
150
159
self .limits = Limits (max_keepalive_connections = 20 , max_connections = 20 , keepalive_expiry = 60.0 )
151
160
self .init_adapter ()
152
161
self .init_adapter_dav ()
153
- self .response_headers = Headers ()
162
+ self .response_headers = CaseInsensitiveDict ()
154
163
self ._ocs_regexp = re .compile (r"/ocs/v[12]\.php/|/apps/groupfolders/" )
155
164
156
165
def init_adapter (self , restart = False ) -> None :
@@ -172,7 +181,7 @@ def init_adapter_dav(self, restart=False) -> None:
172
181
self .adapter_dav .cookies .set ("XDEBUG_SESSION" , options .XDEBUG_SESSION )
173
182
174
183
@abstractmethod
175
- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
184
+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
176
185
pass # pragma: no cover
177
186
178
187
@property
@@ -187,8 +196,8 @@ def ae_url_v2(self) -> str:
187
196
188
197
189
198
class NcSessionBasic (NcSessionBase , ABC ):
190
- adapter : Client
191
- adapter_dav : Client
199
+ adapter : Session
200
+ adapter_dav : Session
192
201
193
202
def ocs (
194
203
self ,
@@ -206,9 +215,7 @@ def ocs(
206
215
info = f"request: { method } { path } "
207
216
nested_req = kwargs .pop ("nested_req" , False )
208
217
try :
209
- response = self .adapter .request (
210
- method , path , content = content , json = json , params = params , files = files , ** kwargs
211
- )
218
+ response = self .adapter .request (method , path , data = content , json = json , params = params , files = files , ** kwargs )
212
219
except ReadTimeout :
213
220
raise NextcloudException (408 , info = info ) from None
214
221
@@ -281,18 +288,18 @@ def _get_adapter_kwargs(self, dav: bool) -> dict[str, typing.Any]:
281
288
return {
282
289
"base_url" : self .cfg .dav_endpoint ,
283
290
"timeout" : self .cfg .options .timeout_dav ,
284
- "event_hooks" : {"request " : [], "response" : [self ._response_event ]},
291
+ "event_hooks" : {"pre_request " : [], "response" : [self ._response_event ]},
285
292
}
286
293
return {
287
294
"base_url" : self .cfg .endpoint ,
288
295
"timeout" : self .cfg .options .timeout ,
289
- "event_hooks" : {"request " : [self ._request_event_ocs ], "response" : [self ._response_event ]},
296
+ "event_hooks" : {"pre_request " : [self ._request_event_ocs ], "response" : [self ._response_event ]},
290
297
}
291
298
292
299
def _request_event_ocs (self , request : Request ) -> None :
293
300
str_url = str (request .url )
294
301
if re .search (self ._ocs_regexp , str_url ) is not None : # this is OCS call
295
- request .url = request .url . copy_merge_params ({ "format" : "json" } )
302
+ request .url = patch_param ( request .url , "format" , "json" )
296
303
request .headers ["Accept" ] = "application/json"
297
304
298
305
def _response_event (self , response : Response ) -> None :
@@ -305,15 +312,15 @@ def _response_event(self, response: Response) -> None:
305
312
306
313
def download2fp (self , url_path : str , fp , dav : bool , params = None , ** kwargs ):
307
314
adapter = self .adapter_dav if dav else self .adapter
308
- with adapter .stream ( "GET" , url_path , params = params , headers = kwargs .get ("headers" )) as response :
315
+ with adapter .get ( url_path , params = params , headers = kwargs .get ("headers" ), stream = True ) as response :
309
316
check_error (response )
310
- for data_chunk in response .iter_bytes (chunk_size = kwargs .get ("chunk_size" , 5 * 1024 * 1024 )):
317
+ for data_chunk in response .iter_raw (chunk_size = kwargs .get ("chunk_size" , - 1 )):
311
318
fp .write (data_chunk )
312
319
313
320
314
321
class AsyncNcSessionBasic (NcSessionBase , ABC ):
315
- adapter : AsyncClient
316
- adapter_dav : AsyncClient
322
+ adapter : AsyncSession
323
+ adapter_dav : AsyncSession
317
324
318
325
async def ocs (
319
326
self ,
@@ -332,7 +339,7 @@ async def ocs(
332
339
nested_req = kwargs .pop ("nested_req" , False )
333
340
try :
334
341
response = await self .adapter .request (
335
- method , path , content = content , json = json , params = params , files = files , ** kwargs
342
+ method , path , data = content , json = json , params = params , files = files , ** kwargs
336
343
)
337
344
except ReadTimeout :
338
345
raise NextcloudException (408 , info = info ) from None
@@ -350,7 +357,7 @@ async def ocs(
350
357
and ocs_meta ["statuscode" ] == 403
351
358
and str (ocs_meta ["message" ]).lower ().find ("password confirmation is required" ) != - 1
352
359
):
353
- await self .adapter .aclose ()
360
+ await self .adapter .close ()
354
361
self .init_adapter (restart = True )
355
362
return await self .ocs (
356
363
method , path , ** kwargs , content = content , json = json , params = params , nested_req = True
@@ -408,18 +415,18 @@ def _get_adapter_kwargs(self, dav: bool) -> dict[str, typing.Any]:
408
415
return {
409
416
"base_url" : self .cfg .dav_endpoint ,
410
417
"timeout" : self .cfg .options .timeout_dav ,
411
- "event_hooks" : {"request " : [], "response" : [self ._response_event ]},
418
+ "event_hooks" : {"pre_request " : [], "response" : [self ._response_event ]},
412
419
}
413
420
return {
414
421
"base_url" : self .cfg .endpoint ,
415
422
"timeout" : self .cfg .options .timeout ,
416
- "event_hooks" : {"request " : [self ._request_event_ocs ], "response" : [self ._response_event ]},
423
+ "event_hooks" : {"pre_request " : [self ._request_event_ocs ], "response" : [self ._response_event ]},
417
424
}
418
425
419
426
async def _request_event_ocs (self , request : Request ) -> None :
420
427
str_url = str (request .url )
421
428
if re .search (self ._ocs_regexp , str_url ) is not None : # this is OCS call
422
- request .url = request .url . copy_merge_params ({ "format" : "json" } )
429
+ request .url = patch_param ( request .url , "format" , "json" )
423
430
request .headers ["Accept" ] = "application/json"
424
431
425
432
async def _response_event (self , response : Response ) -> None :
@@ -432,10 +439,12 @@ async def _response_event(self, response: Response) -> None:
432
439
433
440
async def download2fp (self , url_path : str , fp , dav : bool , params = None , ** kwargs ):
434
441
adapter = self .adapter_dav if dav else self .adapter
435
- async with adapter .stream ("GET" , url_path , params = params , headers = kwargs .get ("headers" )) as response :
436
- check_error (response )
437
- async for data_chunk in response .aiter_bytes (chunk_size = kwargs .get ("chunk_size" , 5 * 1024 * 1024 )):
438
- fp .write (data_chunk )
442
+ response = await adapter .get (url_path , params = params , headers = kwargs .get ("headers" ), stream = True )
443
+
444
+ check_error (response )
445
+
446
+ async for data_chunk in await response .iter_raw (chunk_size = kwargs .get ("chunk_size" , - 1 )):
447
+ fp .write (data_chunk )
439
448
440
449
441
450
class NcSession (NcSessionBasic ):
@@ -445,15 +454,20 @@ def __init__(self, **kwargs):
445
454
self .cfg = Config (** kwargs )
446
455
super ().__init__ ()
447
456
448
- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
449
- return Client (
450
- follow_redirects = True ,
451
- limits = self .limits ,
452
- verify = self .cfg .options .nc_cert ,
453
- ** self ._get_adapter_kwargs (dav ),
454
- auth = self .cfg .auth ,
457
+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
458
+ session_kwargs = self ._get_adapter_kwargs (dav )
459
+ hooks = session_kwargs .pop ("event_hooks" )
460
+
461
+ session = Session (
462
+ keepalive_delay = self .limits .keepalive_expiry , pool_maxsize = self .limits .max_connections , ** session_kwargs
455
463
)
456
464
465
+ session .auth = self .cfg .auth
466
+ session .verify = self .cfg .options .nc_cert
467
+ session .hooks .update (hooks )
468
+
469
+ return session
470
+
457
471
458
472
class AsyncNcSession (AsyncNcSessionBasic ):
459
473
cfg : Config
@@ -462,21 +476,28 @@ def __init__(self, **kwargs):
462
476
self .cfg = Config (** kwargs )
463
477
super ().__init__ ()
464
478
465
- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
466
- return AsyncClient (
467
- follow_redirects = True ,
468
- limits = self .limits ,
469
- verify = self .cfg .options .nc_cert ,
470
- ** self ._get_adapter_kwargs (dav ),
471
- auth = self .cfg .auth ,
479
+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
480
+ session_kwargs = self ._get_adapter_kwargs (dav )
481
+ hooks = session_kwargs .pop ("event_hooks" )
482
+
483
+ session = AsyncSession (
484
+ keepalive_delay = self .limits .keepalive_expiry ,
485
+ pool_maxsize = self .limits .max_connections ,
486
+ ** session_kwargs ,
472
487
)
473
488
489
+ session .verify = self .cfg .options .nc_cert
490
+ session .auth = self .cfg .auth
491
+ session .hooks .update (hooks )
492
+
493
+ return session
494
+
474
495
475
496
class NcSessionAppBasic (ABC ):
476
497
cfg : AppConfig
477
498
_user : str
478
- adapter : AsyncClient | Client
479
- adapter_dav : AsyncClient | Client
499
+ adapter : AsyncSession | Session
500
+ adapter_dav : AsyncSession | Session
480
501
481
502
def __init__ (self , ** kwargs ):
482
503
self .cfg = AppConfig (** kwargs )
@@ -505,22 +526,29 @@ def sign_check(self, request: HTTPConnection) -> str:
505
526
class NcSessionApp (NcSessionAppBasic , NcSessionBasic ):
506
527
cfg : AppConfig
507
528
508
- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
509
- r = self ._get_adapter_kwargs (dav )
510
- r ["event_hooks" ]["request" ].append (self ._add_auth )
511
- return Client (
512
- follow_redirects = True ,
513
- limits = self .limits ,
514
- verify = self .cfg .options .nc_cert ,
515
- ** r ,
516
- headers = {
517
- "AA-VERSION" : self .cfg .aa_version ,
518
- "EX-APP-ID" : self .cfg .app_name ,
519
- "EX-APP-VERSION" : self .cfg .app_version ,
520
- "user-agent" : f"ExApp/{ self .cfg .app_name } /{ self .cfg .app_version } (httpx/{ httpx_version } )" ,
521
- },
529
+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
530
+ session_kwargs = self ._get_adapter_kwargs (dav )
531
+ session_kwargs ["event_hooks" ]["pre_request" ].append (self ._add_auth )
532
+
533
+ hooks = session_kwargs .pop ("event_hooks" )
534
+
535
+ session = Session (
536
+ keepalive_delay = self .limits .keepalive_expiry ,
537
+ pool_maxsize = self .limits .max_connections ,
538
+ ** session_kwargs ,
522
539
)
523
540
541
+ session .verify = self .cfg .options .nc_cert
542
+ session .headers = {
543
+ "AA-VERSION" : self .cfg .aa_version ,
544
+ "EX-APP-ID" : self .cfg .app_name ,
545
+ "EX-APP-VERSION" : self .cfg .app_version ,
546
+ "user-agent" : f"ExApp/{ self .cfg .app_name } /{ self .cfg .app_version } (niquests/{ niquests_version } )" ,
547
+ }
548
+ session .hooks .update (hooks )
549
+
550
+ return session
551
+
524
552
def _add_auth (self , request : Request ):
525
553
request .headers .update (
526
554
{"AUTHORIZATION-APP-API" : b64encode (f"{ self ._user } :{ self .cfg .app_secret } " .encode ("UTF=8" ))}
@@ -530,23 +558,39 @@ def _add_auth(self, request: Request):
530
558
class AsyncNcSessionApp (NcSessionAppBasic , AsyncNcSessionBasic ):
531
559
cfg : AppConfig
532
560
533
- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
534
- r = self ._get_adapter_kwargs (dav )
535
- r ["event_hooks" ]["request" ].append (self ._add_auth )
536
- return AsyncClient (
537
- follow_redirects = True ,
538
- limits = self .limits ,
539
- verify = self .cfg .options .nc_cert ,
540
- ** r ,
541
- headers = {
542
- "AA-VERSION" : self .cfg .aa_version ,
543
- "EX-APP-ID" : self .cfg .app_name ,
544
- "EX-APP-VERSION" : self .cfg .app_version ,
545
- "User-Agent" : f"ExApp/{ self .cfg .app_name } /{ self .cfg .app_version } (httpx/{ httpx_version } )" ,
546
- },
561
+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
562
+ session_kwargs = self ._get_adapter_kwargs (dav )
563
+ session_kwargs ["event_hooks" ]["pre_request" ].append (self ._add_auth )
564
+
565
+ hooks = session_kwargs .pop ("event_hooks" )
566
+
567
+ session = AsyncSession (
568
+ keepalive_delay = self .limits .keepalive_expiry ,
569
+ pool_maxsize = self .limits .max_connections ,
570
+ ** session_kwargs ,
547
571
)
572
+ session .verify = self .cfg .options .nc_cert
573
+ session .headers = {
574
+ "AA-VERSION" : self .cfg .aa_version ,
575
+ "EX-APP-ID" : self .cfg .app_name ,
576
+ "EX-APP-VERSION" : self .cfg .app_version ,
577
+ "User-Agent" : f"ExApp/{ self .cfg .app_name } /{ self .cfg .app_version } (niquests/{ niquests_version } )" ,
578
+ }
579
+ session .hooks .update (hooks )
580
+
581
+ return session
548
582
549
583
async def _add_auth (self , request : Request ):
550
584
request .headers .update (
551
585
{"AUTHORIZATION-APP-API" : b64encode (f"{ self ._user } :{ self .cfg .app_secret } " .encode ("UTF=8" ))}
552
586
)
587
+
588
+
589
+ def patch_param (url : str , key : str , value : str ) -> str :
590
+ parts = urlsplit (url )
591
+ query = dict (parse_qsl (parts .query , keep_blank_values = True ))
592
+ query [key ] = value
593
+
594
+ new_query = urlencode (query , doseq = True )
595
+
596
+ return urlunsplit ((parts .scheme , parts .netloc , parts .path , new_query , parts .fragment ))
0 commit comments