From 7b68be7ccbc67261770dd31cbf21ff183b8ba999 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Sun, 21 Jun 2026 21:25:02 +0100 Subject: [PATCH 1/2] feat: add models only strategy --- ariadne_codegen/client_generators/package.py | 104 ++++++++++++------- ariadne_codegen/config.py | 58 ++++++++++- ariadne_codegen/main.py | 55 +++++++++- ariadne_codegen/settings.py | 93 +++++++++++++++++ 4 files changed, 268 insertions(+), 42 deletions(-) diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index c84e48f4..bdfa12bc 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -92,6 +92,7 @@ def __init__( default_optional_fields_to_none: bool = False, include_typename: bool = True, ignore_extra_fields: bool = True, + models_only: bool = False, ) -> None: self.package_path = Path(target_path) / package_name @@ -150,13 +151,15 @@ def __init__( self._unpacked_fragments: set[str] = set() self._used_enums: list[str] = [] + self.models_only = models_only self.enable_custom_operations = enable_custom_operations if self.enable_custom_operations: self.files_to_include.append(self.base_schema_root_file_path) def generate(self) -> list[str]: """Generate package with graphql client.""" - self._include_exceptions() + if not self.models_only: + self._include_exceptions() self._validate_unique_file_names() if not self.package_path.exists(): self.package_path.mkdir() @@ -164,7 +167,7 @@ def generate(self) -> list[str]: self._generate_result_types() self._generate_fragments() self._copy_files() - if self.enable_custom_operations: + if not self.models_only and self.enable_custom_operations: self._generate_custom_fields_typing() self._generate_custom_fields() self.client_generator.add_execute_custom_operation_method(self.async_client) @@ -179,7 +182,8 @@ def generate(self) -> list[str]: "mutation", OperationType.MUTATION.value.upper(), self.async_client ) - self._generate_client() + if not self.models_only: + self._generate_client() self._generate_enums() self._generate_init() @@ -223,14 +227,15 @@ def add_operation(self, definition: OperationDefinitionNode): query_types_generator.get_generated_public_names(), module_name, 1 ) - self.client_generator.add_method( - definition=definition, - name=method_name, - return_type=return_type_name, - return_type_module=module_name, - operation_str=operation_str, - async_=self.async_client, - ) + if not self.models_only: + self.client_generator.add_method( + definition=definition, + name=method_name, + return_type=return_type_name, + return_type_module=module_name, + operation_str=operation_str, + async_=self.async_client, + ) def _include_exceptions(self): if self.base_client_file_path in ( @@ -247,18 +252,19 @@ def _include_exceptions(self): ) def _validate_unique_file_names(self): - file_names = ( - [ + file_names = [ + self.base_model_file_path.name, + f"{self.enums_module_name}.py", + f"{self.input_types_module_name}.py", + f"{self.fragments_module_name}.py", + ] + if not self.models_only: + file_names += [ f"{self.client_file_name}.py", self.base_client_file_path.name, - self.base_model_file_path.name, - f"{self.enums_module_name}.py", - f"{self.input_types_module_name}.py", - f"{self.fragments_module_name}.py", ] - + list(self._result_types_files.keys()) - + [f.name for f in self.files_to_include] - ) + file_names += list(self._result_types_files.keys()) + file_names += [f.name for f in self.files_to_include] if len(file_names) != len(set(file_names)): seen = set() @@ -310,7 +316,7 @@ def _generate_enums(self): ) def _generate_input_types(self): - if self.include_all_inputs: + if self.include_all_inputs or self.models_only: module = self.input_types_generator.generate() else: used_inputs = self.client_generator.arguments_generator.get_used_inputs() @@ -359,10 +365,9 @@ def _generate_fragments(self): ) def _copy_files(self): - files_to_copy = self.files_to_include + [ - self.base_client_file_path, - self.base_model_file_path, - ] + files_to_copy = self.files_to_include + [self.base_model_file_path] + if not self.models_only: + files_to_copy.append(self.base_client_file_path) for source_path in files_to_copy: code = self._add_comments_to_code(source_path.read_text(encoding="utf-8")) if not self.ignore_extra_fields and source_path.name == "base_model.py": @@ -373,11 +378,12 @@ def _copy_files(self): target_path.write_text(code) self._generated_files.append(target_path.name) - self.init_generator.add_import( - names=[self.base_client_name], - from_=self.base_client_file_path.stem, - level=1, - ) + if not self.models_only: + self.init_generator.add_import( + names=[self.base_client_name], + from_=self.base_client_file_path.stem, + level=1, + ) self.init_generator.add_import( names=[BASE_MODEL_CLASS_NAME, UPLOAD_CLASS_NAME], from_=self.base_model_file_path.stem, @@ -427,22 +433,39 @@ def get_package_generator( fragments: list[FragmentDefinitionNode], settings: ClientSettings, plugin_manager: PluginManager, + models_only: bool = False, ) -> PackageGenerator: init_generator = InitFileGenerator(plugin_manager=plugin_manager) - client_generator = ClientGenerator( - base_client_import=generate_import_from( + if models_only: + base_client_import = generate_import_from( + names=["object"], from_="builtins", level=0 + ) + client_name = "Client" + base_client = "object" + async_client = True + base_client_file_path = "" + client_file_name = "client" + else: + base_client_import = generate_import_from( names=[settings.base_client_name], from_=Path(settings.base_client_file_path).stem, level=1, - ), + ) + client_name = settings.client_name + base_client = settings.base_client_name + async_client = settings.async_client + base_client_file_path = settings.base_client_file_path + client_file_name = settings.client_file_name + client_generator = ClientGenerator( + base_client_import=base_client_import, arguments_generator=ArgumentsGenerator( schema=schema, convert_to_snake_case=settings.convert_to_snake_case, custom_scalars=settings.scalars, plugin_manager=plugin_manager, ), - name=settings.client_name, - base_client=settings.base_client_name, + name=client_name, + base_client=base_client, enums_module_name=settings.enums_module_name, input_types_module_name=settings.input_types_module_name, unset_import=UNSET_IMPORT, @@ -524,11 +547,11 @@ def get_package_generator( input_types_generator=input_types_generator, fragments_generator=fragments_generator, fragments_definitions=fragments_definitions, - client_name=settings.client_name, - async_client=settings.async_client, - base_client_name=settings.base_client_name, - base_client_file_path=settings.base_client_file_path, - client_file_name=settings.client_file_name, + client_name=client_name, + async_client=async_client, + base_client_name=base_client, + base_client_file_path=base_client_file_path, + client_file_name=client_file_name, enums_module_name=settings.enums_module_name, input_types_module_name=settings.input_types_module_name, fragments_module_name=settings.fragments_module_name, @@ -553,4 +576,5 @@ def get_package_generator( default_optional_fields_to_none=settings.default_optional_fields_to_none, include_typename=settings.include_typename, ignore_extra_fields=settings.ignore_extra_fields, + models_only=models_only, ) diff --git a/ariadne_codegen/config.py b/ariadne_codegen/config.py index 7c14228a..08c97131 100644 --- a/ariadne_codegen/config.py +++ b/ariadne_codegen/config.py @@ -7,7 +7,12 @@ from .client_generators.scalars import ScalarData from .exceptions import ConfigFileNotFound, MissingConfiguration -from .settings import ClientSettings, CommentsStrategy, GraphQLSchemaSettings +from .settings import ( + ClientSettings, + CommentsStrategy, + GraphQLSchemaSettings, + ModelsOnlySettings, +) simplefilter("default", DeprecationWarning) @@ -103,6 +108,57 @@ def get_section(config_dict: dict) -> dict: raise MissingConfiguration(f"Config has no [{tool_key}.{codegen_key}] section.") +def get_models_only_settings(config_dict: dict) -> ModelsOnlySettings: + """Parse configuration dict and return ModelsOnlySettings instance.""" + section = get_section(config_dict).copy() + settings_fields_names = {f.name for f in fields(ModelsOnlySettings)} + try: + section["scalars"] = { + name: ScalarData( + type_=data["type"], + serialize=data.get("serialize"), + parse=data.get("parse"), + import_=data.get("import"), + ) + for name, data in section.get("scalars", {}).items() + } + except KeyError as exc: + raise MissingConfiguration( + "Missing 'type' field for scalar definition" + ) from exc + + try: + if "include_comments" in section and isinstance( + section["include_comments"], bool + ): + section["include_comments"] = ( + CommentsStrategy.TIMESTAMP.value + if section["include_comments"] + else CommentsStrategy.NONE.value + ) + options = ", ".join(strategy.value for strategy in CommentsStrategy) + warn( + "Support for boolean 'include_comments' value has been deprecated " + "and will be dropped in future release. " + f"Instead use one of following options: {options}", + DeprecationWarning, + stacklevel=2, + ) + + return ModelsOnlySettings( + **{ + key: value + for key, value in section.items() + if key in settings_fields_names + } + ) + except TypeError as exc: + missing_fields = settings_fields_names.difference(section) + raise MissingConfiguration( + f"Missing configuration fields: {', '.join(missing_fields)}" + ) from exc + + def get_graphql_schema_settings(config_dict: dict) -> GraphQLSchemaSettings: """Parse configuration dict and return GraphQLSchemaSettings instance.""" section = get_section(config_dict) diff --git a/ariadne_codegen/main.py b/ariadne_codegen/main.py index ea71f152..fa00479d 100644 --- a/ariadne_codegen/main.py +++ b/ariadne_codegen/main.py @@ -4,7 +4,12 @@ from graphql import assert_valid_schema from .client_generators.package import get_package_generator -from .config import get_client_settings, get_config_dict, get_graphql_schema_settings +from .config import ( + get_client_settings, + get_config_dict, + get_graphql_schema_settings, + get_models_only_settings, +) from .graphql_schema_generators.schema import ( generate_graphql_schema_graphql_file, generate_graphql_schema_python_file, @@ -39,6 +44,9 @@ def main(strategy=Strategy.CLIENT.value, config=None): if strategy == Strategy.GRAPHQL_SCHEMA: graphql_schema(config_dict) + if strategy == Strategy.MODELS_ONLY: + models_only(config_dict) + def client(config_dict): settings = get_client_settings(config_dict) @@ -84,6 +92,51 @@ def client(config_dict): sys.stdout.write("\nGenerated files:\n " + "\n ".join(generated_files) + "\n") +def models_only(config_dict): + settings = get_models_only_settings(config_dict) + + if settings.schema_path: + schema = get_graphql_schema_from_path(settings.schema_path) + else: + schema = get_graphql_schema_from_url( + url=settings.remote_schema_url, + headers=settings.remote_schema_headers, + verify_ssl=settings.remote_schema_verify_ssl, + timeout=settings.remote_schema_timeout, + ) + + plugin_manager = PluginManager( + schema=schema, + config_dict=config_dict, + plugins_types=get_plugins_types(settings.plugins), + ) + schema = add_mixin_directive_to_schema(schema) + schema = plugin_manager.process_schema(schema) + assert_valid_schema(schema) + + fragments = [] + queries = [] + if settings.queries_path: + definitions = get_graphql_queries(settings.queries_path, schema) + queries = filter_operations_definitions(definitions) + fragments = filter_fragments_definitions(definitions) + + sys.stdout.write(settings.used_settings_message) + + package_generator = get_package_generator( + schema=schema, + fragments=fragments, + settings=settings, + plugin_manager=plugin_manager, + models_only=True, + ) + for query in queries: + package_generator.add_operation(query) + generated_files = package_generator.generate() + + sys.stdout.write("\nGenerated files:\n " + "\n ".join(generated_files) + "\n") + + def graphql_schema(config_dict): settings = get_graphql_schema_settings(config_dict) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 74d29e17..89fcf308 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -29,6 +29,7 @@ class CommentsStrategy(str, enum.Enum): class Strategy(str, enum.Enum): CLIENT = "client" GRAPHQL_SCHEMA = "graphqlschema" + MODELS_ONLY = "models_only" @dataclass @@ -200,6 +201,98 @@ def used_settings_message(self) -> str: ) +@dataclass +class ModelsOnlySettings(BaseSettings): + queries_path: str = "" + target_package_name: str = "graphql_client" + target_package_path: str = field(default_factory=lambda: Path.cwd().as_posix()) + enums_module_name: str = "enums" + input_types_module_name: str = "input_types" + fragments_module_name: str = "fragments" + include_comments: CommentsStrategy = field(default=CommentsStrategy.STABLE) + convert_to_snake_case: bool = True + include_all_inputs: bool = True + include_all_enums: bool = True + files_to_include: list[str] = field(default_factory=list) + scalars: dict[str, ScalarData] = field(default_factory=dict) + default_optional_fields_to_none: bool = False + include_typename: bool = True + ignore_extra_fields: bool = True + + def __post_init__(self): + super().__post_init__() + + try: + self.include_comments = CommentsStrategy(self.include_comments) + except ValueError as exc: + valid_options = ", ".join(strategy.value for strategy in CommentsStrategy) + raise InvalidConfiguration( + f"'{self.include_comments}' is not a valid choice. " + f"Valid options are: {valid_options}" + ) from exc + + for name, data in self.scalars.items(): + data.graphql_name = name + + if self.queries_path: + assert_path_exists(self.queries_path) + + assert_string_is_valid_python_identifier(self.target_package_name) + assert_path_is_valid_directory(self.target_package_path) + + assert_string_is_valid_python_identifier(self.enums_module_name) + assert_string_is_valid_python_identifier(self.input_types_module_name) + assert_string_is_valid_python_identifier(self.fragments_module_name) + + for file_path in self.files_to_include: + assert_path_is_valid_file(file_path) + + @property + def schema_source(self) -> str: + return self.schema_path if self.schema_path else self.remote_schema_url + + @property + def used_settings_message(self) -> str: + snake_case_msg = ( + "Converting fields and arguments name to snake case." + if self.convert_to_snake_case + else "Not converting fields and arguments name to snake case." + ) + files_to_include_list = ",".join(self.files_to_include) + files_to_include_msg = ( + f"Copying the following files into the package: {files_to_include_list}" + if self.files_to_include + else "No files to copy." + ) + plugins_list = ",".join(self.plugins) + plugins_msg = ( + f"Plugins to use: {plugins_list}" + if self.plugins + else "No plugin is being used." + ) + queries_msg = ( + f"Reading queries from '{self.queries_path}'." + if self.queries_path + else "No queries path provided, generating models only." + ) + return dedent( + f"""\ + Selected strategy: {Strategy.MODELS_ONLY} + Using schema from '{self.schema_source}'. + {queries_msg} + Using '{self.target_package_name}' as package name. + Generating package into '{self.target_package_path}'. + Generating enums into '{self.enums_module_name}.py'. + Generating inputs into '{self.input_types_module_name}.py'. + Generating fragments into '{self.fragments_module_name}.py'. + Comments type: {self.include_comments.value} + {snake_case_msg} + {files_to_include_msg} + {plugins_msg} + """ + ) + + @dataclass class GraphQLSchemaSettings(BaseSettings): target_file_path: str = "schema.py" From b7da2c34790fce83c767b413bf70ccdbced14a59 Mon Sep 17 00:00:00 2001 From: Amin Farjadi Date: Sun, 21 Jun 2026 21:25:21 +0100 Subject: [PATCH 2/2] add tests --- .../package_generator/test_generated_files.py | 45 ++++++++++++++++++ .../example/expected_client/__init__.py | 16 +++++++ .../example/expected_client/base_model.py | 28 +++++++++++ .../example/expected_client/create_user.py | 17 +++++++ .../example/expected_client/enums.py | 9 ++++ .../example/expected_client/input_types.py | 13 +++++ .../example/expected_client/list_users.py | 20 ++++++++ tests/main/models_only/example/pyproject.toml | 5 ++ .../main/models_only/example/queries.graphql | 15 ++++++ tests/main/models_only/example/schema.graphql | 35 ++++++++++++++ tests/main/test_main.py | 32 +++++++++++++ tests/test_settings.py | 47 ++++++++++++++++++- 12 files changed, 281 insertions(+), 1 deletion(-) create mode 100644 tests/main/models_only/example/expected_client/__init__.py create mode 100644 tests/main/models_only/example/expected_client/base_model.py create mode 100644 tests/main/models_only/example/expected_client/create_user.py create mode 100644 tests/main/models_only/example/expected_client/enums.py create mode 100644 tests/main/models_only/example/expected_client/input_types.py create mode 100644 tests/main/models_only/example/expected_client/list_users.py create mode 100644 tests/main/models_only/example/pyproject.toml create mode 100644 tests/main/models_only/example/queries.graphql create mode 100644 tests/main/models_only/example/schema.graphql diff --git a/tests/client_generators/package_generator/test_generated_files.py b/tests/client_generators/package_generator/test_generated_files.py index d0e12932..0243ca4d 100644 --- a/tests/client_generators/package_generator/test_generated_files.py +++ b/tests/client_generators/package_generator/test_generated_files.py @@ -764,3 +764,48 @@ def test_generate_creates_client_with_custom_scalars_imports( f"{generator.client_file_name}.py" ).open() as client_file: assert "from .abc import ScalarABC" in client_file.read() + + +def test_generate_models_only(tmp_path, schema, async_base_client_import): + package_name = "test_graphql_client" + generator = PackageGenerator( + package_name=package_name, + target_path=tmp_path.as_posix(), + schema=schema, + init_generator=InitFileGenerator(), + client_generator=ClientGenerator( + base_client_import=async_base_client_import, + arguments_generator=ArgumentsGenerator(schema=schema), + ), + enums_generator=EnumsGenerator(schema=schema), + input_types_generator=InputTypesGenerator(schema=schema), + fragments_generator=FragmentsGenerator(schema=schema, fragments_definitions={}), + models_only=True, + ) + query_str = """ + query CustomQuery($id: ID!) { + query1(id: $id) { + field1 + } + } + """ + generator.add_operation(parse(query_str).definitions[0]) + generated_files = generator.generate() + + package_path = tmp_path / package_name + assert (package_path / "__init__.py").exists() + assert (package_path / "base_model.py").exists() + assert (package_path / f"{generator.enums_module_name}.py").exists() + assert (package_path / f"{generator.input_types_module_name}.py").exists() + assert (package_path / "custom_query.py").exists() + assert "custom_query.py" in generated_files + assert not (package_path / "client.py").exists() + assert not (package_path / generator.base_client_file_path.name).exists() + assert not (package_path / EXCEPTIONS_FILE_PATH.name).exists() + assert "client.py" not in generated_files + assert EXCEPTIONS_FILE_PATH.name not in generated_files + init_content = (package_path / "__init__.py").read_text() + assert "from .base_model import BaseModel, Upload" in init_content + assert "Client" not in init_content + assert "AsyncBaseClient" not in init_content + assert "GraphQLClientError" not in init_content diff --git a/tests/main/models_only/example/expected_client/__init__.py b/tests/main/models_only/example/expected_client/__init__.py new file mode 100644 index 00000000..55bb6be1 --- /dev/null +++ b/tests/main/models_only/example/expected_client/__init__.py @@ -0,0 +1,16 @@ +from .base_model import BaseModel, Upload +from .create_user import CreateUser, CreateUserUserCreate +from .enums import Color +from .input_types import UserCreateInput +from .list_users import ListUsers, ListUsersUsers + +__all__ = [ + "BaseModel", + "Color", + "CreateUser", + "CreateUserUserCreate", + "ListUsers", + "ListUsersUsers", + "Upload", + "UserCreateInput", +] diff --git a/tests/main/models_only/example/expected_client/base_model.py b/tests/main/models_only/example/expected_client/base_model.py new file mode 100644 index 00000000..68e2f9ea --- /dev/null +++ b/tests/main/models_only/example/expected_client/base_model.py @@ -0,0 +1,28 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel +from pydantic import ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/models_only/example/expected_client/create_user.py b/tests/main/models_only/example/expected_client/create_user.py new file mode 100644 index 00000000..cafe1a76 --- /dev/null +++ b/tests/main/models_only/example/expected_client/create_user.py @@ -0,0 +1,17 @@ +from typing import Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class CreateUser(BaseModel): + user_create: Optional["CreateUserUserCreate"] = Field(alias="userCreate") + + +class CreateUserUserCreate(BaseModel): + id: str + email: str + + +CreateUser.model_rebuild() diff --git a/tests/main/models_only/example/expected_client/enums.py b/tests/main/models_only/example/expected_client/enums.py new file mode 100644 index 00000000..c8efdc27 --- /dev/null +++ b/tests/main/models_only/example/expected_client/enums.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class Color(str, Enum): + BLACK = "BLACK" + WHITE = "WHITE" + RED = "RED" + GREEN = "GREEN" + BLUE = "BLUE" diff --git a/tests/main/models_only/example/expected_client/input_types.py b/tests/main/models_only/example/expected_client/input_types.py new file mode 100644 index 00000000..77db3d64 --- /dev/null +++ b/tests/main/models_only/example/expected_client/input_types.py @@ -0,0 +1,13 @@ +from typing import Optional + +from pydantic import Field + +from .base_model import BaseModel +from .enums import Color + + +class UserCreateInput(BaseModel): + first_name: Optional[str] = Field(alias="firstName", default=None) + last_name: Optional[str] = Field(alias="lastName", default=None) + email: str + favourite_color: Optional[Color] = Field(alias="favouriteColor", default=None) diff --git a/tests/main/models_only/example/expected_client/list_users.py b/tests/main/models_only/example/expected_client/list_users.py new file mode 100644 index 00000000..831e2c47 --- /dev/null +++ b/tests/main/models_only/example/expected_client/list_users.py @@ -0,0 +1,20 @@ +from typing import Optional + +from pydantic import Field + +from .base_model import BaseModel +from .enums import Color + + +class ListUsers(BaseModel): + users: list["ListUsersUsers"] + + +class ListUsersUsers(BaseModel): + id: str + first_name: Optional[str] = Field(alias="firstName") + email: str + favourite_color: Optional[Color] = Field(alias="favouriteColor") + + +ListUsers.model_rebuild() diff --git a/tests/main/models_only/example/pyproject.toml b/tests/main/models_only/example/pyproject.toml new file mode 100644 index 00000000..7cfc9d3b --- /dev/null +++ b/tests/main/models_only/example/pyproject.toml @@ -0,0 +1,5 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +queries_path = "queries.graphql" +include_comments = "none" +target_package_name = "example_client" diff --git a/tests/main/models_only/example/queries.graphql b/tests/main/models_only/example/queries.graphql new file mode 100644 index 00000000..748564fb --- /dev/null +++ b/tests/main/models_only/example/queries.graphql @@ -0,0 +1,15 @@ +mutation CreateUser($userData: UserCreateInput!) { + userCreate(userData: $userData) { + id + email + } +} + +query ListUsers { + users { + id + firstName + email + favouriteColor + } +} diff --git a/tests/main/models_only/example/schema.graphql b/tests/main/models_only/example/schema.graphql new file mode 100644 index 00000000..332f8d6c --- /dev/null +++ b/tests/main/models_only/example/schema.graphql @@ -0,0 +1,35 @@ +schema { + query: Query + mutation: Mutation +} + +type Query { + users(country: String): [User!]! +} + +type Mutation { + userCreate(userData: UserCreateInput!): User +} + +input UserCreateInput { + firstName: String + lastName: String + email: String! + favouriteColor: Color +} + +type User { + id: ID! + firstName: String + lastName: String + email: String! + favouriteColor: Color +} + +enum Color { + BLACK + WHITE + RED + GREEN + BLUE +} diff --git a/tests/main/test_main.py b/tests/main/test_main.py index 23a9f6d5..892d6dd6 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -16,6 +16,7 @@ CLIENTS_PATH = Path(__file__).parent / "clients" GRAPHQL_SCHEMAS_PATH = Path(__file__).parent / "graphql_schemas" +MODELS_ONLY_PATH = Path(__file__).parent / "models_only" @pytest.fixture(scope="function") @@ -438,3 +439,34 @@ def normalise(schema: str) -> str: expected = normalise(expected_file_path.read_text()) assert actual == expected + + +@pytest.mark.parametrize( + "project_dir, package_name, expected_package_path", + [ + ( + ( + MODELS_ONLY_PATH / "example" / "pyproject.toml", + ( + MODELS_ONLY_PATH / "example" / "schema.graphql", + MODELS_ONLY_PATH / "example" / "queries.graphql", + ), + ), + "example_client", + MODELS_ONLY_PATH / "example" / "expected_client", + ), + ], + indirect=["project_dir"], +) +def test_main_generates_correct_models_only_package( + project_dir, package_name, expected_package_path +): + result = CliRunner().invoke(main, "models_only") + + assert result.exit_code == 0 + package_path = project_dir / package_name + assert package_path.is_dir() + assert_the_same_files_in_directories(package_path, expected_package_path) + assert not (package_path / "client.py").exists() + assert not (package_path / "async_base_client.py").exists() + assert not (package_path / "exceptions.py").exists() diff --git a/tests/test_settings.py b/tests/test_settings.py index 047bd1e1..fcc0a71a 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -10,7 +10,11 @@ base_client, base_client_open_telemetry, ) -from ariadne_codegen.config import ClientSettings, GraphQLSchemaSettings +from ariadne_codegen.config import ( + ClientSettings, + GraphQLSchemaSettings, + ModelsOnlySettings, +) from ariadne_codegen.exceptions import InvalidConfiguration @@ -479,3 +483,44 @@ def test_client_settings_include_typename_can_be_set_to_true(tmp_path): ) assert settings.include_typename is True + + +def test_models_only_settings_can_be_created_without_queries_path(tmp_path): + schema_path = tmp_path / "schema.graphql" + schema_path.touch() + + settings = ModelsOnlySettings(schema_path=schema_path.as_posix()) + + assert settings.queries_path == "" + assert settings.target_package_name == "graphql_client" + assert settings.enums_module_name == "enums" + assert settings.input_types_module_name == "input_types" + assert settings.fragments_module_name == "fragments" + + +def test_models_only_settings_used_settings_message_contains_strategy(tmp_path): + schema_path = tmp_path / "schema.graphql" + schema_path.touch() + + settings = ModelsOnlySettings(schema_path=schema_path.as_posix()) + message = settings.used_settings_message + + assert "models_only" in message + assert schema_path.as_posix() in message + assert "Client" not in message + assert "base_client" not in message + + +def test_models_only_settings_accepts_optional_queries_path(tmp_path): + schema_path = tmp_path / "schema.graphql" + schema_path.touch() + queries_path = tmp_path / "queries.graphql" + queries_path.touch() + + settings = ModelsOnlySettings( + schema_path=schema_path.as_posix(), + queries_path=queries_path.as_posix(), + ) + + assert settings.queries_path == queries_path.as_posix() + assert "queries" in settings.used_settings_message