Skip to content

Commit

Permalink
Removed the ability for Operators to specify their own "scheduling de…
Browse files Browse the repository at this point in the history
…ps". (apache#45713)

This is not talking about the relationship between tasks, but the conditions
on an operator that the scheduler checks before it can be schedules -- things
like "are my upstreams complete" or "am I out of my retry period" etc)

With the split of Task SDK and Task Execution interface this feature has
become untennable to support with the split responsibilty mosel, and it is
such a rarely used feature that the right approach is to remove it.

This makes future code and PRs much much easier.
  • Loading branch information
ashb authored Jan 17, 2025
1 parent 16eaa5e commit 5ff07fa
Show file tree
Hide file tree
Showing 18 changed files with 123 additions and 301 deletions.
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class PluginResponse(BaseModel):
global_operator_extra_links: list[str]
operator_extra_links: list[str]
source: Annotated[str, BeforeValidator(coerce_to_string)]
ti_deps: list[Annotated[str, BeforeValidator(coerce_to_string)]]
listeners: list[str]
timetables: list[str]

Expand Down
6 changes: 0 additions & 6 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8699,11 +8699,6 @@ components:
source:
type: string
title: Source
ti_deps:
items:
type: string
type: array
title: Ti Deps
listeners:
items:
type: string
Expand All @@ -8725,7 +8720,6 @@ components:
- global_operator_extra_links
- operator_extra_links
- source
- ti_deps
- listeners
- timetables
title: PluginResponse
Expand Down
2 changes: 1 addition & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,14 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
partial_kwargs=partial_kwargs,
task_id=task_id,
params=partial_params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
template_fields=self.operator_class.template_fields,
template_fields_renderers=self.operator_class.template_fields_renderers,
ui_color=self.operator_class.ui_color,
ui_fgcolor=self.operator_class.ui_fgcolor,
is_empty=False,
is_sensor=self.operator_class._is_sensor,
task_module=self.operator_class.__module__,
task_type=self.operator_class.__name__,
operator_name=operator_name,
Expand Down
4 changes: 0 additions & 4 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,6 @@ def serialize(self):
)


class UnmappableOperator(AirflowException):
"""Raise when an operator is not implemented to be mappable."""


class XComForMappingNotPushed(AirflowException):
"""Raise when a mapped downstream's dependency fails to push XCom for task mapping."""

Expand Down
37 changes: 16 additions & 21 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Union

import attr
import attrs
import methodtools

from airflow.exceptions import UnmappableOperator
from airflow.models.abstractoperator import (
DEFAULT_EXECUTOR,
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
Expand All @@ -51,7 +50,6 @@
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.context import context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
Expand Down Expand Up @@ -140,7 +138,7 @@ def ensure_xcomarg_return_value(arg: Any) -> None:
ensure_xcomarg_return_value(v)


@attr.define(kw_only=True, repr=False)
@attrs.define(kw_only=True, repr=False)
class OperatorPartial:
"""
An "intermediate state" returned by ``BaseOperator.partial()``.
Expand Down Expand Up @@ -193,6 +191,7 @@ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool =

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
from airflow.operators.empty import EmptyOperator
from airflow.sensors.base import BaseSensorOperator

self._expand_called = True
ensure_xcomarg_return_value(expand_input.value)
Expand All @@ -215,14 +214,14 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
partial_kwargs=partial_kwargs,
task_id=task_id,
params=self.params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
template_fields=self.operator_class.template_fields,
template_fields_renderers=self.operator_class.template_fields_renderers,
ui_color=self.operator_class.ui_color,
ui_fgcolor=self.operator_class.ui_fgcolor,
is_empty=issubclass(self.operator_class, EmptyOperator),
is_sensor=issubclass(self.operator_class, BaseSensorOperator),
task_module=self.operator_class.__module__,
task_type=self.operator_class.__name__,
operator_name=operator_name,
Expand All @@ -240,7 +239,7 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
return op


@attr.define(
@attrs.define(
kw_only=True,
# Disable custom __getstate__ and __setstate__ generation since it interacts
# badly with Airflow's DAG serialization and pickling. When a mapped task is
Expand All @@ -267,14 +266,15 @@ class MappedOperator(AbstractOperator):
# Needed for serialization.
task_id: str
params: ParamsDict | dict
deps: frozenset[BaseTIDep]
deps: frozenset[BaseTIDep] = attrs.field(init=False)
operator_extra_links: Collection[BaseOperatorLink]
template_ext: Sequence[str]
template_fields: Collection[str]
template_fields_renderers: dict[str, str]
ui_color: str
ui_fgcolor: str
_is_empty: bool
_is_sensor: bool = False
_task_module: str
_task_type: str
_operator_name: str
Expand All @@ -286,8 +286,8 @@ class MappedOperator(AbstractOperator):
task_group: TaskGroup | None
start_date: pendulum.DateTime | None
end_date: pendulum.DateTime | None
upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
upstream_task_ids: set[str] = attrs.field(factory=set, init=False)
downstream_task_ids: set[str] = attrs.field(factory=set, init=False)

_disallow_kwargs_override: bool
"""Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
Expand All @@ -308,6 +308,12 @@ class MappedOperator(AbstractOperator):
("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger")
)

@deps.default
def _deps(self):
from airflow.models.baseoperator import BaseOperator

return BaseOperator.deps

def __hash__(self):
return id(self)

Expand All @@ -333,7 +339,7 @@ def __attrs_post_init__(self):
@classmethod
def get_serialized_fields(cls):
# Not using 'cls' here since we only want to serialize base fields.
return (frozenset(attr.fields_dict(MappedOperator)) | {"task_type"}) - {
return (frozenset(attrs.fields_dict(MappedOperator)) | {"task_type"}) - {
"_task_type",
"dag",
"deps",
Expand All @@ -346,17 +352,6 @@ def get_serialized_fields(cls):
"_on_failure_fail_dagrun",
}

@methodtools.lru_cache(maxsize=None)
@staticmethod
def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]:
operator_deps = operator_class.deps
if not isinstance(operator_deps, collections.abc.Set):
raise UnmappableOperator(
f"'deps' must be a set defined as a class-level variable on {operator_class.__name__}, "
f"not a {type(operator_deps).__name__}"
)
return operator_deps | {MappedTaskIsExpanded()}

@property
def task_type(self) -> str:
"""Implementing Operator."""
Expand Down
25 changes: 0 additions & 25 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
global_operator_extra_links: list[Any] | None = None
operator_extra_links: list[Any] | None = None
registered_operator_link_classes: dict[str, type] | None = None
registered_ti_dep_classes: dict[str, type] | None = None
timetable_classes: dict[str, type[Timetable]] | None = None
hook_lineage_reader_classes: list[type[HookLineageReader]] | None = None
priority_weight_strategy_classes: dict[str, type[PriorityWeightStrategy]] | None = None
Expand All @@ -95,7 +94,6 @@
"global_operator_extra_links",
"operator_extra_links",
"source",
"ti_deps",
"timetables",
"listeners",
"priority_weight_strategies",
Expand Down Expand Up @@ -171,8 +169,6 @@ class AirflowPlugin:
# buttons.
operator_extra_links: list[Any] = []

ti_deps: list[Any] = []

# A list of timetable classes that can be used for DAG scheduling.
timetables: list[type[Timetable]] = []

Expand Down Expand Up @@ -427,27 +423,6 @@ def initialize_fastapi_plugins():
fastapi_apps.extend(plugin.fastapi_apps)


def initialize_ti_deps_plugins():
"""Create modules for loaded extension from custom task instance dependency rule plugins."""
global registered_ti_dep_classes
if registered_ti_dep_classes is not None:
return

ensure_plugins_loaded()

if plugins is None:
raise AirflowPluginException("Can't load plugins.")

log.debug("Initialize custom taskinstance deps plugins")

registered_ti_dep_classes = {}

for plugin in plugins:
registered_ti_dep_classes.update(
{qualname(ti_dep.__class__): ti_dep.__class__ for ti_dep in plugin.ti_deps}
)


def initialize_extra_operators_links_plugins():
"""Create modules for loaded extension from extra operators links plugins."""
global global_operator_extra_links
Expand Down
4 changes: 3 additions & 1 deletion airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
ui_color: str = "#e6f1f2"
valid_modes: Iterable[str] = ["poke", "reschedule"]

_is_sensor: bool = True

# Adds one additional dependency for all sensor operators that checks if a
# sensor task instance can be rescheduled.
deps = BaseOperator.deps | {ReadyToRescheduleDep()}
Expand Down Expand Up @@ -406,7 +408,7 @@ def reschedule(self):

@classmethod
def get_serialized_fields(cls):
return super().get_serialized_fields() | {"reschedule"}
return super().get_serialized_fields() | {"reschedule", "_is_sensor"}


def poke_mode_only(cls):
Expand Down
7 changes: 1 addition & 6 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,6 @@
"items": { "type": "string" }
},
"_is_dummy": { "type": "boolean" },
"deps": {
"description": "list of dep classes -- if non-standard",
"type": "array",
"items": { "type": "string" },
"uniqueItems": true
},
"doc": { "type": "string" },
"doc_md": { "type": "string" },
"doc_json": { "type": "string" },
Expand All @@ -293,6 +287,7 @@
"_logger_name": { "type": "string" },
"_log_config_logger_name": { "type": "string" },
"_is_mapped": { "const": true, "$comment": "only present when True" },
"_is_sensor": { "const": true, "$comment": "only present when True" },
"expand_input": { "type": "object" },
"partial_kwargs": { "type": "object" },
"map_index_template": { "type": "string" },
Expand Down
Loading

0 comments on commit 5ff07fa

Please sign in to comment.