-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathslurm_submit.py
137 lines (107 loc) · 4.35 KB
/
slurm_submit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
A script to run multinode training with submitit.
"""
import argparse
import os
import sys
import uuid
from argparse import Namespace
from pathlib import Path
import submitit
from transformer_lightning import get_parser
WORK_DIR = str(Path(__file__).parent.absolute())
def add_slurm_args(parser=None):
p = parser.add_argument_group('Slurm submitit')
p.add_argument("--ngpus", default=1, type=int,
help="Number of gpus to request on each node")
p.add_argument("--vram", default="12GB", type=str)
p.add_argument("--num_gpus", default=1, type=int)
p.add_argument("--mem_per_gpu", default=24, type=int)
p.add_argument("--nodes", default=1, type=int,
help="Number of nodes to request")
p.add_argument("--timeout", default=4320, type=int,
help="Maximum duration of the job in minutes")
p.add_argument("--job_dir", default="", type=str,
help="Job dir. Leave empty for automatic.")
p.add_argument("--cluster", default=None, type=str,
help="Use to run jobs locally.")
p.add_argument("--slurm_partition", default="NORMAL", type=str,
help="Partition. Leave empty for automatic.")
p.add_argument("--slurm_constraint", default="", type=str,
help="Constraint. Leave empty for automatic.")
p.add_argument("--slurm_comment", default="", type=str)
p.add_argument("--slurm_gres", default="", type=str)
p.add_argument("--slurm_exclude", default="", type=str)
p.add_argument("--checkpoint_name", default="last.ckpt", type=str)
p.add_argument("--jobname", default='test_job', type=str)
return parser
def get_shared_folder() -> Path:
user = os.getenv("USER")
if Path("/storage/slurm").is_dir():
path = Path(f"/storage/slurm/{user}/runs")
path.mkdir(exist_ok=True)
return path
raise RuntimeError("No shared folder available")
def get_init_file() -> Path:
# Init file must not exist, but it's parent dir must exist.
os.makedirs(str(get_shared_folder()), exist_ok=True)
init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
if init_file.exists():
os.remove(str(init_file))
return init_file
class Trainer:
def __init__(self, args: Namespace, jobid=None) -> None:
self.args = args
def __call__(self) -> None:
sys.path.append(WORK_DIR)
import transformer_lightning as train_file
self._setup_gpu_args()
train_file.train(self.args)
def _setup_gpu_args(self) -> None:
from pathlib import Path
import submitit
job_env = submitit.JobEnvironment()
self.args.gpu = job_env.local_rank
self.args.rank = job_env.global_rank
self.args.world_size = job_env.num_tasks
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
def main():
parser = get_parser()
args = add_slurm_args(parser).parse_args()
# Note that the folder will depend on the job_id, to easily track
# experiments
if args.job_dir == "":
args.job_dir = get_shared_folder() / "%j"
executor = submitit.AutoExecutor(
folder=args.job_dir, cluster=args.cluster, slurm_max_num_timeout=30)
# cluster setup is defined by environment variables
num_gpus_per_node = args.num_gpus
nodes = args.nodes
timeout_min = args.timeout
if args.slurm_gres:
slurm_gres = args.slurm_gres
else:
slurm_gres = f'gpu:{num_gpus_per_node},VRAM:{args.vram}'
executor.update_parameters(
mem_gb=args.mem_per_gpu * num_gpus_per_node,
# gpus_per_node=num_gpus_per_node,
tasks_per_node=num_gpus_per_node, # one task per GPU
cpus_per_task=8,
nodes=nodes,
timeout_min=timeout_min,
slurm_partition=args.slurm_partition,
slurm_constraint=args.slurm_constraint,
slurm_comment=args.slurm_comment,
slurm_exclude=args.slurm_exclude,
slurm_gres=slurm_gres
)
executor.update_parameters(name=args.jobname)
args.dist_url = get_init_file().as_uri()
trainer = Trainer(args)
job = executor.submit(trainer)
print("Submitted job_id:", job.job_id)
if args.cluster == 'debug':
job.wait()
if __name__ == "__main__":
main()