Skip to content

Commit

Permalink
update and fix
Browse files Browse the repository at this point in the history
  • Loading branch information
BanananaFish committed Apr 1, 2024
1 parent b4bf25a commit 688705d
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 11 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/exe_build_action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ jobs:
spec: 'src/comsol/cmdline.py'
requirements: 'requirements.txt'
exe_path: 'dist/windows'
upload_exe_with_name: 'consol_CLI'
options: --onefile, --name "consol_CLI"
upload_exe_with_name: 'comsol_CLI'
options: --onefile, --name "comsol_CLI"
- name: create release
id: create_release
uses: ncipollo/[email protected]
Expand All @@ -40,5 +40,5 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
files: |
dist/windows/consol_CLI.exe
dist/windows/comsol_CLI.exe
dist/windows/README.md
2 changes: 1 addition & 1 deletion config/cell.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ cell:
step: 0.001

train:
epoch: 500
epoch: 200
batch_size: 16
lr: 0.001
21 changes: 21 additions & 0 deletions config/cell_local.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
cell:
zeta:
init: 0.0005
max: 180
min: 0
step: 30
rr:
init: 0.001
max: 0.0045
min: 0.0035
step: 0.0005
rrr:
init: 0.001
max: 0.003
min: 0.001
step: 0.001

train:
epoch: 300
batch_size: 8
lr: 0.001
5 changes: 4 additions & 1 deletion src/comsol/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def train(saved, config, ckpt_path):
dataset = BandDataset(saved)
model = MLP()
trainer = Trainer(dataset, model, cfg)
trainer.train()
try:
trainer.train()
except KeyboardInterrupt:
trainer.save_ckpt(f"earlystop_best_{trainer.best_loss:.3f}.pth", best=True)


@main.command()
Expand Down
10 changes: 7 additions & 3 deletions src/comsol/ga.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy
import pygad
import torch

from loguru import logger

from comsol.model import MLP
from comsol.utils import BandDataset

Expand All @@ -19,7 +19,7 @@ def fitness_func(ga_instance, solution, solution_idx):
Bs: numpy.ndarray = (
net(torch.tensor([solution]).float()).detach().numpy().flatten()
)
return abs(Bs[2] - Bs[0])
return abs((Bs[1] + Bs[2]) / 2 - Bs[0])

fitness_function = fitness_func

Expand Down Expand Up @@ -57,6 +57,10 @@ def fitness_func(ga_instance, solution, solution_idx):

prediction = net(torch.tensor([solution]).float()).detach().numpy().flatten()
solution, prediction = dataset.denormalization(solution, prediction)
solution[0] = solution[0] * 360
logger.info(f"Parameters of the best solution : {solution}")
logger.info(f"Predicted output based on the best solution : {prediction}")
logger.info(f"Fitness value of the best solution = {solution_fitness}")
logger.info(f"Fitness value of the best solution : {solution_fitness}")
logger.info(
f"Fitness value of the best solution denormalization= {solution_fitness * (dataset.res_max - dataset.res_min) + dataset.res_min:.5e}"
)
10 changes: 7 additions & 3 deletions src/comsol/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import pickle
from dataclasses import dataclass
from datetime import datetime
Expand Down Expand Up @@ -124,7 +125,7 @@ def train(self):
self.test()
self.save_ckpt("lastest")
logger.info(f"Training finished, best loss: {self.best_loss:.3f}")
self.save_ckpt(f"best_loss_{self.best_loss:.3f}")
self.save_ckpt(f"best_loss_{self.best_loss:.3f}", best=True)

def test(self):
self.model.eval()
Expand All @@ -142,7 +143,7 @@ def test(self):
self.best_loss = now_loss
self.best_ckpt = self.model.state_dict()

def save_ckpt(self, name):
def save_ckpt(self, name, best=False):
ckpt_path = Path(f"ckpt") / f"{self.start_time:%Y.%m.%d_%H.%M.%S}"
ckpt_path.mkdir(parents=True, exist_ok=True)

Expand All @@ -151,7 +152,10 @@ def save_ckpt(self, name):
if not cfg_path.exists():
self.cfg.dump(cfg_path)
logger.info(f"Dumped config to {cfg_path}")
torch.save(self.model.state_dict(), pth_path)
if best:
torch.save(self.best_ckpt, pth_path)
else:
torch.save(self.model.state_dict(), pth_path)
logger.info(f"Saved model to {pth_path}")


Expand Down

0 comments on commit 688705d

Please sign in to comment.