Skip to content

Commit 62be80d

Browse files
authored
Merge pull request #45 from koxudaxi/improve_model_management
improve model management
2 parents 9e3228f + 3794656 commit 62be80d

File tree

3 files changed

+108
-52
lines changed

3 files changed

+108
-52
lines changed

fastapi_code_generator/parser.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import stringcase
99
import yaml
1010
from datamodel_code_generator import DataModelField, snooper_to_methods
11-
from datamodel_code_generator.imports import IMPORT_LIST, Import, Imports
12-
from datamodel_code_generator.model.pydantic.types import type_map
11+
from datamodel_code_generator.imports import Import, Imports
1312
from datamodel_code_generator.parser.jsonschema import (
1413
JsonSchemaObject,
1514
get_model_by_path,
16-
json_schema_data_formats,
1715
)
18-
from pydantic import BaseModel, Field, root_validator
16+
from datamodel_code_generator.parser.openapi import OpenAPIParser as OpenAPIModelParser
17+
from datamodel_code_generator.types import DataType
18+
from pydantic import BaseModel, root_validator
1919

2020
MODEL_PATH = ".models"
2121

@@ -89,6 +89,7 @@ class Operation(CachedPropertyModel):
8989
imports: List[Import] = []
9090
security: Optional[List[Dict[str, List[str]]]] = None
9191
components: Dict[str, Any] = {}
92+
open_api_model_parser: OpenAPIModelParser
9293

9394
@cached_property
9495
def root_path(self) -> UsefulStr:
@@ -108,19 +109,17 @@ def request(self) -> Optional[Argument]:
108109
for content_type, schema in requests.contents.items():
109110
# TODO: support other content-types
110111
if content_type == "application/json":
112+
data_type = self.get_data_type(schema)
111113
arguments.append(
112114
# TODO: support multiple body
113115
Argument(
114116
name='body', # type: ignore
115-
type_hint=schema.ref_object_name,
117+
type_hint=data_type.type_hint,
116118
required=requests.required,
117119
)
118120
)
119-
self.imports.append(
120-
Import(
121-
from_=model_path_var.get(), import_=schema.ref_object_name
122-
)
123-
)
121+
self.imports.extend(data_type.imports_)
122+
124123
if not arguments:
125124
return None
126125
return arguments[0]
@@ -207,20 +206,37 @@ def get_argument_list(self, snake_case: bool) -> List[Argument]:
207206
arguments.append(self.request)
208207
return arguments
209208

209+
def get_data_type(self, schema: JsonSchemaObject) -> DataType:
210+
if schema.ref:
211+
data_type = self.open_api_model_parser.get_ref_data_type(schema.ref)
212+
data_type.imports_.append(
213+
Import(
214+
# TODO: Improve import statements
215+
from_=model_path_var.get(),
216+
import_=data_type.type,
217+
)
218+
)
219+
return data_type
220+
elif schema.is_array:
221+
# TODO: Improve handling array
222+
items = schema.items if isinstance(schema.items, list) else [schema.items]
223+
return self.open_api_model_parser.data_type(
224+
data_types=[self.get_data_type(i) for i in items], is_list=True
225+
)
226+
return self.open_api_model_parser.get_data_type(schema)
227+
210228
def get_parameter_type(
211229
self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool
212230
) -> Argument:
213-
schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"])
214-
format_ = schema.format or "default"
215-
type_ = json_schema_data_formats[schema.type][format_]
216231
name: str = parameter["name"] # type: ignore
217232
orig_name = name
218233
if snake_case:
219234
name = stringcase.snakecase(name)
235+
schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"])
220236

221237
field = DataModelField(
222238
name=name,
223-
data_type=type_map[type_],
239+
data_type=self.get_data_type(schema),
224240
required=parameter.get("required") or parameter.get("in") == "path",
225241
)
226242
self.imports.extend(field.imports)
@@ -241,46 +257,21 @@ def get_parameter_type(
241257

242258
@cached_property
243259
def response(self) -> str:
244-
models: List[str] = []
260+
data_types: List[DataType] = []
245261
for response in self.response_objects:
246262
# expect 2xx
247263
if response.status_code.startswith("2"):
248264
for content_type, schema in response.contents.items():
249265
if content_type == "application/json":
250-
if schema.is_array:
251-
if isinstance(schema.items, list):
252-
type_ = f'List[{",".join(i.ref_object_name for i in schema.items)}]'
253-
self.imports.extend(
254-
Import(
255-
from_=model_path_var.get(),
256-
import_=i.ref_object_name,
257-
)
258-
for i in schema.items
259-
)
260-
else:
261-
type_ = f'List[{schema.items.ref_object_name}]'
262-
self.imports.append(
263-
Import(
264-
from_=model_path_var.get(),
265-
import_=schema.items.ref_object_name,
266-
)
267-
)
268-
self.imports.append(IMPORT_LIST)
269-
else:
270-
type_ = schema.ref_object_name
271-
self.imports.append(
272-
Import(
273-
from_=model_path_var.get(),
274-
import_=schema.ref_object_name,
275-
)
276-
)
277-
models.append(type_)
278-
279-
if not models:
266+
data_type = self.get_data_type(schema)
267+
data_types.append(data_type)
268+
self.imports.extend(data_type.imports_)
269+
270+
if not data_types:
280271
return "None"
281-
if len(models) > 1:
282-
return f'Union[{",".join(models)}]'
283-
return models[0]
272+
if len(data_types) > 1:
273+
return self.open_api_model_parser.data_type(data_types=data_types).type_hint
274+
return data_types[0].type_hint
284275

285276

286277
OPERATION_NAMES: List[str] = [
@@ -296,6 +287,9 @@ def response(self) -> str:
296287

297288

298289
class Operations(BaseModel):
290+
class Config:
291+
arbitrary_types_allowed = (OpenAPIModelParser,)
292+
299293
parameters: List[Dict[str, Any]] = []
300294
get: Optional[Operation] = None
301295
put: Optional[Operation] = None
@@ -308,20 +302,29 @@ class Operations(BaseModel):
308302
path: UsefulStr
309303
security: Optional[List[Dict[str, List[str]]]] = []
310304
components: Dict[str, Any] = {}
305+
open_api_model_parser: OpenAPIModelParser
311306

312307
@root_validator(pre=True)
313308
def inject_path_and_type_to_operation(cls, values: Dict[str, Any]) -> Any:
314309
path: Any = values.get('path')
310+
open_api_model_parser: OpenAPIModelParser = values.get('open_api_model_parser')
315311
return dict(
316312
**{
317-
o: dict(**v, path=path, type=o, components=values.get('components', {}))
313+
o: dict(
314+
**v,
315+
path=path,
316+
type=o,
317+
components=values.get('components', {}),
318+
open_api_model_parser=open_api_model_parser,
319+
)
318320
for o in OPERATION_NAMES
319321
if (v := values.get(o))
320322
},
321323
path=path,
322324
parameters=values.get('parameters', []),
323325
security=values.get('security'),
324326
components=values.get('components', {}),
327+
open_api_model_parser=open_api_model_parser,
325328
)
326329

327330
@root_validator
@@ -342,6 +345,7 @@ class Path(CachedPropertyModel):
342345
operations: Optional[Operations] = None
343346
security: Optional[List[Dict[str, List[str]]]] = []
344347
components: Dict[str, Any] = {}
348+
open_api_model_parser: OpenAPIModelParser
345349

346350
@root_validator(pre=True)
347351
def validate_root(cls, values: Dict[str, Any]) -> Any:
@@ -351,16 +355,19 @@ def validate_root(cls, values: Dict[str, Any]) -> Any:
351355
if isinstance(operations, dict):
352356
security = values.get('security', [])
353357
components = values.get('components', {})
358+
open_api_model_parser = values.get('open_api_model_parser')
354359
return {
355360
'path': path,
356361
'operations': dict(
357362
**operations,
358363
path=path,
359364
security=security,
360365
components=components,
366+
open_api_model_parser=open_api_model_parser,
361367
),
362368
'security': security,
363369
'components': components,
370+
'open_api_model_parser': open_api_model_parser,
364371
}
365372
return values
366373

@@ -407,6 +414,7 @@ def __init__(
407414
self.input_text: str = input_text
408415
if model_path:
409416
model_path_var.set(model_path)
417+
self.open_api_model_parser: OpenAPIModelParser = OpenAPIModelParser(source='')
410418

411419
def parse(self) -> ParsedObject:
412420
openapi = yaml.safe_load(self.input_text)
@@ -434,6 +442,7 @@ def parse_paths(self, openapi: Dict[str, Any]) -> ParsedObject:
434442
operations=operations,
435443
security=security,
436444
components=openapi.get('components', {}),
445+
open_api_model_parser=self.open_api_model_parser,
437446
).exists_operations
438447
],
439448
info,

tests/data/expected/openapi/default_template/body_and_parameters/main.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@
1313
app = FastAPI(version="1.0.0", title="Swagger Petstore", license="{'name': 'MIT'}",)
1414

1515

16-
@app.get('/food/{food_id}', response_model=None)
17-
def show_food_by_id(food_id: str) -> None:
16+
@app.post('/food', response_model=None)
17+
def post_food(body: str) -> None:
18+
pass
19+
20+
21+
@app.get('/food/{food_id}', response_model=List[int])
22+
def show_food_by_id(
23+
food_id: str, message_texts: Optional[List[str]] = None
24+
) -> List[int]:
1825
pass
1926

2027

tests/data/openapi/default_template/body_and_parameters.yaml

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,26 @@ paths:
134134
passthroughBehavior: when_no_templates
135135
httpMethod: POST
136136
type: aws_proxy
137+
/food:
138+
post:
139+
summary: Create a pet
140+
tags:
141+
- pets
142+
requestBody:
143+
required: true
144+
content:
145+
application/json:
146+
schema:
147+
type: string
148+
responses:
149+
'201':
150+
description: Null response
151+
default:
152+
description: unexpected error
153+
content:
154+
application/json:
155+
schema:
156+
type: string
137157
/food/{food_id}:
138158
get:
139159
summary: Info for a specific pet
@@ -146,9 +166,29 @@ paths:
146166
description: The id of the food to retrieve
147167
schema:
148168
type: string
169+
- name: message_texts
170+
in: query
171+
required: false
172+
explode: true
173+
schema:
174+
type: array
175+
items:
176+
type: string
149177
responses:
150178
'200':
151-
description: Expected response to a valid request
179+
description: OK
180+
content:
181+
application/json:
182+
schema:
183+
type: array
184+
items:
185+
type: integer
186+
examples:
187+
example-1:
188+
value:
189+
- 0
190+
- 1
191+
- 3
152192
x-amazon-apigateway-integration:
153193
uri:
154194
Fn::Sub: arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${PythonVersionFunction.Arn}/invocations

0 commit comments

Comments
 (0)