Skip to content

Commit

Permalink
Convert dbt_tags parameter to general select parameter (#96)
Browse files Browse the repository at this point in the history
Closes [94](#94)

I also snuck in there a max line length of 120 so that pre-commit
stopped failing.
  • Loading branch information
chrishronek authored Jan 25, 2023
1 parent a7013db commit b4ea9e6
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 12 deletions.
9 changes: 6 additions & 3 deletions cosmos/providers/dbt/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class DbtDag(CosmosDag):
:param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG dependencies
:param test_behavior: The behavior for running tests. Options are "none", "after_each", and "after_all".
Defaults to "after_each"
:param dbt_tags: A list of dbt tags to filter the dbt models by
:param select: A dict of dbt selector arguments (i.e., {"tags": ["tag_1", "tag_2"]})
:param exclude: A dict of dbt exclude arguments (i.e., {"tags": ["tag_1", "tag_2"]})
"""

def __init__(
Expand All @@ -36,7 +37,8 @@ def __init__(
emit_datasets: bool = True,
dbt_root_path: str = "/usr/local/airflow/dbt",
test_behavior: Literal["none", "after_each", "after_all"] = "after_each",
dbt_tags: List[str] = [],
select: Dict[str, List[str]] = {},
exclude: Dict[str, List[str]] = {},
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -54,7 +56,8 @@ def __init__(
test_behavior=test_behavior,
emit_datasets=emit_datasets,
conn_id=conn_id,
dbt_tags=dbt_tags,
select=select,
exclude=exclude,
)

# call the airflow DAG constructor
Expand Down
26 changes: 22 additions & 4 deletions cosmos/providers/dbt/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Dict, List

from airflow.datasets import Dataset
from airflow.exceptions import AirflowException

from cosmos.core.graph.entities import CosmosEntity, Group, Task
from cosmos.providers.dbt.parser.project import DbtProject
Expand All @@ -25,7 +26,8 @@ def render_project(
test_behavior: Literal["none", "after_each", "after_all"] = "after_each",
emit_datasets: bool = True,
conn_id: str = "default_conn_id",
dbt_tags: List[str] = [],
select: Dict[str, List[str]] = {},
exclude: Dict[str, List[str]] = {},
) -> Group:
"""
Turn a dbt project into a Group
Expand All @@ -37,7 +39,8 @@ def render_project(
Defaults to "after_each"
:param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG dependencies
:param conn_id: The Airflow connection ID to use in Airflow Datasets
:param dbt_tags: A list of dbt tags to filter the dbt models by
:param select: A dict of dbt selector arguments (i.e., {"tags": ["tag_1", "tag_2"]})
:param exclude: A dict of dbt exclude arguments (i.e., {"tags": ["tag_1", "tag_2]}})
"""
# first, get the dbt project
project = DbtProject(
Expand All @@ -53,11 +56,26 @@ def render_project(
# add project_dir arg to task_args
task_args["project_dir"] = project.project_dir

# ensures the same tag isn't in select & exclude
if "tags" in select and "tags" in exclude:
if set(select["tags"]).intersection(exclude["tags"]):
raise AirflowException(
f"Can't specify the same tag in `select` and `include`: "
f"{set(select['tags']).intersection(exclude['tags'])}"
)

# iterate over each model once to create the initial tasks
for model_name, model in project.models.items():
# if we have tags, only include models that have at least one of the tags
if dbt_tags and not set(dbt_tags).intersection(model.config.tags):
continue
# filters down to a set of specified tags
if "tags" in select:
if not set(select["tags"]).intersection(model.config.tags):
continue

# filters out any specified tags
if "tags" in exclude:
if set(exclude["tags"]).intersection(model.config.tags):
continue

run_args: Dict[str, Any] = {**task_args, "models": model_name}
test_args: Dict[str, Any] = {**task_args, "models": model_name}
Expand Down
9 changes: 6 additions & 3 deletions cosmos/providers/dbt/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class DbtTaskGroup(CosmosTaskGroup):
:param emit_datasets: If enabled test nodes emit Airflow Datasets for downstream cross-DAG dependencies
:param test_behavior: The behavior for running tests. Options are "none", "after_each", and "after_all".
Defaults to "after_each"
:param dbt_tags: A list of dbt tags to filter the dbt models by
:param select: A dict of dbt selector arguments (i.e., {"tags": ["tag_1", "tag_2"]})
:param exclude: A dict of dbt exclude arguments (i.e., {"tags": ["tag_1", "tag_2"]})
"""

def __init__(
Expand All @@ -36,7 +37,8 @@ def __init__(
emit_datasets: bool = True,
dbt_root_path: str = "/usr/local/airflow/dbt",
test_behavior: Literal["none", "after_each", "after_all"] = "after_each",
dbt_tags: List[str] = [],
select: Dict[str, List[str]] = {},
exclude: Dict[str, List[str]] = {},
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -54,7 +56,8 @@ def __init__(
test_behavior=test_behavior,
emit_datasets=emit_datasets,
conn_id=conn_id,
dbt_tags=dbt_tags,
select=select,
exclude=exclude,
)

# call the airflow constructor
Expand Down
4 changes: 2 additions & 2 deletions docs/dbt/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Example:
Tags
----------------------

Cosmos allows you to filter by tags using the ``dbt_tags`` parameter. If a model contains any of the tags, it gets included as part of the DAG/Task Group. Otherwise, it doesn't get included (even if rendered models depend on a non-tagged model).
Cosmos allows you to filter by tags using the ``select`` parameter. If a model contains any of the tags, it gets included as part of the DAG/Task Group. Otherwise, it doesn't get included (even if rendered models depend on a non-tagged model).

.. note::
Cosmos currently reads from (1) config calls in the model code and (2) .yml files in the models directory for tags. It does not read from the dbt_project.yml file.
Expand All @@ -40,5 +40,5 @@ Example:
jaffle_shop = DbtDag(
# ...
dbt_tags=['daily'],
select={"tags": ['daily']},
)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ known_third_party = ["airflow", "jinja2"]

[tool.mypy]
strict = true

[tool.ruff]
line-length = 120

0 comments on commit b4ea9e6

Please sign in to comment.