33
44import datetime
55import inspect
6+ import itertools
67import os
78import pathlib
89import sys
@@ -366,6 +367,65 @@ def _validate_return_callback(func: Callable) -> None:
366367_T = TypeVar ("_T" , bound = type )
367368
368369
370+ def _register_type_callback (
371+ resolved_type : _T ,
372+ return_callback : ReturnCallback | None = None ,
373+ ) -> list [type ]:
374+ modified_callbacks = []
375+ if return_callback is None :
376+ return []
377+ _validate_return_callback (return_callback )
378+ # if the type is a Union, add the callback to all of the types in the union
379+ # (except NoneType)
380+ if get_origin (resolved_type ) is Union :
381+ for type_per in _generate_union_variants (resolved_type ):
382+ if return_callback not in _RETURN_CALLBACKS [type_per ]:
383+ _RETURN_CALLBACKS [type_per ].append (return_callback )
384+ modified_callbacks .append (type_per )
385+
386+ for t in get_args (resolved_type ):
387+ if not _is_none_type (t ) and return_callback not in _RETURN_CALLBACKS [t ]:
388+ _RETURN_CALLBACKS [t ].append (return_callback )
389+ modified_callbacks .append (t )
390+ elif return_callback not in _RETURN_CALLBACKS [resolved_type ]:
391+ _RETURN_CALLBACKS [resolved_type ].append (return_callback )
392+ modified_callbacks .append (resolved_type )
393+ return modified_callbacks
394+
395+
396+ def _register_widget (
397+ resolved_type : _T ,
398+ widget_type : WidgetRef | None = None ,
399+ ** options : Any ,
400+ ) -> WidgetTuple | None :
401+ _options = cast (dict , options )
402+
403+ previous_widget = _TYPE_DEFS .get (resolved_type )
404+
405+ if "choices" in _options :
406+ _TYPE_DEFS [resolved_type ] = (widgets .ComboBox , _options )
407+ if widget_type is not None :
408+ warnings .warn (
409+ "Providing `choices` overrides `widget_type`. Categorical widget "
410+ f"will be used for type { resolved_type } " ,
411+ stacklevel = 2 ,
412+ )
413+ elif widget_type is not None :
414+ if not isinstance (widget_type , (str , WidgetProtocol )) and not (
415+ inspect .isclass (widget_type ) and issubclass (widget_type , widgets .Widget )
416+ ):
417+ raise TypeError (
418+ '"widget_type" must be either a string, WidgetProtocol, or '
419+ "Widget subclass"
420+ )
421+ _TYPE_DEFS [resolved_type ] = (widget_type , _options )
422+ elif "bind" in _options :
423+ # if we're binding a value to this parameter, it doesn't matter what type
424+ # of ValueWidget is used... it usually won't be shown
425+ _TYPE_DEFS [resolved_type ] = (widgets .EmptyWidget , _options )
426+ return previous_widget
427+
428+
369429@overload
370430def register_type (
371431 type_ : _T ,
@@ -435,43 +495,11 @@ def register_type(
435495 "must be provided."
436496 )
437497
438- def _deco (type_ : _T ) -> _T :
439- resolved_type = resolve_single_type (type_ )
440- if return_callback is not None :
441- _validate_return_callback (return_callback )
442- # if the type is a Union, add the callback to all of the types in the union
443- # (except NoneType)
444- if get_origin (resolved_type ) is Union :
445- for t in get_args (resolved_type ):
446- if not _is_none_type (t ):
447- _RETURN_CALLBACKS [t ].append (return_callback )
448- else :
449- _RETURN_CALLBACKS [resolved_type ].append (return_callback )
450-
451- _options = cast (dict , options )
452-
453- if "choices" in _options :
454- _TYPE_DEFS [resolved_type ] = (widgets .ComboBox , _options )
455- if widget_type is not None :
456- warnings .warn (
457- "Providing `choices` overrides `widget_type`. Categorical widget "
458- f"will be used for type { resolved_type } " ,
459- stacklevel = 2 ,
460- )
461- elif widget_type is not None :
462- if not isinstance (widget_type , (str , WidgetProtocol )) and not (
463- inspect .isclass (widget_type ) and issubclass (widget_type , widgets .Widget )
464- ):
465- raise TypeError (
466- '"widget_type" must be either a string, WidgetProtocol, or '
467- "Widget subclass"
468- )
469- _TYPE_DEFS [resolved_type ] = (widget_type , _options )
470- elif "bind" in _options :
471- # if we're binding a value to this parameter, it doesn't matter what type
472- # of ValueWidget is used... it usually won't be shown
473- _TYPE_DEFS [resolved_type ] = (widgets .EmptyWidget , _options )
474- return type_
498+ def _deco (type__ : _T ) -> _T :
499+ resolved_type = resolve_single_type (type__ )
500+ _register_type_callback (resolved_type , return_callback )
501+ _register_widget (resolved_type , widget_type , ** options )
502+ return type__
475503
476504 return _deco if type_ is None else _deco (type_ )
477505
@@ -507,23 +535,19 @@ def type_registered(
507535 """
508536 resolved_type = resolve_single_type (type_ )
509537
510- # check if return_callback is already registered
511- rc_was_present = return_callback in _RETURN_CALLBACKS .get (resolved_type , [])
512538 # store any previous widget_type and options for this type
513- prev_type_def : WidgetTuple | None = _TYPE_DEFS .get (resolved_type , None )
514- resolved_type = register_type (
515- resolved_type ,
516- widget_type = widget_type ,
517- return_callback = return_callback ,
518- ** options ,
519- )
539+
540+ revert_list = _register_type_callback (resolved_type , return_callback )
541+ prev_type_def = _register_widget (resolved_type , widget_type , ** options )
542+
520543 new_type_def : WidgetTuple | None = _TYPE_DEFS .get (resolved_type , None )
521544 try :
522545 yield
523546 finally :
524547 # restore things to before the context
525- if return_callback is not None and not rc_was_present :
526- _RETURN_CALLBACKS [resolved_type ].remove (return_callback )
548+ if return_callback is not None : # this if is only for mypy
549+ for return_callback_type in revert_list :
550+ _RETURN_CALLBACKS [return_callback_type ].remove (return_callback )
527551
528552 if _TYPE_DEFS .get (resolved_type , None ) is not new_type_def :
529553 warnings .warn ("Type definition changed during context" , stacklevel = 2 )
@@ -537,9 +561,6 @@ def type_registered(
537561def type2callback (type_ : type ) -> list [ReturnCallback ]:
538562 """Return any callbacks that have been registered for ``type_``.
539563
540- Note that if the return type is X, then the callbacks registered for Optional[X]
541- will be returned also be returned.
542-
543564 Parameters
544565 ----------
545566 type_ : type
@@ -555,7 +576,7 @@ def type2callback(type_: type) -> list[ReturnCallback]:
555576
556577 # look for direct hits ...
557578 # if it's an Optional, we need to look for the type inside the Optional
558- _ , type_ = _is_optional ( resolve_single_type (type_ ) )
579+ type_ = resolve_single_type (type_ )
559580 if type_ in _RETURN_CALLBACKS :
560581 return _RETURN_CALLBACKS [type_ ]
561582
@@ -566,10 +587,8 @@ def type2callback(type_: type) -> list[ReturnCallback]:
566587 return []
567588
568589
569- def _is_optional (type_ : Any ) -> tuple [bool , type ]:
570- # TODO: this function is too similar to _type_optional above... need to combine
571- if get_origin (type_ ) is Union :
572- args = get_args (type_ )
573- if len (args ) == 2 and any (_is_none_type (i ) for i in args ):
574- return True , next (i for i in args if not _is_none_type (i ))
575- return False , type_
590+ def _generate_union_variants (type_ : Any ) -> Iterator [type ]:
591+ type_args = get_args (type_ )
592+ for i in range (2 , len (type_args ) + 1 ):
593+ for per in itertools .combinations (type_args , i ):
594+ yield cast (type , Union [per ])
0 commit comments