From bea83d7fbcdcb7655810ac13aa0a4b476830e3da Mon Sep 17 00:00:00 2001 From: BanananaFish Date: Tue, 2 Apr 2024 11:44:02 +0800 Subject: [PATCH] ga improve; early stop implement --- models/cell_org.mph.lock | 1 + src/comsol/cmdline.py | 4 +-- src/comsol/ga.py | 60 ++++++++++++++++++++++++++++------------ src/comsol/utils.py | 35 +++++++++++++---------- 4 files changed, 67 insertions(+), 33 deletions(-) create mode 100644 models/cell_org.mph.lock diff --git a/models/cell_org.mph.lock b/models/cell_org.mph.lock new file mode 100644 index 0000000..46e1e86 --- /dev/null +++ b/models/cell_org.mph.lock @@ -0,0 +1 @@ +10951 diff --git a/src/comsol/cmdline.py b/src/comsol/cmdline.py index 719068c..bce961e 100644 --- a/src/comsol/cmdline.py +++ b/src/comsol/cmdline.py @@ -46,8 +46,8 @@ def train(saved, config, ckpt_path): trainer = Trainer(dataset, model, cfg, ckpt_path) try: trainer.train() - except KeyboardInterrupt: - trainer.save_ckpt(f"earlystop_best_{trainer.best_loss:.3f}", best=True) + except (KeyboardInterrupt, Trainer.EarlyStop): + trainer.save_ckpt(f"earlystop_best_{trainer.best_loss:.6f}", best=True) @main.command() diff --git a/src/comsol/ga.py b/src/comsol/ga.py index 8f9ad6a..6c4889a 100644 --- a/src/comsol/ga.py +++ b/src/comsol/ga.py @@ -1,3 +1,5 @@ +import warnings + import numpy import pygad import torch @@ -6,36 +8,55 @@ from comsol.model import MLP from comsol.utils import BandDataset, Config +warnings.filterwarnings("ignore") + -def fitness_warper(fitness_metric): +def fitness_warper(fitness_metric, net): def fitness_func(ga_instance, solution, solution_idx): - fitness = fitness_metric(solution) + fitness = fitness_metric(solution, net) return fitness return fitness_func +def max_min_distance_four(solution, net): + # line0_0, line1_0, line0_1, line1_1 + # 1 . 3 + # . . . + # 0 . 2 + Bs: numpy.ndarray = net(torch.tensor([solution]).float()).detach().numpy().flatten() + return abs(Bs[3] - Bs[2]) + + +def max_min_distance_six(solution, net): + # line0_0, line1_0, line0_1, line1_1 + # 1 . 3 . 5 + # . . . . . + # 0 . 2 . 4 + if any(solution < 0) or any(solution > 1): + return -1000 + Bs: numpy.ndarray = net(torch.tensor([solution]).float()).detach().numpy().flatten() + return min(abs(Bs[3] - Bs[2]), abs(Bs[5] - Bs[4])) + + def fit(ckpt, pkl_path, cfg: Config): net = MLP(cfg) net.load_state_dict(torch.load(ckpt)) net.eval() dataset = BandDataset(pkl_path, cfg) - def fitness_func(ga_instance, solution, solution_idx): - if any(solution < 0) or any(solution > 1): - return -1000 - Bs: numpy.ndarray = ( - net(torch.tensor([solution]).float()).detach().numpy().flatten() - ) - return Bs[2] - Bs[0] - - fitness_function = fitness_func + if cfg["dataset"]["sampler"] == "four_points": + fitness_func = fitness_warper(max_min_distance_four, net) + elif cfg["dataset"]["sampler"] == "six_points": + fitness_func = fitness_warper(max_min_distance_six, net) + else: + raise ValueError(f"Unknown sampler: {cfg['dataset']['sampler']}") num_generations = 100 num_parents_mating = 4 sol_per_pop = 30 - num_genes = 3 + num_genes = len(cfg["cell"].values()) init_range_low = 0 init_range_high = 1 @@ -49,7 +70,7 @@ def fitness_func(ga_instance, solution, solution_idx): ga_instance = pygad.GA( num_generations=num_generations, num_parents_mating=num_parents_mating, - fitness_func=fitness_function, + fitness_func=fitness_func, sol_per_pop=sol_per_pop, num_genes=num_genes, init_range_low=init_range_low, @@ -66,9 +87,14 @@ def fitness_func(ga_instance, solution, solution_idx): prediction = net(torch.tensor([solution]).float()).detach().numpy().flatten() solution, prediction = dataset.denormalization(solution, prediction) solution[0] = solution[0] * 360 - logger.info(f"Parameters of the best solution : {solution}") - logger.info(f"Predicted output based on the best solution : {prediction}") - logger.info(f"Fitness value of the best solution : {solution_fitness}") + # logger.info(f"Parameters of the best solution : {solution}") + params_dict = dict(zip(cfg["cell"].keys(), solution)) + logger.info(f"BEST Parameters : {params_dict}") + logger.info(f"Predicted Band Structure : {prediction}") + logger.info(f"Fitness: {solution_fitness}") logger.info( - f"Fitness value of the best solution denormalization= {solution_fitness * (dataset.res_max - dataset.res_min) + dataset.res_min:.5e}" + f"Fitness (Denormal)= {solution_fitness * (dataset.res_max - dataset.res_min) + dataset.res_min:.5e}" ) + + +# TODO: auto evaluate diff --git a/src/comsol/utils.py b/src/comsol/utils.py index e26ba7b..f7ec8e2 100644 --- a/src/comsol/utils.py +++ b/src/comsol/utils.py @@ -76,6 +76,9 @@ def dump(self, path: PathLike | str): class Trainer: + class EarlyStop(Exception): + pass + def __init__(self, dataset, model, cfg: Config, ckpt_path): self.model = model self.cfg = cfg @@ -87,6 +90,9 @@ def __init__(self, dataset, model, cfg: Config, ckpt_path): self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr) self.best_loss = float("inf") self.best_ckpt = model.state_dict() + self.stuck_count = 0 + self.early_stop = cfg["train"]["early_stop"] + self.ckpt_path = ckpt_path dataset_size = len(dataset) @@ -108,40 +114,40 @@ def train(self): self.model.train() self.model = self.to_cuda(self.model) for epoch in range(1, self.epoch + 1): - for i, (x, y) in track( - enumerate(self.train_loader), - total=len(self.train_loader), - description=f"Epoch {epoch}", - auto_refresh=False, - ): + for i, (x, y) in enumerate(self.train_loader): x, y = self.to_cuda(x), self.to_cuda(y) self.optimizer.zero_grad() y_pred = self.model(x) loss = self.loss(y_pred, y) loss.backward() self.optimizer.step() - if i % 10 == 0: - logger.info(f"Epoch {epoch}, iter {i}, loss: {loss.item():.3f}") + if i % 25 == 0: + logger.info( + f"Epoch [{epoch}/{self.cfg['train']['epoch']}], iter {i}, loss: {loss.item():.6f}" + ) self.test() self.save_ckpt("lastest") - logger.info(f"Training finished, best loss: {self.best_loss:.3f}") - self.save_ckpt(f"best_loss_{self.best_loss:.3f}", best=True) + logger.info(f"Training finished, best loss: {self.best_loss:.6f}") + self.save_ckpt(f"best_loss_{self.best_loss:.6f}", best=True) def test(self): self.model.eval() losses = 0 with torch.no_grad(): - for x, y in track( - self.test_loader, description="Testing", auto_refresh=False - ): + for x, y in self.test_loader: x, y = self.to_cuda(x), self.to_cuda(y) y_pred = self.model(x) losses += self.loss(y_pred, y) now_loss = losses / len(self.test_loader) - logger.info(f"Test loss: {now_loss:.3f}") + logger.info(f"Test loss: {now_loss:.6f}") if now_loss < self.best_loss: + self.stuck_count = 0 self.best_loss = now_loss self.best_ckpt = self.model.state_dict() + else: + self.stuck_count += 1 + if self.stuck_count >= self.early_stop: + raise self.EarlyStop def save_ckpt(self, name, best=False): ckpt_path = Path(self.ckpt_path) / f"{self.start_time:%Y.%m.%d_%H.%M.%S}" @@ -283,6 +289,7 @@ def __init__(self, saved_path: PathLike | str, cfg: Config): with open(pkl, "rb") as f: data: tuple[dict[str, str], np.ndarray] = pickle.load(f) params, res = data + # 把第一个位置的参数归一化 params_list.append(self.to_rand(params)) res_arr_list.append(self.get_Bs(res, cfg["dataset"]["sampler"]))