Skip to content

Commit 0c7ef52

Browse files
authored
Merge pull request #10 from koxudaxi/improve_argument_list
add field to argument
2 parents 628e8c3 + d49a7d4 commit 0c7ef52

File tree

1 file changed

+44
-51
lines changed

1 file changed

+44
-51
lines changed

fastapi_code_generator/parser.py

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,20 @@ def camelcase(self) -> str:
6565
return stringcase.camelcase(self)
6666

6767

68-
class Argument(BaseModel):
68+
class Argument(CachedPropertyModel):
6969
name: UsefulStr
70+
type_hint: UsefulStr
71+
default: Optional[UsefulStr]
72+
required: bool
73+
74+
def __str__(self) -> str:
75+
return self.argument
7076

71-
# def __str__(self) -> UsefulStr:
72-
# return self.name
77+
@cached_property
78+
def argument(self) -> str:
79+
if not self.default and self.required:
80+
return f'{self.name}: {self.type_hint}'
81+
return f'{self.name}: {self.type_hint} = {self.default}'
7382

7483

7584
class Operation(CachedPropertyModel):
@@ -93,22 +102,28 @@ def snake_case_path(self) -> str:
93102
)
94103

95104
@cached_property
96-
def request(self) -> Optional[str]:
97-
models: List[str] = []
105+
def request(self) -> Optional[Argument]:
106+
arguments: List[Argument] = []
98107
for requests in self.request_objects:
99108
for content_type, schema in requests.contents.items():
109+
# TODO: support other content-types
100110
if content_type == "application/json":
101-
models.append(schema.ref_object_name)
111+
arguments.append(
112+
# TODO: support multiple body
113+
Argument(
114+
name='body', # type: ignore
115+
type_hint=schema.ref_object_name,
116+
required=requests.required,
117+
)
118+
)
102119
self.imports.append(
103120
Import(
104121
from_=model_path_var.get(), import_=schema.ref_object_name
105122
)
106123
)
107-
if not models:
124+
if not arguments:
108125
return None
109-
if len(models) > 1:
110-
return f'Union[{",".join(models)}]'
111-
return models[0]
126+
return arguments[0]
112127

113128
@cached_property
114129
def request_objects(self) -> List[Request]:
@@ -171,69 +186,47 @@ def snake_case_arguments(self) -> str:
171186
return self.get_arguments(snake_case=True)
172187

173188
def get_arguments(self, snake_case: bool) -> str:
174-
arguments: List[str] = []
175-
176-
if self.parameters:
177-
for parameter in self.parameters:
178-
arguments.append(self.get_parameter_type(parameter, snake_case))
179-
180-
if self.request:
181-
arguments.append(f"body: {self.request}")
182-
183-
return ", ".join(arguments)
189+
return ", ".join(
190+
argument.argument for argument in self.get_argument_list(snake_case)
191+
)
184192

185193
@cached_property
186194
def argument_list(self) -> List[Argument]:
195+
return self.get_argument_list(False)
196+
197+
def get_argument_list(self, snake_case: bool) -> List[Argument]:
187198
arguments: List[Argument] = []
188199

189200
if self.parameters:
190201
for parameter in self.parameters:
191-
arguments.append(Argument.parse_obj(parameter))
202+
arguments.append(self.get_parameter_type(parameter, snake_case))
192203

193204
if self.request:
194-
arguments.append(Argument(name=UsefulStr('body')))
195-
205+
arguments.append(self.request)
196206
return arguments
197207

198208
def get_parameter_type(
199209
self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool
200-
) -> str:
210+
) -> Argument:
201211
schema: JsonSchemaObject = JsonSchemaObject.parse_obj(parameter["schema"])
202212
format_ = schema.format or "default"
203213
type_ = json_schema_data_formats[schema.type][format_]
204-
return self.get_data_type_hint(
205-
name=stringcase.snakecase(parameter["name"])
206-
if snake_case
207-
else parameter["name"],
214+
name: str = parameter["name"] # type: ignore
215+
216+
field = DataModelField(
217+
name=stringcase.snakecase(name) if snake_case else name,
208218
data_types=[type_map[type_]],
209219
required=parameter.get("required") == "true"
210220
or parameter.get("in") == "path",
211-
snake_case=snake_case,
212221
default=schema.typed_default,
213222
)
214-
215-
def get_data_type_hint(
216-
self,
217-
name: str,
218-
data_types: List[DataType],
219-
required: bool,
220-
snake_case: bool,
221-
default: Optional[str] = None,
222-
auto_import: bool = True,
223-
) -> str:
224-
field = DataModelField(
225-
name=stringcase.snakecase(name) if snake_case else name,
226-
data_types=data_types,
227-
required=required,
228-
default=default,
223+
self.imports.extend(field.imports)
224+
return Argument(
225+
name=field.name,
226+
type_hint=field.type_hint,
227+
default=field.default,
228+
required=field.required,
229229
)
230-
if auto_import:
231-
self.imports.extend(field.imports)
232-
233-
if not default and field.required:
234-
return f"{field.name}: {field.type_hint}"
235-
236-
return f'{field.name}: {field.type_hint} = {default}'
237230

238231
@cached_property
239232
def response(self) -> str:

0 commit comments

Comments
 (0)