Skip to content

Commit 5cba8b6

Browse files
committed
Cherry-picked simplified domain models
1 parent 17a1105 commit 5cba8b6

File tree

7 files changed

+145
-209
lines changed

7 files changed

+145
-209
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.0.1-rc3
2+
current_version = 0.0.2-rc1
33
commit = False
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?

orchestrator/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
"""This is the orchestrator workflow engine."""
1515

16-
__version__ = "0.0.1-rc3"
16+
__version__ = "0.0.2-rc1"
1717

1818
from orchestrator.app import OrchestratorCore
1919
from orchestrator.settings import app_settings, oauth2_settings

orchestrator/api/api_v1/endpoints/subscriptions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
"""Module that implements subscription related API endpoints."""
1515

16-
from dataclasses import asdict
1716
from http import HTTPStatus
1817
from typing import Any, Dict, List, Optional, Union
1918
from uuid import UUID
@@ -336,7 +335,7 @@ def subscription_details_by_id_with_domain_model(subscription_id: UUID) -> Dict[
336335
SubscriptionCustomerDescriptionTable.subscription_id == subscription_id
337336
).all()
338337

339-
subscription = asdict(SubscriptionModel.from_subscription(subscription_id))
338+
subscription = SubscriptionModel.from_subscription(subscription_id).dict()
340339

341340
subscription["customer_descriptions"] = customer_descriptions
342341

orchestrator/domain/base.py

Lines changed: 69 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
# limitations under the License.
1313

1414
from collections import defaultdict
15-
from dataclasses import field
1615
from datetime import datetime
1716
from itertools import groupby, zip_longest
1817
from operator import attrgetter
19-
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
18+
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Type, TypeVar, Union
2019
from uuid import UUID, uuid4
2120

2221
import structlog
2322
from more_itertools import flatten, only
24-
from pydantic import ValidationError
23+
from pydantic import BaseModel, Field, ValidationError
24+
from pydantic.main import ModelMetaclass
2525
from pydantic.types import ConstrainedList
2626
from sqlalchemy import and_
2727
from sqlalchemy.orm import selectinload
@@ -35,19 +35,10 @@
3535
SubscriptionTable,
3636
db,
3737
)
38-
from orchestrator.domain.config import PydanticConfig
3938
from orchestrator.domain.lifecycle import ProductLifecycle, lookup_specialized_type, register_specialized_type
4039
from orchestrator.types import State, SubscriptionLifecycle, UUIDstr, is_list_type, is_of_type, is_optional_type
4140
from orchestrator.utils.docs import make_product_block_docstring, make_subscription_model_docstring
4241

43-
if TYPE_CHECKING:
44-
# Workaround for the fact that the pre-commit hook uses it's own env.
45-
from dataclasses import dataclass
46-
else:
47-
from pydantic.dataclasses import dataclass
48-
49-
DataclassType = Any
50-
5142
logger = structlog.get_logger(__name__)
5243

5344

@@ -67,41 +58,22 @@ def _is_constrained_list_type(type: Type) -> bool:
6758

6859

6960
T = TypeVar("T")
70-
S = TypeVar("S", "SubscriptionModel", "SubscriptionModel")
71-
B = TypeVar("B", "ProductBlockModel", "ProductBlockModel")
72-
73-
74-
class DomainMeta(type):
75-
__base_type__: "DomainMeta"
76-
77-
def __new__(metacls: Type[T], name: str, bases: Tuple[type, ...], namespace: Dict[str, Any], **kwds: Any) -> T:
78-
"""Create a new domain model type.
79-
80-
We make sure pydantic dataclasses are properly set up.
81-
82-
Each domain model is a pydantic dataclass. By calling it here we do
83-
not require the decorator on every class definition.
84-
85-
Because of how python works just setting a field with a new annotation will not override
86-
the runtime checks for that field by pydantic. We do that here explicitly using `dataclasses.field`
87-
This has the desired effect of actually changing the validations
88-
"""
89-
if name not in ("ProductBlockModel", "SubscriptionModel", "DomainModel"):
90-
for field_name in namespace["__annotations__"]:
91-
if field_name not in namespace:
92-
namespace[field_name] = field()
93-
94-
cls = super().__new__(metacls, name, bases, namespace, **kwds) # type: ignore
95-
96-
return dataclass(cls, config=PydanticConfig) # type: ignore
61+
S = TypeVar("S", bound="SubscriptionModel")
62+
B = TypeVar("B", bound="ProductBlockModel")
9763

9864

99-
class DomainModel(metaclass=DomainMeta):
65+
class DomainModel(BaseModel):
10066
"""Base class for domain models.
10167
10268
Contains all common Product block/Subscription instance code
10369
"""
10470

71+
class Config:
72+
validate_assignment = True
73+
validate_all = True
74+
arbitrary_types_allowed = True
75+
76+
__base_type__: ClassVar[Type["DomainModel"]]
10577
_product_block_fields_: ClassVar[Dict[str, Type]]
10678
_non_product_block_fields_: ClassVar[Dict[str, Type]]
10779

@@ -328,7 +300,7 @@ def _save_instances(
328300
return saved_instances
329301

330302

331-
class ProductBlockModelMeta(DomainMeta):
303+
class ProductBlockModelMeta(ModelMetaclass):
332304
"""Metaclass used to create product block instances.
333305
334306
This metaclass is used to make sure the class contains product block metadata.
@@ -339,9 +311,13 @@ class ProductBlockModelMeta(DomainMeta):
339311
You can find some examples in: :ref:`domain-models`
340312
"""
341313

342-
def __call__( # type:ignore
343-
self: Type[B], *args: Any, **kwargs: Any
344-
) -> B:
314+
__names__: List[str]
315+
name: str
316+
product_block_id: UUID
317+
description: str
318+
tag: str
319+
320+
def __call__(self, *args: Any, **kwargs: Any) -> B:
345321

346322
# Would have been nice to do this in __init_subclass__ but that runs outside the app context so we cant access the db
347323
# So now we do it just before we instantiate the instance
@@ -395,12 +371,8 @@ class ProductBlockModel(DomainModel, metaclass=ProductBlockModelMeta):
395371
description: ClassVar[str]
396372
tag: ClassVar[str]
397373

398-
# None of the fields defined here should have defaults.
399-
# Python dataclasses prohibits subclasses dataclasses from using non-default fields when
400-
# the superclass has a default field. To set a default add it to `new()`.
401-
product_block_name: str
402-
subscription_instance_id: UUID
403-
label: Optional[str]
374+
subscription_instance_id: UUID = Field(default_factory=uuid4)
375+
label: Optional[str] = None
404376

405377
def __init_subclass__(
406378
cls,
@@ -432,9 +404,7 @@ def new(cls: Type[B], **kwargs: Any) -> B:
432404
"""
433405
sub_instances = cls._init_instances(list(kwargs.keys()))
434406

435-
return cls( # type: ignore
436-
subscription_instance_id=uuid4(), label=None, product_block_name=cls.name, **sub_instances, **kwargs
437-
)
407+
return cls(**sub_instances, **kwargs) # type: ignore
438408

439409
@classmethod
440410
def _load_instances_values(cls, instance_values: List[SubscriptionInstanceValueTable]) -> Dict[str, str]:
@@ -510,7 +480,6 @@ def from_db(
510480
return cls( # type: ignore
511481
subscription_instance_id=subscription_instance_id,
512482
label=label,
513-
product_block_name=cls.name,
514483
**instance_values,
515484
**sub_instances,
516485
)
@@ -688,10 +657,14 @@ def save(
688657
return children + [subscription_instance]
689658

690659

691-
@dataclass(config=PydanticConfig) # type: ignore
692-
class ProductModel:
660+
class ProductModel(BaseModel):
693661
"""Represent the product as defined in the database as a dataclass."""
694662

663+
class Config:
664+
validate_assignment = True
665+
validate_all = True
666+
arbitrary_types_allowed = True
667+
695668
product_id: UUID
696669
name: str
697670
description: str
@@ -700,7 +673,7 @@ class ProductModel:
700673
status: ProductLifecycle
701674

702675

703-
class SubscriptionModel(DomainModel, metaclass=DomainMeta):
676+
class SubscriptionModel(DomainModel):
704677
"""Base class for all product subscription models.
705678
706679
Define a subscription model:
@@ -725,18 +698,15 @@ class SubscriptionModel(DomainModel, metaclass=DomainMeta):
725698
>>> SubscriptionInactive.from_subscription(subscription_id) # doctest:+SKIP
726699
"""
727700

728-
# None of the fields defined here should have defaults.
729-
# Python dataclasses prohibits subclasses dataclasses from using non-default fields when
730-
# the superclass has a default field. To set a default add it to `from_product_id()`.
731701
product: ProductModel
732702
customer_id: UUID
733-
subscription_id: UUID
734-
description: str
735-
status: SubscriptionLifecycle
736-
insync: bool
737-
start_date: Optional[datetime]
738-
end_date: Optional[datetime]
739-
note: Optional[str]
703+
subscription_id: UUID = Field(default_factory=uuid4)
704+
description: str = "Initial subscription"
705+
status: SubscriptionLifecycle = SubscriptionLifecycle.INITIAL
706+
insync: bool = False
707+
start_date: Optional[datetime] = None
708+
end_date: Optional[datetime] = None
709+
note: Optional[str] = None
740710

741711
def __new__(cls, *args: Any, status: Optional[SubscriptionLifecycle] = None, **kwargs: Any) -> "SubscriptionModel":
742712

@@ -777,7 +747,7 @@ def diff_product_in_database(cls, product_id: UUID) -> Dict[str, Any]:
777747
missing_fixed_inputs_in_model=missing_fixed_inputs_in_model,
778748
)
779749

780-
def find_product_block_in(cls: Type[Union[S, B]]) -> List[ProductBlockModel]:
750+
def find_product_block_in(cls: Type[DomainModel]) -> List[ProductBlockModel]:
781751
product_blocks_in_model = []
782752
for product_block_field_type in cls._product_block_fields_.values():
783753
if is_list_type(product_block_field_type) or is_optional_type(product_block_field_type):
@@ -866,12 +836,12 @@ def from_product_id(
866836
# Caller wants a new instance and provided a product_id and customer_id
867837
product_db = ProductTable.query.get(product_id)
868838
product = ProductModel(
869-
product_db.product_id,
870-
product_db.name,
871-
product_db.description,
872-
product_db.product_type,
873-
product_db.tag,
874-
product_db.status,
839+
product_id=product_db.product_id,
840+
name=product_db.name,
841+
description=product_db.description,
842+
product_type=product_db.product_type,
843+
tag=product_db.tag,
844+
status=product_db.status,
875845
)
876846

877847
if description is None:
@@ -881,15 +851,14 @@ def from_product_id(
881851
instances = cls._init_instances()
882852

883853
return cls(
884-
product,
885-
customer_id,
886-
uuid4(),
887-
description,
888-
status,
889-
insync,
890-
start_date,
891-
end_date,
892-
note,
854+
product=product,
855+
customer_id=customer_id, # type:ignore
856+
description=description,
857+
status=status,
858+
insync=insync,
859+
start_date=start_date,
860+
end_date=end_date,
861+
note=note,
893862
**fixed_inputs,
894863
**instances,
895864
)
@@ -905,12 +874,12 @@ def from_subscription(cls: Type[S], subscription_id: Union[UUID, UUIDstr]) -> S:
905874
selectinload(SubscriptionTable.instances).selectinload(SubscriptionInstanceTable.values),
906875
).get(subscription_id)
907876
product = ProductModel(
908-
subscription.product.product_id,
909-
subscription.product.name,
910-
subscription.product.description,
911-
subscription.product.product_type,
912-
subscription.product.tag,
913-
subscription.product.status,
877+
product_id=subscription.product.product_id,
878+
name=subscription.product.name,
879+
description=subscription.product.description,
880+
product_type=subscription.product.product_type,
881+
tag=subscription.product.tag,
882+
status=subscription.product.status,
914883
)
915884
status = SubscriptionLifecycle(subscription.status)
916885

@@ -920,7 +889,7 @@ def from_subscription(cls: Type[S], subscription_id: Union[UUID, UUIDstr]) -> S:
920889
# Import here to prevent cyclic imports
921890
from orchestrator.domain import SUBSCRIPTION_MODEL_REGISTRY
922891

923-
cls = SUBSCRIPTION_MODEL_REGISTRY.get(subscription.product.name, cls)
892+
cls = SUBSCRIPTION_MODEL_REGISTRY.get(subscription.product.name, cls) # type:ignore
924893
cls = lookup_specialized_type(cls, status)
925894
if cls != old_cls and not issubclass(cls, old_cls):
926895
raise ValueError(
@@ -931,16 +900,16 @@ def from_subscription(cls: Type[S], subscription_id: Union[UUID, UUIDstr]) -> S:
931900
instances = cls._load_instances(subscription.instances, status)
932901

933902
try:
934-
return cls(
935-
product,
936-
subscription.customer_id,
937-
subscription_id,
938-
subscription.description,
939-
status,
940-
subscription.insync,
941-
subscription.start_date,
942-
subscription.end_date,
943-
subscription.note,
903+
return cls( # type: ignore
904+
product=product,
905+
customer_id=subscription.customer_id,
906+
subscription_id=subscription_id, # type:ignore
907+
description=subscription.description,
908+
status=status,
909+
insync=subscription.insync,
910+
start_date=subscription.start_date,
911+
end_date=subscription.end_date,
912+
note=subscription.note,
944913
**fixed_inputs,
945914
**instances,
946915
)

orchestrator/domain/config.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

orchestrator/domain/lifecycle.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from dataclasses import asdict
1514
from typing import Dict, List, Optional, Tuple, Type, TypeVar
1615

1716
import structlog
1817

19-
from orchestrator.types import SubscriptionLifecycle, is_list_type, is_optional_type, strEnum
18+
from orchestrator.types import SubscriptionLifecycle, strEnum
2019
from orchestrator.utils.datetime import nowtz
2120

2221
logger = structlog.get_logger(__name__)
@@ -57,21 +56,12 @@ def lookup_specialized_type(block: Type, lifecycle: Optional[SubscriptionLifecyc
5756

5857
def change_lifecycle(subscription: T, status: SubscriptionLifecycle) -> T:
5958
new_klass = lookup_specialized_type(subscription.__class__, status)
60-
data = asdict(subscription)
59+
data = subscription.dict() # type:ignore
6160

6261
data["status"] = status
6362
if data["start_date"] is None and status == SubscriptionLifecycle.ACTIVE:
6463
data["start_date"] = nowtz()
6564
if data["end_date"] is None and status == SubscriptionLifecycle.TERMINATED:
6665
data["end_date"] = nowtz()
6766

68-
for product_block_field_name, product_block_field_type in new_klass._product_block_fields_.items():
69-
current = getattr(subscription, product_block_field_name)
70-
if is_list_type(product_block_field_type):
71-
data[product_block_field_name] = [asdict(item) for item in current]
72-
elif is_optional_type(product_block_field_type) and current is None:
73-
data[product_block_field_name] = None
74-
else:
75-
data[product_block_field_name] = asdict(current)
76-
7767
return new_klass(**data)

0 commit comments

Comments
 (0)