diff --git a/ppfleetx/core/engine/auto_engine.py b/ppfleetx/core/engine/auto_engine.py index b78e19931..a49014f88 100644 --- a/ppfleetx/core/engine/auto_engine.py +++ b/ppfleetx/core/engine/auto_engine.py @@ -37,7 +37,7 @@ class AutoEngine(BasicEngine): - def __init__(self, configs, module=None, mode='train'): + def __init__(self, configs, module=None, mode='train', cluster=None): super().__init__() version_check() @@ -99,7 +99,7 @@ def __init__(self, configs, module=None, mode='train'): # init engine self._auto_engine = auto.Engine( - model, loss_fn, optimizer, strategy=self._strategy) + model, loss_fn, optimizer, strategy=self._strategy, cluster=cluster) def fit(self, epoch=1, train_dataset=None, valid_dataset=None): diff --git a/ppfleetx/utils/config.py b/ppfleetx/utils/config.py index c51529633..2f3b484b5 100644 --- a/ppfleetx/utils/config.py +++ b/ppfleetx/utils/config.py @@ -511,7 +511,7 @@ def process_auto_strategy(config): process auto strategy for auto parallel """ strategy = auto.Strategy() - strategy.auto_mode = "semi" + strategy.auto_mode = config.Engine.get('auto_mode', "semi") strategy.seed = config['Global']['seed'] # amp config diff --git a/tools/auto.py b/tools/auto.py index 30d58079a..336bd2ce6 100644 --- a/tools/auto.py +++ b/tools/auto.py @@ -56,7 +56,10 @@ 'step_each_epoch': len(train_data) }) - engine = AutoEngine(configs=cfg, module=module) + device_count = paddle.distributed.get_world_size() + cluster = dist.auto.cluster.Cluster() + cluster.gen_default_config_cluster(node_count=device_count//8, device_count=8) + engine = AutoEngine(configs=cfg, module=module, cluster=cluster) if cfg.Engine.save_load.ckpt_dir is not None: engine.load()