diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index e1e73b91b..c06eebc8a 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -169,6 +169,17 @@ LABEL_INSTANCE_TYPE = "node.kubernetes.io/instance-type" +# role.env translates to static env variables in the yaml +# {"FOO" : "bar"} =====> - name: FOO +# value: bar +# unless this placeholder is present at the start of the role.env value then the env variable +# in the yaml will be dynamically populated at runtime (placeholder is stripped out of the value) +# {"FOO" : "[FIELD_PATH]bar"} =====> - name: FOO +# valueFrom: +# fieldRef: +# fieldPath: bar +PLACEHOLDER_FIELD_PATH = "[FIELD_PATH]" + def sanitize_for_serialization(obj: object) -> object: from kubernetes import client @@ -183,7 +194,9 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod V1ContainerPort, V1EmptyDirVolumeSource, V1EnvVar, + V1EnvVarSource, V1HostPathVolumeSource, + V1ObjectFieldSelector, V1ObjectMeta, V1PersistentVolumeClaimVolumeSource, V1Pod, @@ -303,9 +316,20 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod image=role.image, name=name, env=[ - V1EnvVar( - name=name, - value=value, + ( + V1EnvVar( + name=name, + value_from=V1EnvVarSource( + field_ref=V1ObjectFieldSelector( + field_path=value.strip(PLACEHOLDER_FIELD_PATH) + ) + ), + ) + if value.startswith(PLACEHOLDER_FIELD_PATH) + else V1EnvVar( + name=name, + value=value, + ) ) for name, value in role.env.items() ], diff --git a/torchx/schedulers/test/kubernetes_scheduler_test.py b/torchx/schedulers/test/kubernetes_scheduler_test.py index 8ad0b0e87..a03b2834b 100644 --- a/torchx/schedulers/test/kubernetes_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_scheduler_test.py @@ -28,6 +28,7 @@ KubernetesOpts, KubernetesScheduler, LABEL_INSTANCE_TYPE, + PLACEHOLDER_FIELD_PATH, role_to_pod, ) from torchx.specs import AppState @@ -75,7 +76,7 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef: "--rank0-env", specs.macros.rank0_env, ], - env={"FOO": "bar"}, + env={"FOO": "bar", "FOO_FIELD_PATH": f"{PLACEHOLDER_FIELD_PATH}bar"}, resource=specs.Resource( cpu=2, memMB=3000, @@ -149,7 +150,9 @@ def test_role_to_pod(self) -> None: V1ContainerPort, V1EmptyDirVolumeSource, V1EnvVar, + V1EnvVarSource, V1HostPathVolumeSource, + V1ObjectFieldSelector, V1ObjectMeta, V1Pod, V1PodSpec, @@ -188,7 +191,15 @@ def test_role_to_pod(self) -> None: ], image="pytorch/torchx:latest", name="name", - env=[V1EnvVar(name="FOO", value="bar")], + env=[ + V1EnvVar(name="FOO", value="bar"), + V1EnvVar( + name="FOO_FIELD_PATH", + value_from=V1EnvVarSource( + field_ref=V1ObjectFieldSelector(field_path="bar") + ), + ), + ], resources=resources, ports=[V1ContainerPort(name="foo", container_port=1234)], security_context=V1SecurityContext(), @@ -303,6 +314,10 @@ def test_submit_dryrun(self) -> None: env: - name: FOO value: bar + - name: FOO_FIELD_PATH + valueFrom: + fieldRef: + fieldPath: bar - name: TORCHX_RANK0_HOST value: localhost image: pytorch/torchx:latest