Skip to content

Commit 2e1b78f

Browse files
Incremental Changes
1 parent 6886283 commit 2e1b78f

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/strawberry_sqlalchemy_mapper/mapper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def __init__(
160160
extra_sqlalchemy_type_to_strawberry_type_map: Optional[
161161
Mapping[Type[TypeEngine], Type[Any]]
162162
] = None,
163+
edge_type: Optional[Type] = None,
164+
connection_type: Optional[Type] = None,
163165
) -> None:
164166
if model_to_type_name is None:
165167
model_to_type_name = self._default_model_to_type_name
@@ -181,6 +183,9 @@ def __init__(
181183
self._related_type_models = set()
182184
self._related_interface_models = set()
183185

186+
self.edge_type = edge_type
187+
self.connection_type = connection_type
188+
184189
@staticmethod
185190
def _default_model_to_type_name(model: Type[BaseModelType]) -> str:
186191
return model.__name__
@@ -220,6 +225,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
220225
Get or create a corresponding Edge model for the given type
221226
(to support future pagination)
222227
"""
228+
if self.edge_type is not None:
229+
return self.edge_type
223230
edge_name = f"{type_name}Edge"
224231
if edge_name not in self.edge_types:
225232
self.edge_types[edge_name] = edge_type = strawberry.type(
@@ -238,6 +245,8 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
238245
Get or create a corresponding Connection model for the given type
239246
(to support future pagination)
240247
"""
248+
if self.connection_type is not None:
249+
return self.connection_type[ForwardRef(type_name)]
241250
connection_name = f"{type_name}Connection"
242251
if connection_name not in self.connection_types:
243252
self.connection_types[connection_name] = connection_type = strawberry.type(
@@ -269,6 +278,8 @@ def _convert_column_to_strawberry_type(
269278
"""
270279
if isinstance(column.type, Enum):
271280
type_annotation = column.type.python_type
281+
if not hasattr(column.type, "_enum_definition"):
282+
type_annotation = strawberry.enum(type_annotation)
272283
elif isinstance(column.type, ARRAY):
273284
item_type = self._convert_column_to_strawberry_type(
274285
Column(column.type.item_type, nullable=False)

0 commit comments

Comments
 (0)