Skip to content

Commit 48263c1

Browse files
authored
Merge pull request #4 from koxudaxi/refactor_openapi_parser
Refactor openapi parser
2 parents ec26b04 + 0bea693 commit 48263c1

File tree

2 files changed

+100
-30
lines changed

2 files changed

+100
-30
lines changed

fastapi_code_generator/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from datamodel_code_generator.format import format_code
99
from jinja2 import Environment, FileSystemLoader
1010

11-
from fastapi_code_generator.parser import OpenAPIParser, ParsedObject
11+
from fastapi_code_generator.parser import OpenAPIParser, Operation, ParsedObject
1212

1313
app = typer.Typer() # type: ignore
1414

@@ -27,7 +27,7 @@ def main(
2727
output_dir.mkdir(parents=True)
2828
if not template_dir:
2929
template_dir = BUILTIN_TEMPLATE_DIR
30-
parser = OpenAPIParser(input_name, input_text,)
30+
parser = OpenAPIParser(input_name, input_text)
3131
parsed_object: ParsedObject = parser.parse()
3232

3333
environment: Environment = Environment(

fastapi_code_generator/parser.py

Lines changed: 98 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import re
4+
from contextvars import ContextVar
45
from functools import cached_property
56
from typing import Any, Dict, List, Optional, Union
67

@@ -10,7 +11,7 @@
1011
load_json_or_yaml,
1112
snooper_to_methods,
1213
)
13-
from datamodel_code_generator.imports import Import, Imports
14+
from datamodel_code_generator.imports import IMPORT_LIST, Import, Imports
1415
from datamodel_code_generator.model.pydantic.types import type_map
1516
from datamodel_code_generator.parser.jsonschema import (
1617
JsonSchemaObject,
@@ -22,6 +23,8 @@
2223

2324
MODEL_PATH = ".models"
2425

26+
model_path_var: ContextVar[str] = ContextVar('model_path', default=MODEL_PATH)
27+
2528

2629
class CachedPropertyModel(BaseModel):
2730
class Config:
@@ -41,13 +44,40 @@ class Request(BaseModel):
4144
required: bool
4245

4346

47+
class UsefulStr(str):
48+
@property
49+
def snakecase(self) -> str:
50+
return stringcase.snakecase(self)
51+
52+
@property
53+
def pascalcase(self) -> str:
54+
return stringcase.pascalcase(self)
55+
56+
@property
57+
def camelcase(self) -> str:
58+
return stringcase.camelcase(self)
59+
60+
61+
class Argument(BaseModel):
62+
name: UsefulStr
63+
64+
@validator('name')
65+
def validate_name(cls, value: Any) -> Any:
66+
if type(value) == str:
67+
return UsefulStr(value)
68+
return value
69+
70+
# def __str__(self) -> UsefulStr:
71+
# return self.name
72+
73+
4474
class Operation(CachedPropertyModel):
45-
type: Optional[str]
46-
path: Optional[str]
47-
operationId: Optional[str]
48-
rootPath: Optional[str]
75+
type: Optional[UsefulStr]
76+
path: Optional[UsefulStr]
77+
operationId: Optional[UsefulStr]
78+
root_path: Optional[UsefulStr]
4979
parameters: Optional[Any]
50-
responses: Dict[str, Any] = {}
80+
responses: Dict[UsefulStr, Any] = {}
5181
requestBody: Dict[str, Any] = {}
5282
imports: List[Import] = []
5383

@@ -59,7 +89,7 @@ def snake_case_path(self) -> str:
5989

6090
def set_path(self, path: Path) -> None:
6191
self.path = path.path
62-
self.rootPath = path.root_path
92+
self.root_path = UsefulStr(path.root_path)
6393

6494
@cached_property
6595
def request(self) -> Optional[str]:
@@ -69,7 +99,9 @@ def request(self) -> Optional[str]:
6999
if content_type == "application/json":
70100
models.append(schema.ref_object_name)
71101
self.imports.append(
72-
Import(from_=MODEL_PATH, import_=schema.ref_object_name)
102+
Import(
103+
from_=model_path_var.get(), import_=schema.ref_object_name
104+
)
73105
)
74106
if not models:
75107
return None
@@ -81,9 +113,9 @@ def request(self) -> Optional[str]:
81113
def request_objects(self) -> List[Request]:
82114
requests: List[Request] = []
83115
contents: Dict[str, JsonSchemaObject] = {}
84-
for content_type, obj in self.requestBody.get("content", {}).items():
116+
for content_type, obj in self.requestBody.get('content', {}).items():
85117
contents[content_type] = (
86-
JsonSchemaObject.parse_obj(obj["schema"]) if "schema" in obj else None
118+
JsonSchemaObject.parse_obj(obj['schema']) if 'schema' in obj else None
87119
)
88120
requests.append(
89121
Request(
@@ -131,29 +163,36 @@ def dump_imports(self) -> str:
131163

132164
@cached_property
133165
def arguments(self) -> str:
134-
parameters: List[str] = []
166+
return self.get_arguments(snake_case=False)
167+
168+
@cached_property
169+
def snake_case_arguments(self) -> str:
170+
return self.get_arguments(snake_case=True)
171+
172+
def get_arguments(self, snake_case: bool) -> str:
173+
arguments: List[str] = []
135174

136175
if self.parameters:
137176
for parameter in self.parameters:
138-
parameters.append(self.get_parameter_type(parameter, False))
177+
arguments.append(self.get_parameter_type(parameter, snake_case))
139178

140179
if self.request:
141-
parameters.append(f"body: {self.request}")
180+
arguments.append(f"body: {self.request}")
142181

143-
return ", ".join(parameters)
182+
return ", ".join(arguments)
144183

145184
@cached_property
146-
def snake_case_arguments(self) -> str:
147-
parameters: List[str] = []
185+
def argument_list(self) -> List[Argument]:
186+
arguments: List[Argument] = []
148187

149188
if self.parameters:
150189
for parameter in self.parameters:
151-
parameters.append(self.get_parameter_type(parameter, True))
190+
arguments.append(Argument.parse_obj(parameter))
152191

153192
if self.request:
154-
parameters.append(f"body: {self.request}")
193+
arguments.append(Argument(name=UsefulStr('body')))
155194

156-
return ", ".join(parameters)
195+
return arguments
157196

158197
def get_parameter_type(
159198
self, parameter: Dict[str, Union[str, Dict[str, str]]], snake_case: bool
@@ -200,10 +239,35 @@ def response(self) -> str:
200239
if response.status_code.startswith("2"):
201240
for content_type, schema in response.contents.items():
202241
if content_type == "application/json":
203-
models.append(schema.ref_object_name)
204-
self.imports.append(
205-
Import(from_=MODEL_PATH, import_=schema.ref_object_name)
206-
)
242+
if schema.is_array:
243+
if isinstance(schema.items, list):
244+
type_ = f'List[{",".join(i.ref_object_name for i in schema.items)}]'
245+
self.imports.extend(
246+
Import(
247+
from_=model_path_var.get(),
248+
import_=i.ref_object_name,
249+
)
250+
for i in schema.items
251+
)
252+
else:
253+
type_ = f'List[{schema.items.ref_object_name}]'
254+
self.imports.append(
255+
Import(
256+
from_=model_path_var.get(),
257+
import_=schema.items.ref_object_name,
258+
)
259+
)
260+
self.imports.append(IMPORT_LIST)
261+
else:
262+
type_ = schema.ref_object_name
263+
self.imports.append(
264+
Import(
265+
from_=model_path_var.get(),
266+
import_=schema.ref_object_name,
267+
)
268+
)
269+
models.append(type_)
270+
207271
if not models:
208272
return "None"
209273
if len(models) > 1:
@@ -237,12 +301,12 @@ class Operations(BaseModel):
237301
@validator(*OPERATION_NAMES)
238302
def validate_operations(cls, value: Any, field: ModelField) -> Any:
239303
if isinstance(value, Operation):
240-
value.type = field.name
304+
value.type = UsefulStr(field.name)
241305
return value
242306

243307

244308
class Path(BaseModel):
245-
path: Optional[str]
309+
path: Optional[UsefulStr]
246310
operations: Optional[Operations] = None
247311
children: List[Path] = []
248312
parent: Optional[Path] = None
@@ -273,7 +337,9 @@ def init(self) -> None:
273337

274338
class ParsedObject:
275339
def __init__(self, parsed_operations: List[Operation]):
276-
self.operations = sorted(parsed_operations, key=lambda m: m.path)
340+
self.operations: List[Operation] = sorted(
341+
parsed_operations, key=lambda m: m.path
342+
)
277343
self.imports: Imports = Imports()
278344
for operation in self.operations:
279345
# create imports
@@ -285,9 +351,13 @@ def __init__(self, parsed_operations: List[Operation]):
285351

286352
@snooper_to_methods(max_variable_length=None)
287353
class OpenAPIParser:
288-
def __init__(self, input_name: str, input_text: str,) -> None:
354+
def __init__(
355+
self, input_name: str, input_text: str, model_path: Optional[str] = None
356+
) -> None:
289357
self.input_name: str = input_name
290358
self.input_text: str = input_text
359+
if model_path:
360+
model_path_var.set(model_path)
291361

292362
def parse(self) -> ParsedObject:
293363
openapi = load_json_or_yaml(self.input_text)
@@ -312,7 +382,7 @@ def parse_paths(self, path_tree: Dict[str, Any]) -> ParsedObject:
312382
if me:
313383
continue
314384

315-
last = Path(path="/".join(tree), parent=parent)
385+
last = Path(path=UsefulStr("/".join(tree)), parent=parent)
316386

317387
paths.append(last)
318388

0 commit comments

Comments
 (0)