Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 64 additions & 40 deletions ariadne_codegen/client_generators/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
default_optional_fields_to_none: bool = False,
include_typename: bool = True,
ignore_extra_fields: bool = True,
models_only: bool = False,
) -> None:
self.package_path = Path(target_path) / package_name

Expand Down Expand Up @@ -150,21 +151,23 @@ def __init__(
self._unpacked_fragments: set[str] = set()
self._used_enums: list[str] = []

self.models_only = models_only
self.enable_custom_operations = enable_custom_operations
if self.enable_custom_operations:
self.files_to_include.append(self.base_schema_root_file_path)

def generate(self) -> list[str]:
"""Generate package with graphql client."""
self._include_exceptions()
if not self.models_only:
self._include_exceptions()
self._validate_unique_file_names()
if not self.package_path.exists():
self.package_path.mkdir()
self._generate_input_types()
self._generate_result_types()
self._generate_fragments()
self._copy_files()
if self.enable_custom_operations:
if not self.models_only and self.enable_custom_operations:
self._generate_custom_fields_typing()
self._generate_custom_fields()
self.client_generator.add_execute_custom_operation_method(self.async_client)
Expand All @@ -179,7 +182,8 @@ def generate(self) -> list[str]:
"mutation", OperationType.MUTATION.value.upper(), self.async_client
)

self._generate_client()
if not self.models_only:
self._generate_client()
self._generate_enums()
self._generate_init()

Expand Down Expand Up @@ -223,14 +227,15 @@ def add_operation(self, definition: OperationDefinitionNode):
query_types_generator.get_generated_public_names(), module_name, 1
)

self.client_generator.add_method(
definition=definition,
name=method_name,
return_type=return_type_name,
return_type_module=module_name,
operation_str=operation_str,
async_=self.async_client,
)
if not self.models_only:
self.client_generator.add_method(
definition=definition,
name=method_name,
return_type=return_type_name,
return_type_module=module_name,
operation_str=operation_str,
async_=self.async_client,
)

def _include_exceptions(self):
if self.base_client_file_path in (
Expand All @@ -247,18 +252,19 @@ def _include_exceptions(self):
)

def _validate_unique_file_names(self):
file_names = (
[
file_names = [
self.base_model_file_path.name,
f"{self.enums_module_name}.py",
f"{self.input_types_module_name}.py",
f"{self.fragments_module_name}.py",
]
if not self.models_only:
file_names += [
f"{self.client_file_name}.py",
self.base_client_file_path.name,
self.base_model_file_path.name,
f"{self.enums_module_name}.py",
f"{self.input_types_module_name}.py",
f"{self.fragments_module_name}.py",
]
+ list(self._result_types_files.keys())
+ [f.name for f in self.files_to_include]
)
file_names += list(self._result_types_files.keys())
file_names += [f.name for f in self.files_to_include]

if len(file_names) != len(set(file_names)):
seen = set()
Expand Down Expand Up @@ -310,7 +316,7 @@ def _generate_enums(self):
)

def _generate_input_types(self):
if self.include_all_inputs:
if self.include_all_inputs or self.models_only:
module = self.input_types_generator.generate()
else:
used_inputs = self.client_generator.arguments_generator.get_used_inputs()
Expand Down Expand Up @@ -359,10 +365,9 @@ def _generate_fragments(self):
)

def _copy_files(self):
files_to_copy = self.files_to_include + [
self.base_client_file_path,
self.base_model_file_path,
]
files_to_copy = self.files_to_include + [self.base_model_file_path]
if not self.models_only:
files_to_copy.append(self.base_client_file_path)
for source_path in files_to_copy:
code = self._add_comments_to_code(source_path.read_text(encoding="utf-8"))
if not self.ignore_extra_fields and source_path.name == "base_model.py":
Expand All @@ -373,11 +378,12 @@ def _copy_files(self):
target_path.write_text(code)
self._generated_files.append(target_path.name)

self.init_generator.add_import(
names=[self.base_client_name],
from_=self.base_client_file_path.stem,
level=1,
)
if not self.models_only:
self.init_generator.add_import(
names=[self.base_client_name],
from_=self.base_client_file_path.stem,
level=1,
)
self.init_generator.add_import(
names=[BASE_MODEL_CLASS_NAME, UPLOAD_CLASS_NAME],
from_=self.base_model_file_path.stem,
Expand Down Expand Up @@ -431,22 +437,39 @@ def get_package_generator(
fragments: list[FragmentDefinitionNode],
settings: ClientSettings,
plugin_manager: PluginManager,
models_only: bool = False,
) -> PackageGenerator:
init_generator = InitFileGenerator(plugin_manager=plugin_manager)
client_generator = ClientGenerator(
base_client_import=generate_import_from(
if models_only:
base_client_import = generate_import_from(
names=["object"], from_="builtins", level=0
)
client_name = "Client"
base_client = "object"
async_client = True
base_client_file_path = ""
client_file_name = "client"
else:
base_client_import = generate_import_from(
names=[settings.base_client_name],
from_=Path(settings.base_client_file_path).stem,
level=1,
),
)
client_name = settings.client_name
base_client = settings.base_client_name
async_client = settings.async_client
base_client_file_path = settings.base_client_file_path
client_file_name = settings.client_file_name
client_generator = ClientGenerator(
base_client_import=base_client_import,
arguments_generator=ArgumentsGenerator(
schema=schema,
convert_to_snake_case=settings.convert_to_snake_case,
custom_scalars=settings.scalars,
plugin_manager=plugin_manager,
),
name=settings.client_name,
base_client=settings.base_client_name,
name=client_name,
base_client=base_client,
enums_module_name=settings.enums_module_name,
input_types_module_name=settings.input_types_module_name,
unset_import=UNSET_IMPORT,
Expand Down Expand Up @@ -528,11 +551,11 @@ def get_package_generator(
input_types_generator=input_types_generator,
fragments_generator=fragments_generator,
fragments_definitions=fragments_definitions,
client_name=settings.client_name,
async_client=settings.async_client,
base_client_name=settings.base_client_name,
base_client_file_path=settings.base_client_file_path,
client_file_name=settings.client_file_name,
client_name=client_name,
async_client=async_client,
base_client_name=base_client,
base_client_file_path=base_client_file_path,
client_file_name=client_file_name,
enums_module_name=settings.enums_module_name,
input_types_module_name=settings.input_types_module_name,
fragments_module_name=settings.fragments_module_name,
Expand All @@ -557,4 +580,5 @@ def get_package_generator(
default_optional_fields_to_none=settings.default_optional_fields_to_none,
include_typename=settings.include_typename,
ignore_extra_fields=settings.ignore_extra_fields,
models_only=models_only,
)
58 changes: 57 additions & 1 deletion ariadne_codegen/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

from .client_generators.scalars import ScalarData
from .exceptions import ConfigFileNotFound, MissingConfiguration
from .settings import ClientSettings, CommentsStrategy, GraphQLSchemaSettings
from .settings import (
ClientSettings,
CommentsStrategy,
GraphQLSchemaSettings,
ModelsOnlySettings,
)

simplefilter("default", DeprecationWarning)

Expand Down Expand Up @@ -103,6 +108,57 @@ def get_section(config_dict: dict) -> dict:
raise MissingConfiguration(f"Config has no [{tool_key}.{codegen_key}] section.")


def get_models_only_settings(config_dict: dict) -> ModelsOnlySettings:
"""Parse configuration dict and return ModelsOnlySettings instance."""
section = get_section(config_dict).copy()
settings_fields_names = {f.name for f in fields(ModelsOnlySettings)}
try:
section["scalars"] = {
name: ScalarData(
type_=data["type"],
serialize=data.get("serialize"),
parse=data.get("parse"),
import_=data.get("import"),
)
for name, data in section.get("scalars", {}).items()
}
except KeyError as exc:
raise MissingConfiguration(
"Missing 'type' field for scalar definition"
) from exc

try:
if "include_comments" in section and isinstance(
section["include_comments"], bool
):
section["include_comments"] = (
CommentsStrategy.TIMESTAMP.value
if section["include_comments"]
else CommentsStrategy.NONE.value
)
options = ", ".join(strategy.value for strategy in CommentsStrategy)
warn(
"Support for boolean 'include_comments' value has been deprecated "
"and will be dropped in future release. "
f"Instead use one of following options: {options}",
DeprecationWarning,
stacklevel=2,
)

return ModelsOnlySettings(
**{
key: value
for key, value in section.items()
if key in settings_fields_names
}
)
except TypeError as exc:
missing_fields = settings_fields_names.difference(section)
raise MissingConfiguration(
f"Missing configuration fields: {', '.join(missing_fields)}"
) from exc


def get_graphql_schema_settings(config_dict: dict) -> GraphQLSchemaSettings:
"""Parse configuration dict and return GraphQLSchemaSettings instance."""
section = get_section(config_dict)
Expand Down
55 changes: 54 additions & 1 deletion ariadne_codegen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from graphql import assert_valid_schema

from .client_generators.package import get_package_generator
from .config import get_client_settings, get_config_dict, get_graphql_schema_settings
from .config import (
get_client_settings,
get_config_dict,
get_graphql_schema_settings,
get_models_only_settings,
)
from .graphql_schema_generators.schema import (
generate_graphql_schema_graphql_file,
generate_graphql_schema_python_file,
Expand Down Expand Up @@ -39,6 +44,9 @@ def main(strategy=Strategy.CLIENT.value, config=None):
if strategy == Strategy.GRAPHQL_SCHEMA:
graphql_schema(config_dict)

if strategy == Strategy.MODELS_ONLY:
models_only(config_dict)


def client(config_dict):
settings = get_client_settings(config_dict)
Expand Down Expand Up @@ -89,6 +97,51 @@ def client(config_dict):
sys.stdout.write("\nGenerated files:\n " + "\n ".join(generated_files) + "\n")


def models_only(config_dict):
settings = get_models_only_settings(config_dict)

if settings.schema_path:
schema = get_graphql_schema_from_path(settings.schema_path)
else:
schema = get_graphql_schema_from_url(
url=settings.remote_schema_url,
headers=settings.remote_schema_headers,
verify_ssl=settings.remote_schema_verify_ssl,
timeout=settings.remote_schema_timeout,
)

plugin_manager = PluginManager(
schema=schema,
config_dict=config_dict,
plugins_types=get_plugins_types(settings.plugins),
)
schema = add_mixin_directive_to_schema(schema)
schema = plugin_manager.process_schema(schema)
assert_valid_schema(schema)

fragments = []
queries = []
if settings.queries_path:
definitions = get_graphql_queries(settings.queries_path, schema)
queries = filter_operations_definitions(definitions)
fragments = filter_fragments_definitions(definitions)

sys.stdout.write(settings.used_settings_message)

package_generator = get_package_generator(
schema=schema,
fragments=fragments,
settings=settings,
plugin_manager=plugin_manager,
models_only=True,
)
for query in queries:
package_generator.add_operation(query)
generated_files = package_generator.generate()

sys.stdout.write("\nGenerated files:\n " + "\n ".join(generated_files) + "\n")


def graphql_schema(config_dict):
settings = get_graphql_schema_settings(config_dict)

Expand Down
Loading