88import stringcase
99import yaml
1010from datamodel_code_generator import DataModelField , snooper_to_methods
11- from datamodel_code_generator .imports import IMPORT_LIST , Import , Imports
12- from datamodel_code_generator .model .pydantic .types import type_map
11+ from datamodel_code_generator .imports import Import , Imports
1312from datamodel_code_generator .parser .jsonschema import (
1413 JsonSchemaObject ,
1514 get_model_by_path ,
16- json_schema_data_formats ,
1715)
18- from pydantic import BaseModel , Field , root_validator
16+ from datamodel_code_generator .parser .openapi import OpenAPIParser as OpenAPIModelParser
17+ from datamodel_code_generator .types import DataType
18+ from pydantic import BaseModel , root_validator
1919
2020MODEL_PATH = ".models"
2121
@@ -89,6 +89,7 @@ class Operation(CachedPropertyModel):
8989 imports : List [Import ] = []
9090 security : Optional [List [Dict [str , List [str ]]]] = None
9191 components : Dict [str , Any ] = {}
92+ open_api_model_parser : OpenAPIModelParser
9293
9394 @cached_property
9495 def root_path (self ) -> UsefulStr :
@@ -108,19 +109,17 @@ def request(self) -> Optional[Argument]:
108109 for content_type , schema in requests .contents .items ():
109110 # TODO: support other content-types
110111 if content_type == "application/json" :
112+ data_type = self .get_data_type (schema )
111113 arguments .append (
112114 # TODO: support multiple body
113115 Argument (
114116 name = 'body' , # type: ignore
115- type_hint = schema . ref_object_name ,
117+ type_hint = data_type . type_hint ,
116118 required = requests .required ,
117119 )
118120 )
119- self .imports .append (
120- Import (
121- from_ = model_path_var .get (), import_ = schema .ref_object_name
122- )
123- )
121+ self .imports .extend (data_type .imports_ )
122+
124123 if not arguments :
125124 return None
126125 return arguments [0 ]
@@ -207,20 +206,37 @@ def get_argument_list(self, snake_case: bool) -> List[Argument]:
207206 arguments .append (self .request )
208207 return arguments
209208
209+ def get_data_type (self , schema : JsonSchemaObject ) -> DataType :
210+ if schema .ref :
211+ data_type = self .open_api_model_parser .get_ref_data_type (schema .ref )
212+ data_type .imports_ .append (
213+ Import (
214+ # TODO: Improve import statements
215+ from_ = model_path_var .get (),
216+ import_ = data_type .type ,
217+ )
218+ )
219+ return data_type
220+ elif schema .is_array :
221+ # TODO: Improve handling array
222+ items = schema .items if isinstance (schema .items , list ) else [schema .items ]
223+ return self .open_api_model_parser .data_type (
224+ data_types = [self .get_data_type (i ) for i in items ], is_list = True
225+ )
226+ return self .open_api_model_parser .get_data_type (schema )
227+
210228 def get_parameter_type (
211229 self , parameter : Dict [str , Union [str , Dict [str , str ]]], snake_case : bool
212230 ) -> Argument :
213- schema : JsonSchemaObject = JsonSchemaObject .parse_obj (parameter ["schema" ])
214- format_ = schema .format or "default"
215- type_ = json_schema_data_formats [schema .type ][format_ ]
216231 name : str = parameter ["name" ] # type: ignore
217232 orig_name = name
218233 if snake_case :
219234 name = stringcase .snakecase (name )
235+ schema : JsonSchemaObject = JsonSchemaObject .parse_obj (parameter ["schema" ])
220236
221237 field = DataModelField (
222238 name = name ,
223- data_type = type_map [ type_ ] ,
239+ data_type = self . get_data_type ( schema ) ,
224240 required = parameter .get ("required" ) or parameter .get ("in" ) == "path" ,
225241 )
226242 self .imports .extend (field .imports )
@@ -241,46 +257,21 @@ def get_parameter_type(
241257
242258 @cached_property
243259 def response (self ) -> str :
244- models : List [str ] = []
260+ data_types : List [DataType ] = []
245261 for response in self .response_objects :
246262 # expect 2xx
247263 if response .status_code .startswith ("2" ):
248264 for content_type , schema in response .contents .items ():
249265 if content_type == "application/json" :
250- if schema .is_array :
251- if isinstance (schema .items , list ):
252- type_ = f'List[{ "," .join (i .ref_object_name for i in schema .items )} ]'
253- self .imports .extend (
254- Import (
255- from_ = model_path_var .get (),
256- import_ = i .ref_object_name ,
257- )
258- for i in schema .items
259- )
260- else :
261- type_ = f'List[{ schema .items .ref_object_name } ]'
262- self .imports .append (
263- Import (
264- from_ = model_path_var .get (),
265- import_ = schema .items .ref_object_name ,
266- )
267- )
268- self .imports .append (IMPORT_LIST )
269- else :
270- type_ = schema .ref_object_name
271- self .imports .append (
272- Import (
273- from_ = model_path_var .get (),
274- import_ = schema .ref_object_name ,
275- )
276- )
277- models .append (type_ )
278-
279- if not models :
266+ data_type = self .get_data_type (schema )
267+ data_types .append (data_type )
268+ self .imports .extend (data_type .imports_ )
269+
270+ if not data_types :
280271 return "None"
281- if len (models ) > 1 :
282- return f'Union[ { "," . join ( models ) } ]'
283- return models [0 ]
272+ if len (data_types ) > 1 :
273+ return self . open_api_model_parser . data_type ( data_types = data_types ). type_hint
274+ return data_types [0 ]. type_hint
284275
285276
286277OPERATION_NAMES : List [str ] = [
@@ -296,6 +287,9 @@ def response(self) -> str:
296287
297288
298289class Operations (BaseModel ):
290+ class Config :
291+ arbitrary_types_allowed = (OpenAPIModelParser ,)
292+
299293 parameters : List [Dict [str , Any ]] = []
300294 get : Optional [Operation ] = None
301295 put : Optional [Operation ] = None
@@ -308,20 +302,29 @@ class Operations(BaseModel):
308302 path : UsefulStr
309303 security : Optional [List [Dict [str , List [str ]]]] = []
310304 components : Dict [str , Any ] = {}
305+ open_api_model_parser : OpenAPIModelParser
311306
312307 @root_validator (pre = True )
313308 def inject_path_and_type_to_operation (cls , values : Dict [str , Any ]) -> Any :
314309 path : Any = values .get ('path' )
310+ open_api_model_parser : OpenAPIModelParser = values .get ('open_api_model_parser' )
315311 return dict (
316312 ** {
317- o : dict (** v , path = path , type = o , components = values .get ('components' , {}))
313+ o : dict (
314+ ** v ,
315+ path = path ,
316+ type = o ,
317+ components = values .get ('components' , {}),
318+ open_api_model_parser = open_api_model_parser ,
319+ )
318320 for o in OPERATION_NAMES
319321 if (v := values .get (o ))
320322 },
321323 path = path ,
322324 parameters = values .get ('parameters' , []),
323325 security = values .get ('security' ),
324326 components = values .get ('components' , {}),
327+ open_api_model_parser = open_api_model_parser ,
325328 )
326329
327330 @root_validator
@@ -342,6 +345,7 @@ class Path(CachedPropertyModel):
342345 operations : Optional [Operations ] = None
343346 security : Optional [List [Dict [str , List [str ]]]] = []
344347 components : Dict [str , Any ] = {}
348+ open_api_model_parser : OpenAPIModelParser
345349
346350 @root_validator (pre = True )
347351 def validate_root (cls , values : Dict [str , Any ]) -> Any :
@@ -351,16 +355,19 @@ def validate_root(cls, values: Dict[str, Any]) -> Any:
351355 if isinstance (operations , dict ):
352356 security = values .get ('security' , [])
353357 components = values .get ('components' , {})
358+ open_api_model_parser = values .get ('open_api_model_parser' )
354359 return {
355360 'path' : path ,
356361 'operations' : dict (
357362 ** operations ,
358363 path = path ,
359364 security = security ,
360365 components = components ,
366+ open_api_model_parser = open_api_model_parser ,
361367 ),
362368 'security' : security ,
363369 'components' : components ,
370+ 'open_api_model_parser' : open_api_model_parser ,
364371 }
365372 return values
366373
@@ -407,6 +414,7 @@ def __init__(
407414 self .input_text : str = input_text
408415 if model_path :
409416 model_path_var .set (model_path )
417+ self .open_api_model_parser : OpenAPIModelParser = OpenAPIModelParser (source = '' )
410418
411419 def parse (self ) -> ParsedObject :
412420 openapi = yaml .safe_load (self .input_text )
@@ -434,6 +442,7 @@ def parse_paths(self, openapi: Dict[str, Any]) -> ParsedObject:
434442 operations = operations ,
435443 security = security ,
436444 components = openapi .get ('components' , {}),
445+ open_api_model_parser = self .open_api_model_parser ,
437446 ).exists_operations
438447 ],
439448 info ,
0 commit comments