diff --git a/strawberry_django_plus/field.py b/strawberry_django_plus/field.py index 546df64..de042b9 100644 --- a/strawberry_django_plus/field.py +++ b/strawberry_django_plus/field.py @@ -57,7 +57,7 @@ from .descriptors import ModelProperty from .permissions import filter_with_perms from .utils import resolvers -from .utils.typing import TypeOrSequence +from .utils.typing import TypeOrMapping, TypeOrSequence if TYPE_CHECKING: from strawberry_django_plus.type import StrawberryDjangoType @@ -87,6 +87,7 @@ def __init__( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None, + annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None, disable_optimization: bool = False, **kwargs, ): @@ -95,6 +96,7 @@ def __init__( only=only, select_related=select_related, prefetch_related=prefetch_related, + annotate=annotate, ) super().__init__(*args, **kwargs) @@ -486,6 +488,7 @@ def field( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None, + annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None, disable_optimization: bool = False, extensions: List[FieldExtension] = (), # type: ignore ) -> _T: @@ -513,6 +516,7 @@ def field( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None, + annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None, disable_optimization: bool = False, extensions: List[FieldExtension] = (), # type: ignore ) -> Any: @@ -540,6 +544,7 @@ def field( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None, + annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None, disable_optimization: bool = False, extensions: List[FieldExtension] = (), # type: ignore ) -> StrawberryDjangoField: @@ -566,6 +571,7 @@ def field( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None, + annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None, disable_optimization: bool = False, extensions: List[FieldExtension] = (), # type: ignore # This init parameter is used by pyright to determine whether this field @@ -606,6 +612,7 @@ def field( only=only, select_related=select_related, prefetch_related=prefetch_related, + annotate=annotate, disable_optimization=disable_optimization, extensions=extensions, ) @@ -631,6 +638,7 @@ def node( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None, + annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None, disable_optimization: bool = False, extensions: List[FieldExtension] = (), # type: ignore # This init parameter is used by pyright to determine whether this field @@ -675,6 +683,7 @@ def node( only=only, select_related=select_related, prefetch_related=prefetch_related, + annotate=annotate, disable_optimization=disable_optimization, extensions=extensions, ) @@ -699,6 +708,7 @@ def connection( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None, + annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None, disable_optimization: bool = False, ) -> Any: ... @@ -725,6 +735,7 @@ def connection( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None, + annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None, disable_optimization: bool = False, ) -> Any: ... @@ -749,6 +760,7 @@ def connection( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None, + annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None, disable_optimization: bool = False, # This init parameter is used by pyright to determine whether this field # is added in the constructor or not. It is not used to change @@ -823,6 +835,7 @@ def connection( only=only, select_related=select_related, prefetch_related=prefetch_related, + annotate=annotate, disable_optimization=disable_optimization, extensions=extensions, ) diff --git a/strawberry_django_plus/optimizer.py b/strawberry_django_plus/optimizer.py index 3cb4c53..e84b2e7 100644 --- a/strawberry_django_plus/optimizer.py +++ b/strawberry_django_plus/optimizer.py @@ -21,6 +21,7 @@ from django.db import models from django.db.models import Prefetch from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import BaseExpression from django.db.models.fields.reverse_related import ( ManyToManyRel, ManyToOneRel, @@ -53,7 +54,7 @@ get_possible_type_definitions, get_selections, ) -from .utils.typing import TypeOrSequence +from .utils.typing import TypeOrMapping, TypeOrSequence try: from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation @@ -79,6 +80,7 @@ else: _relation_fields = (models.ManyToManyField, ManyToManyRel, ManyToOneRel) _sentinel = object() +_annotate_placeholder = "______annotate_placeholder______" _interfaces: """ defaultdict[ Schema, @@ -90,6 +92,8 @@ PrefetchCallable: TypeAlias = Callable[[GraphQLResolveInfo], Prefetch] PrefetchType: TypeAlias = Union[str, Prefetch, PrefetchCallable] +AnnotateCallable: TypeAlias = Callable[[GraphQLResolveInfo], BaseExpression] +AnnotateType: TypeAlias = Union[BaseExpression, AnnotateCallable] def _get_prefetch_queryset( @@ -207,6 +211,15 @@ def _get_model_hints( # Add annotations from the field if they exist field_store = getattr(field, "store", None) if field_store is not None: + if len(field_store.annotate) == 1 and _annotate_placeholder in field_store.annotate: + # This is a special case where we need to update the field name, + # because when field_store was created on __init__, the field name wasn't available. + # This allows for annotate expressions to be declared as: + # total: int = gql.django.field(annotate=Sum("price")) # noqa: ERA001 + # Instead of the more redundant: + # total: int = gql.django.field(annotate={"total": Sum("price")}) # noqa: ERA001 + field_store.annotate = {field.name: field_store.annotate[_annotate_placeholder]} + store |= field_store.with_prefix(prefix, info=info) if prefix else field_store # Then from the model property if one is defined @@ -446,6 +459,8 @@ class OptimizerConfig: Enable `QuerySet.select_related` optimizations enable_prefetch_related: Enable `QuerySet.prefetch_related` optimizations + enable_annotate: + Enable `QuerySet.annotate` optimizations prefetch_custom_queryset: Use custom instead of _base_manager for prefetch querysets @@ -454,6 +469,7 @@ class OptimizerConfig: enable_only: bool = dataclasses.field(default=True) enable_select_related: bool = dataclasses.field(default=True) enable_prefetch_related: bool = dataclasses.field(default=True) + enable_annotate: bool = dataclasses.field(default=True) prefetch_custom_queryset: bool = dataclasses.field(default=False) @@ -468,20 +484,23 @@ class OptimizerStore: Set of values to optimize using `QuerySet.select_related` prefetch_related: Set of values to optimize using `QuerySet.prefetch_related` - + annotate: + Dict to use on `QuerySet.annotate` """ only: List[str] = dataclasses.field(default_factory=list) select_related: List[str] = dataclasses.field(default_factory=list) prefetch_related: List[PrefetchType] = dataclasses.field(default_factory=list) + annotate: Dict[str, AnnotateType] = dataclasses.field(default_factory=dict) def __bool__(self): - return any([self.only, self.select_related, self.prefetch_related]) + return any([self.only, self.select_related, self.prefetch_related, self.annotate]) def __ior__(self, other: "OptimizerStore"): self.only.extend(other.only) self.select_related.extend(other.select_related) self.prefetch_related.extend(other.prefetch_related) + self.annotate.update(other.annotate) return self def __or__(self, other: "OptimizerStore"): @@ -492,6 +511,7 @@ def copy(self): only=self.only[:], select_related=self.select_related[:], prefetch_related=self.prefetch_related[:], + annotate=self.annotate.copy(), ) @classmethod @@ -501,6 +521,7 @@ def with_hints( only: Optional[TypeOrSequence[str]] = None, select_related: Optional[TypeOrSequence[str]] = None, prefetch_related: Optional[TypeOrSequence[PrefetchType]] = None, + annotate: Optional[TypeOrMapping[AnnotateType]] = None, ): return cls( only=[only] if isinstance(only, str) else list(only or []), @@ -512,6 +533,12 @@ def with_hints( if isinstance(prefetch_related, (str, Prefetch, Callable)) else list(prefetch_related or []) ), + annotate=( + # placeholder here, because field name is evaluated later on .annotate call: + {_annotate_placeholder: annotate} + if isinstance(annotate, (BaseExpression, Callable)) + else dict(annotate or {}) + ), ) def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo): @@ -529,10 +556,19 @@ def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo): else: # pragma:nocover assert_never(p) + annotate = {} + for k, v in self.annotate.items(): + if isinstance(v, Callable): + assert_type(v, AnnotateCallable) + v = v(info) # noqa: PLW2901 + + annotate[f"{prefix}{LOOKUP_SEP}{k}"] = v + return self.__class__( only=[f"{prefix}{LOOKUP_SEP}{i}" for i in self.only], select_related=[f"{prefix}{LOOKUP_SEP}{i}" for i in self.select_related], prefetch_related=prefetch_related, + annotate=annotate, ) def apply( @@ -601,6 +637,17 @@ def apply( if config.enable_only and self.only: qs = qs.only(*self.only) + if config.enable_annotate and self.annotate: + to_annotate = {} + for k, v in self.annotate.items(): + if isinstance(v, Callable): + assert_type(v, AnnotateCallable) + v = v(info) # noqa: PLW2901 + + to_annotate[k] = v + + qs = qs.annotate(**to_annotate) + return qs @@ -620,6 +667,9 @@ class DjangoOptimizerExtension(SchemaExtension): Enable `QuerySet.select_related` optimizations enable_prefetch_related_optimization: Enable `QuerySet.prefetch_related` optimizations + enable_annotate_optimization: + Enable `QuerySet.annotate` optimizations + Examples: Add the following to your schema configuration. @@ -647,13 +697,15 @@ def __init__( enable_only_optimization: bool = True, enable_select_related_optimization: bool = True, enable_prefetch_related_optimization: bool = True, + enable_annotate_optimization: bool = True, execution_context: Optional[ExecutionContext] = None, prefetch_custom_queryset: bool = False, ): super().__init__(execution_context=execution_context) # type: ignore - self._enable_ony = enable_only_optimization + self._enable_only = enable_only_optimization self._enable_select_related = enable_select_related_optimization self._enable_prefetch_related = enable_prefetch_related_optimization + self._enable_annotate = enable_annotate_optimization self._prefetch_custom_queryset = prefetch_custom_queryset def on_execute(self) -> Generator[None, None, None]: @@ -684,10 +736,11 @@ def resolve( if ret._result_cache is None: # type: ignore config = OptimizerConfig( enable_only=( - self._enable_ony and info.operation.operation == OperationType.QUERY + self._enable_only and info.operation.operation == OperationType.QUERY ), enable_select_related=self._enable_select_related, enable_prefetch_related=self._enable_prefetch_related, + enable_annotate=self._enable_annotate, prefetch_custom_queryset=self._prefetch_custom_queryset, ) return resolvers.resolve_qs(optimize(qs=ret, info=info, config=config)) @@ -705,9 +758,10 @@ def optimize( return qs config = OptimizerConfig( - enable_only=self._enable_ony and info.operation.operation == OperationType.QUERY, + enable_only=self._enable_only and info.operation.operation == OperationType.QUERY, enable_select_related=self._enable_select_related, enable_prefetch_related=self._enable_prefetch_related, + enable_annotate=self._enable_annotate, prefetch_custom_queryset=self._prefetch_custom_queryset, ) return optimize(qs, info, config=config, store=store) diff --git a/strawberry_django_plus/utils/typing.py b/strawberry_django_plus/utils/typing.py index d8c29d2..fc947b2 100644 --- a/strawberry_django_plus/utils/typing.py +++ b/strawberry_django_plus/utils/typing.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Iterable, Sequence, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, Sequence, TypeVar, Union from django.contrib.auth.base_user import AbstractBaseUser from graphql.type.definition import GraphQLResolveInfo @@ -14,6 +14,7 @@ DictTree: TypeAlias = Dict[str, "DictTree"] TypeOrSequence: TypeAlias = Union[_T, Sequence[_T]] +TypeOrMapping: TypeAlias = Union[_T, Mapping[str, _T]] TypeOrIterable: TypeAlias = Union[_T, Iterable[_T]] UserType: TypeAlias = Union[AbstractBaseUser, "AnonymousUser"] ResolverInfo: TypeAlias = Union[Info[StrawberryDjangoContext, Any], GraphQLResolveInfo]