Skip to content
This repository has been archived by the owner on Jul 8, 2023. It is now read-only.

Draft: support annotate parameter in field to allow ORM annotations #255

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion strawberry_django_plus/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .descriptors import ModelProperty
from .permissions import filter_with_perms
from .utils import resolvers
from .utils.typing import TypeOrSequence
from .utils.typing import TypeOrMapping, TypeOrSequence

if TYPE_CHECKING:
from strawberry_django_plus.type import StrawberryDjangoType
Expand Down Expand Up @@ -87,6 +87,7 @@ def __init__(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
**kwargs,
):
Expand All @@ -95,6 +96,7 @@ def __init__(
only=only,
select_related=select_related,
prefetch_related=prefetch_related,
annotate=annotate,
)
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -486,6 +488,7 @@ def field(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
) -> _T:
Expand Down Expand Up @@ -513,6 +516,7 @@ def field(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
) -> Any:
Expand Down Expand Up @@ -540,6 +544,7 @@ def field(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
) -> StrawberryDjangoField:
Expand All @@ -566,6 +571,7 @@ def field(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
# This init parameter is used by pyright to determine whether this field
Expand Down Expand Up @@ -606,6 +612,7 @@ def field(
only=only,
select_related=select_related,
prefetch_related=prefetch_related,
annotate=annotate,
disable_optimization=disable_optimization,
extensions=extensions,
)
Expand All @@ -631,6 +638,7 @@ def node(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
extensions: List[FieldExtension] = (), # type: ignore
# This init parameter is used by pyright to determine whether this field
Expand Down Expand Up @@ -675,6 +683,7 @@ def node(
only=only,
select_related=select_related,
prefetch_related=prefetch_related,
annotate=annotate,
disable_optimization=disable_optimization,
extensions=extensions,
)
Expand All @@ -699,6 +708,7 @@ def connection(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
) -> Any:
...
Expand All @@ -725,6 +735,7 @@ def connection(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
) -> Any:
...
Expand All @@ -749,6 +760,7 @@ def connection(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[optimizer.PrefetchType]] = None,
annotate: Optional[TypeOrMapping[optimizer.AnnotateType]] = None,
disable_optimization: bool = False,
# This init parameter is used by pyright to determine whether this field
# is added in the constructor or not. It is not used to change
Expand Down Expand Up @@ -823,6 +835,7 @@ def connection(
only=only,
select_related=select_related,
prefetch_related=prefetch_related,
annotate=annotate,
disable_optimization=disable_optimization,
extensions=extensions,
)
Expand Down
66 changes: 60 additions & 6 deletions strawberry_django_plus/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from django.db import models
from django.db.models import Prefetch
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import BaseExpression
from django.db.models.fields.reverse_related import (
ManyToManyRel,
ManyToOneRel,
Expand Down Expand Up @@ -53,7 +54,7 @@
get_possible_type_definitions,
get_selections,
)
from .utils.typing import TypeOrSequence
from .utils.typing import TypeOrMapping, TypeOrSequence

try:
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
Expand All @@ -79,6 +80,7 @@
else:
_relation_fields = (models.ManyToManyField, ManyToManyRel, ManyToOneRel)
_sentinel = object()
_annotate_placeholder = "______annotate_placeholder______"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just __annotated_placeholder__ would be enough? Any reason to be using 4 underscores?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change to __annotated_placeholder__, it should be more than enough to avoid clashes. Also, if you have a suggestion on how to avoid this, please LMK. Unfortunately, I needed this workaround because field names aren't bind at class declaration moment.

_interfaces: """
defaultdict[
Schema,
Expand All @@ -90,6 +92,8 @@

PrefetchCallable: TypeAlias = Callable[[GraphQLResolveInfo], Prefetch]
PrefetchType: TypeAlias = Union[str, Prefetch, PrefetchCallable]
AnnotateCallable: TypeAlias = Callable[[GraphQLResolveInfo], BaseExpression]
AnnotateType: TypeAlias = Union[BaseExpression, AnnotateCallable]


def _get_prefetch_queryset(
Expand Down Expand Up @@ -207,6 +211,15 @@ def _get_model_hints(
# Add annotations from the field if they exist
field_store = getattr(field, "store", None)
if field_store is not None:
if len(field_store.annotate) == 1 and _annotate_placeholder in field_store.annotate:
# This is a special case where we need to update the field name,
# because when field_store was created on __init__, the field name wasn't available.
# This allows for annotate expressions to be declared as:
# total: int = gql.django.field(annotate=Sum("price")) # noqa: ERA001
# Instead of the more redundant:
# total: int = gql.django.field(annotate={"total": Sum("price")}) # noqa: ERA001
field_store.annotate = {field.name: field_store.annotate[_annotate_placeholder]}

store |= field_store.with_prefix(prefix, info=info) if prefix else field_store

# Then from the model property if one is defined
Expand Down Expand Up @@ -446,6 +459,8 @@ class OptimizerConfig:
Enable `QuerySet.select_related` optimizations
enable_prefetch_related:
Enable `QuerySet.prefetch_related` optimizations
enable_annotate:
Enable `QuerySet.annotate` optimizations
prefetch_custom_queryset:
Use custom instead of _base_manager for prefetch querysets

Expand All @@ -454,6 +469,7 @@ class OptimizerConfig:
enable_only: bool = dataclasses.field(default=True)
enable_select_related: bool = dataclasses.field(default=True)
enable_prefetch_related: bool = dataclasses.field(default=True)
enable_annotate: bool = dataclasses.field(default=True)
prefetch_custom_queryset: bool = dataclasses.field(default=False)


Expand All @@ -468,20 +484,23 @@ class OptimizerStore:
Set of values to optimize using `QuerySet.select_related`
prefetch_related:
Set of values to optimize using `QuerySet.prefetch_related`

annotate:
Dict to use on `QuerySet.annotate`
"""

only: List[str] = dataclasses.field(default_factory=list)
select_related: List[str] = dataclasses.field(default_factory=list)
prefetch_related: List[PrefetchType] = dataclasses.field(default_factory=list)
annotate: Dict[str, AnnotateType] = dataclasses.field(default_factory=dict)

def __bool__(self):
return any([self.only, self.select_related, self.prefetch_related])
return any([self.only, self.select_related, self.prefetch_related, self.annotate])

def __ior__(self, other: "OptimizerStore"):
self.only.extend(other.only)
self.select_related.extend(other.select_related)
self.prefetch_related.extend(other.prefetch_related)
self.annotate.update(other.annotate)
return self

def __or__(self, other: "OptimizerStore"):
Expand All @@ -492,6 +511,7 @@ def copy(self):
only=self.only[:],
select_related=self.select_related[:],
prefetch_related=self.prefetch_related[:],
annotate=self.annotate.copy(),
)

@classmethod
Expand All @@ -501,6 +521,7 @@ def with_hints(
only: Optional[TypeOrSequence[str]] = None,
select_related: Optional[TypeOrSequence[str]] = None,
prefetch_related: Optional[TypeOrSequence[PrefetchType]] = None,
annotate: Optional[TypeOrMapping[AnnotateType]] = None,
):
return cls(
only=[only] if isinstance(only, str) else list(only or []),
Expand All @@ -512,6 +533,12 @@ def with_hints(
if isinstance(prefetch_related, (str, Prefetch, Callable))
else list(prefetch_related or [])
),
annotate=(
# placeholder here, because field name is evaluated later on .annotate call:
{_annotate_placeholder: annotate}
if isinstance(annotate, (BaseExpression, Callable))
else dict(annotate or {})
),
)

def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo):
Expand All @@ -529,10 +556,19 @@ def with_prefix(self, prefix: str, *, info: GraphQLResolveInfo):
else: # pragma:nocover
assert_never(p)

annotate = {}
for k, v in self.annotate.items():
if isinstance(v, Callable):
assert_type(v, AnnotateCallable)
v = v(info) # noqa: PLW2901

annotate[f"{prefix}{LOOKUP_SEP}{k}"] = v

return self.__class__(
only=[f"{prefix}{LOOKUP_SEP}{i}" for i in self.only],
select_related=[f"{prefix}{LOOKUP_SEP}{i}" for i in self.select_related],
prefetch_related=prefetch_related,
annotate=annotate,
)

def apply(
Expand Down Expand Up @@ -601,6 +637,17 @@ def apply(
if config.enable_only and self.only:
qs = qs.only(*self.only)

if config.enable_annotate and self.annotate:
to_annotate = {}
for k, v in self.annotate.items():
if isinstance(v, Callable):
assert_type(v, AnnotateCallable)
v = v(info) # noqa: PLW2901

to_annotate[k] = v

qs = qs.annotate(**to_annotate)

return qs


Expand All @@ -620,6 +667,9 @@ class DjangoOptimizerExtension(SchemaExtension):
Enable `QuerySet.select_related` optimizations
enable_prefetch_related_optimization:
Enable `QuerySet.prefetch_related` optimizations
enable_annotate_optimization:
Enable `QuerySet.annotate` optimizations


Examples:
Add the following to your schema configuration.
Expand Down Expand Up @@ -647,13 +697,15 @@ def __init__(
enable_only_optimization: bool = True,
enable_select_related_optimization: bool = True,
enable_prefetch_related_optimization: bool = True,
enable_annotate_optimization: bool = True,
execution_context: Optional[ExecutionContext] = None,
prefetch_custom_queryset: bool = False,
):
super().__init__(execution_context=execution_context) # type: ignore
self._enable_ony = enable_only_optimization
self._enable_only = enable_only_optimization
self._enable_select_related = enable_select_related_optimization
self._enable_prefetch_related = enable_prefetch_related_optimization
self._enable_annotate = enable_annotate_optimization
self._prefetch_custom_queryset = prefetch_custom_queryset

def on_execute(self) -> Generator[None, None, None]:
Expand Down Expand Up @@ -684,10 +736,11 @@ def resolve(
if ret._result_cache is None: # type: ignore
config = OptimizerConfig(
enable_only=(
self._enable_ony and info.operation.operation == OperationType.QUERY
self._enable_only and info.operation.operation == OperationType.QUERY
),
enable_select_related=self._enable_select_related,
enable_prefetch_related=self._enable_prefetch_related,
enable_annotate=self._enable_annotate,
prefetch_custom_queryset=self._prefetch_custom_queryset,
)
return resolvers.resolve_qs(optimize(qs=ret, info=info, config=config))
Expand All @@ -705,9 +758,10 @@ def optimize(
return qs

config = OptimizerConfig(
enable_only=self._enable_ony and info.operation.operation == OperationType.QUERY,
enable_only=self._enable_only and info.operation.operation == OperationType.QUERY,
enable_select_related=self._enable_select_related,
enable_prefetch_related=self._enable_prefetch_related,
enable_annotate=self._enable_annotate,
prefetch_custom_queryset=self._prefetch_custom_queryset,
)
return optimize(qs, info, config=config, store=store)
3 changes: 2 additions & 1 deletion strawberry_django_plus/utils/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, Sequence, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, Sequence, TypeVar, Union

from django.contrib.auth.base_user import AbstractBaseUser
from graphql.type.definition import GraphQLResolveInfo
Expand All @@ -14,6 +14,7 @@

DictTree: TypeAlias = Dict[str, "DictTree"]
TypeOrSequence: TypeAlias = Union[_T, Sequence[_T]]
TypeOrMapping: TypeAlias = Union[_T, Mapping[str, _T]]
TypeOrIterable: TypeAlias = Union[_T, Iterable[_T]]
UserType: TypeAlias = Union[AbstractBaseUser, "AnonymousUser"]
ResolverInfo: TypeAlias = Union[Info[StrawberryDjangoContext, Any], GraphQLResolveInfo]
Expand Down