Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
default_language_version:
python: "3.9"
python: "3.10"

default_stages:
- pre-commit
Expand All @@ -19,6 +19,7 @@ repos:
- id: detect-private-key
- id: debug-statements
- id: end-of-file-fixer
exclude: "^.*__snapshots__/.*$"
- id: trailing-whitespace
exclude: "^.*__snapshots__/.*$"

Expand All @@ -36,7 +37,7 @@ repos:
rev: v1.7.7
hooks:
- id: actionlint
args: ["-shellcheck", ""]
args: [ "-shellcheck", "" ]

# Commitizen
- repo: https://github.com/commitizen-tools/commitizen
Expand All @@ -48,24 +49,24 @@ repos:
rev: v0.9.0
hooks:
- id: unasyncd
additional_dependencies: ["ruff"]
additional_dependencies: [ "ruff" ]

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.13.0
hooks:
- id: ruff-check
types_or: [python, pyi]
args: [--fix]
types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format
types_or: [python, pyi]
types_or: [ python, pyi ]

- repo: local
hooks:
- id: lint
name: lint
entry: mise run lint:pre-commit
language: python
types: [python]
types: [ python ]
require_serial: true

- repo: local
Expand Down
1 change: 0 additions & 1 deletion examples/testapp/testapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from litestar import Litestar
from litestar.plugins.sqlalchemy import EngineConfig, SQLAlchemyAsyncConfig, SQLAlchemyPlugin

from strawberry.litestar import BaseContext, make_graphql_controller

from .models import Base
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from nox import Session

SUPPORTED_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13"]
SUPPORTED_PYTHON_VERSIONS = ["3.10", "3.11", "3.12", "3.13"]
COMMON_PYTEST_OPTIONS = ["-n=2", "--showlocals", "-vv"]

here = Path(__file__).parent
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ reportPrivateUsage = false
reportUnnecessaryTypeIgnoreComment = true
reportImplicitOverride = true
reportPropertyTypeMismatch = true
reportShadowedImports = true
strictGenericNarrowing = true

[tool.bumpversion]
Expand Down Expand Up @@ -370,7 +369,7 @@ max-complexity = 12
convention = "google"

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["TC001", "UP037", "PLR2004"]
"tests/*" = ["TC001", "UP037", "PLR2004", "PLC0415"]

[tool.unasyncd]
add_editors_note = true
Expand Down
15 changes: 14 additions & 1 deletion src/strawchemy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,20 @@

from importlib.util import find_spec

__all__ = ("GEO_INSTALLED",)
__all__ = (
"AGGREGATIONS_KEY",
"DATA_KEY",
"DISTINCT_ON_KEY",
"FILTER_KEY",
"GEO_INSTALLED",
"JSON_PATH_KEY",
"LIMIT_KEY",
"NODES_KEY",
"OFFSET_KEY",
"ORDER_BY_KEY",
"UPSERT_CONFLICT_FIELDS",
"UPSERT_UPDATE_FIELDS",
)

GEO_INSTALLED: bool = all(find_spec(package) is not None for package in ("geoalchemy2", "shapely"))

Expand Down
7 changes: 2 additions & 5 deletions src/strawchemy/dto/backend/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,12 @@ def build(
if model_module := getmodule(self.dto_base):
module = model_module.__name__

dto = create_model(
dto = create_model( # pyright: ignore[reportCallIssue]
name,
__base__=(self.dto_base,),
__config__=None,
__module__=module,
__validators__=None,
__doc__=f"Pydantic generated DTO for {model.__name__} model" if docstring else None,
__cls_kwargs__=None,
**fields,
**fields, # pyright: ignore[reportArgumentType]
)

if config_dict:
Expand Down
2 changes: 1 addition & 1 deletion src/strawchemy/dto/backend/strawberry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from types import new_class
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, get_origin

from strawberry.types.field import StrawberryField
from typing_extensions import override

import strawberry
from strawberry.types.field import StrawberryField
from strawchemy.dto.base import DTOBackend, DTOBase, MappedDTO, ModelFieldT, ModelT
from strawchemy.dto.types import DTOMissing
from strawchemy.utils import get_annotations
Expand Down
38 changes: 20 additions & 18 deletions src/strawchemy/dto/inspectors/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,6 @@
from inspect import getmodule, signature
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast, get_args, get_origin, get_type_hints

from typing_extensions import TypeIs, override

from sqlalchemy import (
Column,
PrimaryKeyConstraint,
Sequence,
SQLColumnExpression,
Table,
UniqueConstraint,
event,
inspect,
orm,
sql,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import (
NO_VALUE,
Expand All @@ -35,6 +21,21 @@
RelationshipProperty,
registry,
)
from typing_extensions import TypeIs, override

from sqlalchemy import (
Column,
ColumnElement,
PrimaryKeyConstraint,
Sequence,
SQLColumnExpression,
Table,
UniqueConstraint,
event,
inspect,
orm,
sql,
)
from strawchemy.constants import GEO_INSTALLED
from strawchemy.dto.base import TYPING_NS, DTOFieldDefinition, ModelInspector, Relation
from strawchemy.dto.constants import DTO_INFO_KEY
Expand All @@ -47,9 +48,9 @@
from types import ModuleType

from shapely import Geometry

from sqlalchemy.orm import MapperProperty
from sqlalchemy.sql.schema import ColumnCollectionConstraint

from strawchemy.graph import Node


Expand Down Expand Up @@ -247,7 +248,7 @@ def _relationship_required(self, prop: RelationshipProperty[Any]) -> bool:
return False

def _field_definitions_from_columns(
self, model: type[DeclarativeBase], columns: Iterable[Column[Any]], dto_config: DTOConfig
self, model: type[DeclarativeBase], columns: Iterable[ColumnElement[Any]], dto_config: DTOConfig
) -> list[tuple[str, DTOFieldDefinition[DeclarativeBase, QueryableAttribute[Any]]]]:
mapper = inspect(model)
type_hints = self.get_type_hints(model)
Expand All @@ -262,11 +263,12 @@ def _field_definitions_from_columns(
),
)
for column in columns
if column.key
]

@classmethod
def pk_attributes(cls, mapper: Mapper[Any]) -> list[QueryableAttribute[Any]]:
return [mapper.attrs[column.key].class_attribute for column in mapper.primary_key]
return [mapper.attrs[column.key].class_attribute for column in mapper.primary_key if column.key]

@classmethod
def loaded_attributes(cls, model: DeclarativeBase) -> set[str]:
Expand Down Expand Up @@ -315,7 +317,7 @@ def field_definition(

# If column type is a geoalchemy geometry type, override type hint with the corresponding shapely type
if GEO_INSTALLED and (column_prop := mapper.columns.get(model_field.key)) is not None:
from geoalchemy2 import Geometry
from geoalchemy2 import Geometry # noqa: PLC0415

if (
isinstance(column_prop.type, Geometry)
Expand Down
2 changes: 1 addition & 1 deletion src/strawchemy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
if TYPE_CHECKING:
from collections.abc import Callable, Generator, Hashable

__all__ = ("GraphMetadata", "IterationMode", "MatchOn", "Node", "NodeMetadataT", "NodeValueT")
__all__ = ("GraphMetadata", "IterationMode", "MatchOn", "Node", "NodeMetadataT", "NodeValueT", "merge_trees")

T = TypeVar("T")
NodeValueT = TypeVar("NodeValueT", bound="Any")
Expand Down
21 changes: 11 additions & 10 deletions src/strawchemy/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from strawberry.annotation import StrawberryAnnotation
from strawberry.schema.config import StrawberryConfig

from strawchemy.strawberry.factories.aggregations import EnumDTOFactory
from strawchemy.strawberry.factories.enum import EnumDTOBackend, UpsertConflictFieldsEnumDTOBackend

Expand Down Expand Up @@ -37,10 +38,10 @@
from collections.abc import Callable, Mapping, Sequence

from sqlalchemy.orm import DeclarativeBase
from strawberry import BasePermission
from strawberry.extensions.field_extension import FieldExtension
from strawberry.types.arguments import StrawberryArgument
from strawberry.types.field import _RESOLVER_TYPE

from strawberry import BasePermission
from strawchemy.sqlalchemy.hook import QueryHook
from strawchemy.validation.pydantic import PydanticMapper

Expand Down Expand Up @@ -157,14 +158,14 @@ def pydantic(self) -> PydanticMapper:
Returns:
An instance of PydanticMapper.
"""
from .validation.pydantic import PydanticMapper
from .validation.pydantic import PydanticMapper # noqa: PLC0415

return PydanticMapper(self)

@overload
def field(
self,
resolver: _RESOLVER_TYPE[Any],
resolver: Any,
*,
filter_input: Optional[type[BooleanFilterDTO]] = None,
order_by: Optional[type[OrderByDTO]] = None,
Expand Down Expand Up @@ -220,7 +221,7 @@ def field(

def field(
self,
resolver: Optional[_RESOLVER_TYPE[Any]] = None,
resolver: Optional[Any] = None,
*,
filter_input: Optional[type[BooleanFilterDTO]] = None,
order_by: Optional[type[OrderByDTO]] = None,
Expand Down Expand Up @@ -325,7 +326,7 @@ def field(
def create(
self,
input_type: type[MappedGraphQLDTO[T]],
resolver: Optional[_RESOLVER_TYPE[Any]] = None,
resolver: Optional[Any] = None,
*,
repository_type: Optional[AnyRepository] = None,
name: Optional[str] = None,
Expand Down Expand Up @@ -403,7 +404,7 @@ def upsert(
input_type: type[MappedGraphQLDTO[T]],
update_fields: type[EnumDTO],
conflict_fields: type[EnumDTO],
resolver: Optional[_RESOLVER_TYPE[Any]] = None,
resolver: Optional[Any] = None,
*,
repository_type: Optional[AnyRepository] = None,
name: Optional[str] = None,
Expand Down Expand Up @@ -487,7 +488,7 @@ def update(
self,
input_type: type[MappedGraphQLDTO[T]],
filter_input: type[BooleanFilterDTO],
resolver: Optional[_RESOLVER_TYPE[Any]] = None,
resolver: Optional[Any] = None,
*,
repository_type: Optional[AnyRepository] = None,
name: Optional[str] = None,
Expand Down Expand Up @@ -568,7 +569,7 @@ def update(
def update_by_ids(
self,
input_type: type[MappedGraphQLDTO[T]],
resolver: Optional[_RESOLVER_TYPE[Any]] = None,
resolver: Optional[Any] = None,
*,
repository_type: Optional[AnyRepository] = None,
name: Optional[str] = None,
Expand Down Expand Up @@ -647,7 +648,7 @@ def update_by_ids(
def delete(
self,
filter_input: Optional[type[BooleanFilterDTO]] = None,
resolver: Optional[_RESOLVER_TYPE[Any]] = None,
resolver: Optional[Any] = None,
*,
repository_type: Optional[AnyRepository] = None,
name: Optional[str] = None,
Expand Down
26 changes: 13 additions & 13 deletions src/strawchemy/sqlalchemy/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, Optional, Union, cast

from sqlalchemy.orm import (
QueryableAttribute,
RelationshipDirection,
RelationshipProperty,
aliased,
class_mapper,
raiseload,
)
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql.elements import NamedColumn
from typing_extensions import Self

from sqlalchemy import (
Expand All @@ -22,17 +32,6 @@
null,
select,
)
from sqlalchemy.orm import (
QueryableAttribute,
RelationshipDirection,
RelationshipProperty,
aliased,
class_mapper,
raiseload,
)
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import ColumnElement, SQLColumnExpression
from sqlalchemy.sql.elements import NamedColumn
from strawchemy.constants import AGGREGATIONS_KEY, NODES_KEY
from strawchemy.graph import merge_trees
from strawchemy.strawberry.dto import (
Expand All @@ -52,8 +51,10 @@
from collections.abc import Sequence

from sqlalchemy.orm.strategy_options import _AbstractLoad
from sqlalchemy.sql import ColumnElement, SQLColumnExpression
from sqlalchemy.sql._typing import _OnClauseArgument
from sqlalchemy.sql.selectable import NamedFromClause

from strawchemy.config.databases import DatabaseFeatures
from strawchemy.strawberry.typing import QueryNodeType

Expand Down Expand Up @@ -657,8 +658,7 @@ def _distinct_on(self, statement: Select[Any], order_by_expressions: list[UnaryE
*[
expression.element
for expression in order_by_expressions
if isinstance(expression.element, ColumnElement)
and not any(elem.compare(expression.element) for elem in statement.selected_columns)
if not any(elem.compare(expression.element) for elem in statement.selected_columns)
]
)
statement = statement.distinct(*distinct_expressions)
Expand Down
Loading