Skip to content

Commit

Permalink
ga improve; early stop implement
Browse files Browse the repository at this point in the history
  • Loading branch information
BanananaFish committed Apr 2, 2024
1 parent 35811ca commit bea83d7
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 33 deletions.
1 change: 1 addition & 0 deletions models/cell_org.mph.lock
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
10951
4 changes: 2 additions & 2 deletions src/comsol/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def train(saved, config, ckpt_path):
trainer = Trainer(dataset, model, cfg, ckpt_path)
try:
trainer.train()
except KeyboardInterrupt:
trainer.save_ckpt(f"earlystop_best_{trainer.best_loss:.3f}", best=True)
except (KeyboardInterrupt, Trainer.EarlyStop):
trainer.save_ckpt(f"earlystop_best_{trainer.best_loss:.6f}", best=True)


@main.command()
Expand Down
60 changes: 43 additions & 17 deletions src/comsol/ga.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy
import pygad
import torch
Expand All @@ -6,36 +8,55 @@
from comsol.model import MLP
from comsol.utils import BandDataset, Config

warnings.filterwarnings("ignore")


def fitness_warper(fitness_metric):
def fitness_warper(fitness_metric, net):
def fitness_func(ga_instance, solution, solution_idx):
fitness = fitness_metric(solution)
fitness = fitness_metric(solution, net)
return fitness

return fitness_func


def max_min_distance_four(solution, net):
# line0_0, line1_0, line0_1, line1_1
# 1 . 3
# . . .
# 0 . 2
Bs: numpy.ndarray = net(torch.tensor([solution]).float()).detach().numpy().flatten()
return abs(Bs[3] - Bs[2])


def max_min_distance_six(solution, net):
# line0_0, line1_0, line0_1, line1_1
# 1 . 3 . 5
# . . . . .
# 0 . 2 . 4
if any(solution < 0) or any(solution > 1):
return -1000
Bs: numpy.ndarray = net(torch.tensor([solution]).float()).detach().numpy().flatten()
return min(abs(Bs[3] - Bs[2]), abs(Bs[5] - Bs[4]))


def fit(ckpt, pkl_path, cfg: Config):
net = MLP(cfg)
net.load_state_dict(torch.load(ckpt))
net.eval()
dataset = BandDataset(pkl_path, cfg)

def fitness_func(ga_instance, solution, solution_idx):
if any(solution < 0) or any(solution > 1):
return -1000
Bs: numpy.ndarray = (
net(torch.tensor([solution]).float()).detach().numpy().flatten()
)
return Bs[2] - Bs[0]

fitness_function = fitness_func
if cfg["dataset"]["sampler"] == "four_points":
fitness_func = fitness_warper(max_min_distance_four, net)
elif cfg["dataset"]["sampler"] == "six_points":
fitness_func = fitness_warper(max_min_distance_six, net)
else:
raise ValueError(f"Unknown sampler: {cfg['dataset']['sampler']}")

num_generations = 100
num_parents_mating = 4

sol_per_pop = 30
num_genes = 3
num_genes = len(cfg["cell"].values())
init_range_low = 0
init_range_high = 1

Expand All @@ -49,7 +70,7 @@ def fitness_func(ga_instance, solution, solution_idx):
ga_instance = pygad.GA(
num_generations=num_generations,
num_parents_mating=num_parents_mating,
fitness_func=fitness_function,
fitness_func=fitness_func,
sol_per_pop=sol_per_pop,
num_genes=num_genes,
init_range_low=init_range_low,
Expand All @@ -66,9 +87,14 @@ def fitness_func(ga_instance, solution, solution_idx):
prediction = net(torch.tensor([solution]).float()).detach().numpy().flatten()
solution, prediction = dataset.denormalization(solution, prediction)
solution[0] = solution[0] * 360
logger.info(f"Parameters of the best solution : {solution}")
logger.info(f"Predicted output based on the best solution : {prediction}")
logger.info(f"Fitness value of the best solution : {solution_fitness}")
# logger.info(f"Parameters of the best solution : {solution}")
params_dict = dict(zip(cfg["cell"].keys(), solution))
logger.info(f"BEST Parameters : {params_dict}")
logger.info(f"Predicted Band Structure : {prediction}")
logger.info(f"Fitness: {solution_fitness}")
logger.info(
f"Fitness value of the best solution denormalization= {solution_fitness * (dataset.res_max - dataset.res_min) + dataset.res_min:.5e}"
f"Fitness (Denormal)= {solution_fitness * (dataset.res_max - dataset.res_min) + dataset.res_min:.5e}"
)


# TODO: auto evaluate
35 changes: 21 additions & 14 deletions src/comsol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def dump(self, path: PathLike | str):


class Trainer:
class EarlyStop(Exception):
pass

def __init__(self, dataset, model, cfg: Config, ckpt_path):
self.model = model
self.cfg = cfg
Expand All @@ -87,6 +90,9 @@ def __init__(self, dataset, model, cfg: Config, ckpt_path):
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)
self.best_loss = float("inf")
self.best_ckpt = model.state_dict()
self.stuck_count = 0
self.early_stop = cfg["train"]["early_stop"]

self.ckpt_path = ckpt_path

dataset_size = len(dataset)
Expand All @@ -108,40 +114,40 @@ def train(self):
self.model.train()
self.model = self.to_cuda(self.model)
for epoch in range(1, self.epoch + 1):
for i, (x, y) in track(
enumerate(self.train_loader),
total=len(self.train_loader),
description=f"Epoch {epoch}",
auto_refresh=False,
):
for i, (x, y) in enumerate(self.train_loader):
x, y = self.to_cuda(x), self.to_cuda(y)
self.optimizer.zero_grad()
y_pred = self.model(x)
loss = self.loss(y_pred, y)
loss.backward()
self.optimizer.step()
if i % 10 == 0:
logger.info(f"Epoch {epoch}, iter {i}, loss: {loss.item():.3f}")
if i % 25 == 0:
logger.info(
f"Epoch [{epoch}/{self.cfg['train']['epoch']}], iter {i}, loss: {loss.item():.6f}"
)
self.test()
self.save_ckpt("lastest")
logger.info(f"Training finished, best loss: {self.best_loss:.3f}")
self.save_ckpt(f"best_loss_{self.best_loss:.3f}", best=True)
logger.info(f"Training finished, best loss: {self.best_loss:.6f}")
self.save_ckpt(f"best_loss_{self.best_loss:.6f}", best=True)

def test(self):
self.model.eval()
losses = 0
with torch.no_grad():
for x, y in track(
self.test_loader, description="Testing", auto_refresh=False
):
for x, y in self.test_loader:
x, y = self.to_cuda(x), self.to_cuda(y)
y_pred = self.model(x)
losses += self.loss(y_pred, y)
now_loss = losses / len(self.test_loader)
logger.info(f"Test loss: {now_loss:.3f}")
logger.info(f"Test loss: {now_loss:.6f}")
if now_loss < self.best_loss:
self.stuck_count = 0
self.best_loss = now_loss
self.best_ckpt = self.model.state_dict()
else:
self.stuck_count += 1
if self.stuck_count >= self.early_stop:
raise self.EarlyStop

def save_ckpt(self, name, best=False):
ckpt_path = Path(self.ckpt_path) / f"{self.start_time:%Y.%m.%d_%H.%M.%S}"
Expand Down Expand Up @@ -283,6 +289,7 @@ def __init__(self, saved_path: PathLike | str, cfg: Config):
with open(pkl, "rb") as f:
data: tuple[dict[str, str], np.ndarray] = pickle.load(f)
params, res = data
# 把第一个位置的参数归一化
params_list.append(self.to_rand(params))
res_arr_list.append(self.get_Bs(res, cfg["dataset"]["sampler"]))

Expand Down

0 comments on commit bea83d7

Please sign in to comment.