11from __future__ import annotations
22
33import re
4+ from contextvars import ContextVar
45from functools import cached_property
56from typing import Any , Dict , List , Optional , Union
67
1011 load_json_or_yaml ,
1112 snooper_to_methods ,
1213)
13- from datamodel_code_generator .imports import Import , Imports
14+ from datamodel_code_generator .imports import IMPORT_LIST , Import , Imports
1415from datamodel_code_generator .model .pydantic .types import type_map
1516from datamodel_code_generator .parser .jsonschema import (
1617 JsonSchemaObject ,
2223
2324MODEL_PATH = ".models"
2425
26+ model_path_var : ContextVar [str ] = ContextVar ('model_path' , default = MODEL_PATH )
27+
2528
2629class CachedPropertyModel (BaseModel ):
2730 class Config :
@@ -41,13 +44,40 @@ class Request(BaseModel):
4144 required : bool
4245
4346
47+ class UsefulStr (str ):
48+ @property
49+ def snakecase (self ) -> str :
50+ return stringcase .snakecase (self )
51+
52+ @property
53+ def pascalcase (self ) -> str :
54+ return stringcase .pascalcase (self )
55+
56+ @property
57+ def camelcase (self ) -> str :
58+ return stringcase .camelcase (self )
59+
60+
61+ class Argument (BaseModel ):
62+ name : UsefulStr
63+
64+ @validator ('name' )
65+ def validate_name (cls , value : Any ) -> Any :
66+ if type (value ) == str :
67+ return UsefulStr (value )
68+ return value
69+
70+ # def __str__(self) -> UsefulStr:
71+ # return self.name
72+
73+
4474class Operation (CachedPropertyModel ):
45- type : Optional [str ]
46- path : Optional [str ]
47- operationId : Optional [str ]
48- rootPath : Optional [str ]
75+ type : Optional [UsefulStr ]
76+ path : Optional [UsefulStr ]
77+ operationId : Optional [UsefulStr ]
78+ root_path : Optional [UsefulStr ]
4979 parameters : Optional [Any ]
50- responses : Dict [str , Any ] = {}
80+ responses : Dict [UsefulStr , Any ] = {}
5181 requestBody : Dict [str , Any ] = {}
5282 imports : List [Import ] = []
5383
@@ -59,7 +89,7 @@ def snake_case_path(self) -> str:
5989
6090 def set_path (self , path : Path ) -> None :
6191 self .path = path .path
62- self .rootPath = path .root_path
92+ self .root_path = UsefulStr ( path .root_path )
6393
6494 @cached_property
6595 def request (self ) -> Optional [str ]:
@@ -69,7 +99,9 @@ def request(self) -> Optional[str]:
6999 if content_type == "application/json" :
70100 models .append (schema .ref_object_name )
71101 self .imports .append (
72- Import (from_ = MODEL_PATH , import_ = schema .ref_object_name )
102+ Import (
103+ from_ = model_path_var .get (), import_ = schema .ref_object_name
104+ )
73105 )
74106 if not models :
75107 return None
@@ -81,9 +113,9 @@ def request(self) -> Optional[str]:
81113 def request_objects (self ) -> List [Request ]:
82114 requests : List [Request ] = []
83115 contents : Dict [str , JsonSchemaObject ] = {}
84- for content_type , obj in self .requestBody .get (" content" , {}).items ():
116+ for content_type , obj in self .requestBody .get (' content' , {}).items ():
85117 contents [content_type ] = (
86- JsonSchemaObject .parse_obj (obj [" schema" ]) if " schema" in obj else None
118+ JsonSchemaObject .parse_obj (obj [' schema' ]) if ' schema' in obj else None
87119 )
88120 requests .append (
89121 Request (
@@ -131,29 +163,36 @@ def dump_imports(self) -> str:
131163
132164 @cached_property
133165 def arguments (self ) -> str :
134- parameters : List [str ] = []
166+ return self .get_arguments (snake_case = False )
167+
168+ @cached_property
169+ def snake_case_arguments (self ) -> str :
170+ return self .get_arguments (snake_case = True )
171+
172+ def get_arguments (self , snake_case : bool ) -> str :
173+ arguments : List [str ] = []
135174
136175 if self .parameters :
137176 for parameter in self .parameters :
138- parameters .append (self .get_parameter_type (parameter , False ))
177+ arguments .append (self .get_parameter_type (parameter , snake_case ))
139178
140179 if self .request :
141- parameters .append (f"body: { self .request } " )
180+ arguments .append (f"body: { self .request } " )
142181
143- return ", " .join (parameters )
182+ return ", " .join (arguments )
144183
145184 @cached_property
146- def snake_case_arguments (self ) -> str :
147- parameters : List [str ] = []
185+ def argument_list (self ) -> List [ Argument ] :
186+ arguments : List [Argument ] = []
148187
149188 if self .parameters :
150189 for parameter in self .parameters :
151- parameters .append (self . get_parameter_type (parameter , True ))
190+ arguments .append (Argument . parse_obj (parameter ))
152191
153192 if self .request :
154- parameters .append (f" body: { self . request } " )
193+ arguments .append (Argument ( name = UsefulStr ( ' body' )) )
155194
156- return ", " . join ( parameters )
195+ return arguments
157196
158197 def get_parameter_type (
159198 self , parameter : Dict [str , Union [str , Dict [str , str ]]], snake_case : bool
@@ -200,10 +239,35 @@ def response(self) -> str:
200239 if response .status_code .startswith ("2" ):
201240 for content_type , schema in response .contents .items ():
202241 if content_type == "application/json" :
203- models .append (schema .ref_object_name )
204- self .imports .append (
205- Import (from_ = MODEL_PATH , import_ = schema .ref_object_name )
206- )
242+ if schema .is_array :
243+ if isinstance (schema .items , list ):
244+ type_ = f'List[{ "," .join (i .ref_object_name for i in schema .items )} ]'
245+ self .imports .extend (
246+ Import (
247+ from_ = model_path_var .get (),
248+ import_ = i .ref_object_name ,
249+ )
250+ for i in schema .items
251+ )
252+ else :
253+ type_ = f'List[{ schema .items .ref_object_name } ]'
254+ self .imports .append (
255+ Import (
256+ from_ = model_path_var .get (),
257+ import_ = schema .items .ref_object_name ,
258+ )
259+ )
260+ self .imports .append (IMPORT_LIST )
261+ else :
262+ type_ = schema .ref_object_name
263+ self .imports .append (
264+ Import (
265+ from_ = model_path_var .get (),
266+ import_ = schema .ref_object_name ,
267+ )
268+ )
269+ models .append (type_ )
270+
207271 if not models :
208272 return "None"
209273 if len (models ) > 1 :
@@ -237,12 +301,12 @@ class Operations(BaseModel):
237301 @validator (* OPERATION_NAMES )
238302 def validate_operations (cls , value : Any , field : ModelField ) -> Any :
239303 if isinstance (value , Operation ):
240- value .type = field .name
304+ value .type = UsefulStr ( field .name )
241305 return value
242306
243307
244308class Path (BaseModel ):
245- path : Optional [str ]
309+ path : Optional [UsefulStr ]
246310 operations : Optional [Operations ] = None
247311 children : List [Path ] = []
248312 parent : Optional [Path ] = None
@@ -273,7 +337,9 @@ def init(self) -> None:
273337
274338class ParsedObject :
275339 def __init__ (self , parsed_operations : List [Operation ]):
276- self .operations = sorted (parsed_operations , key = lambda m : m .path )
340+ self .operations : List [Operation ] = sorted (
341+ parsed_operations , key = lambda m : m .path
342+ )
277343 self .imports : Imports = Imports ()
278344 for operation in self .operations :
279345 # create imports
@@ -285,9 +351,13 @@ def __init__(self, parsed_operations: List[Operation]):
285351
286352@snooper_to_methods (max_variable_length = None )
287353class OpenAPIParser :
288- def __init__ (self , input_name : str , input_text : str ,) -> None :
354+ def __init__ (
355+ self , input_name : str , input_text : str , model_path : Optional [str ] = None
356+ ) -> None :
289357 self .input_name : str = input_name
290358 self .input_text : str = input_text
359+ if model_path :
360+ model_path_var .set (model_path )
291361
292362 def parse (self ) -> ParsedObject :
293363 openapi = load_json_or_yaml (self .input_text )
@@ -312,7 +382,7 @@ def parse_paths(self, path_tree: Dict[str, Any]) -> ParsedObject:
312382 if me :
313383 continue
314384
315- last = Path (path = "/" .join (tree ), parent = parent )
385+ last = Path (path = UsefulStr ( "/" .join (tree ) ), parent = parent )
316386
317387 paths .append (last )
318388
0 commit comments