-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a7fc14b
commit 3541ac8
Showing
7 changed files
with
315 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
// [[Rcpp::depends(ast2ast)]] | ||
// [[Rcpp::depends(RcppArmadillo)]] | ||
// [[Rcpp::plugins(cpp2a)]] | ||
#include "etr.hpp" | ||
|
||
using namespace Rcpp; | ||
using namespace arma; | ||
|
||
// Define the PSO function | ||
// [[Rcpp::export]] | ||
List pso(const vec &lb, const vec &ub, Function loss, int ngen, int npop, | ||
double error_threshold) { | ||
int npar = lb.n_elem; | ||
mat swarm = randu<mat>(npop, npar); | ||
mat v = zeros<mat>(npop, npar); | ||
vec swarm_bests = zeros<vec>(npop); | ||
vec swarm_errors = zeros<vec>(npop); | ||
|
||
const double initial_cog = 2.5; | ||
const double final_cog = 0.5; | ||
const double initial_soc = 0.5; | ||
const double final_soc = 2.5; | ||
const double w_max = 0.9; | ||
const double w_min = 0.4; | ||
|
||
for (int i = 0; i < npop; ++i) { | ||
swarm.row(i) = (randu<rowvec>(npar) % (ub.t() - lb.t())) + lb.t(); | ||
swarm_errors[i] = as<double>(loss(swarm.row(i).t())); | ||
swarm_bests[i] = swarm_errors[i]; | ||
} | ||
|
||
uword global_best = index_min(swarm_bests); | ||
rowvec global_best_vec = swarm.row(global_best); | ||
double global_best_error = swarm_bests[global_best]; | ||
mat swarm_best_params = swarm; | ||
|
||
auto correctBounds = [](const rowvec &target, const vec &lb, const vec &ub) { | ||
rowvec res(target.size()); | ||
for (size_t i = 0; i < lb.size(); i++) { | ||
if (target(i) < lb(i)) { | ||
res(i) = lb(i); | ||
} else if (target(i) > ub(i)) { | ||
res(i) = ub(i); | ||
} else { | ||
res(i) = target(i); | ||
} | ||
} | ||
return res; | ||
}; | ||
|
||
int iter = 0; | ||
while (iter < ngen) { | ||
double w = w_max - iter * (w_max - w_min) / ngen; | ||
double cog = initial_cog - (initial_cog - final_cog) * (iter + 1) / ngen; | ||
double soc = initial_soc - (initial_soc - final_soc) * (iter + 1) / ngen; | ||
|
||
for (int i = 0; i < npop; ++i) { | ||
rowvec local_best_vec = swarm.row(global_best); | ||
|
||
v.row(i) = | ||
w * v.row(i) + | ||
cog * randu<double>() * (swarm_best_params.row(i) - swarm.row(i)) + | ||
soc * randu<double>() * (local_best_vec - swarm.row(i)); | ||
swarm.row(i) += v.row(i); | ||
|
||
swarm.row(i) = correctBounds(swarm.row(i), lb, ub); | ||
|
||
double error = as<double>(loss(swarm.row(i).t())); | ||
|
||
if (!std::isinf(error) && !std::isnan(error) && error < swarm_bests[i]) { | ||
swarm_bests[i] = error; | ||
swarm_best_params.row(i) = swarm.row(i); | ||
} | ||
if (!std::isinf(error) && !std::isnan(error) && | ||
error < global_best_error) { | ||
global_best = i; | ||
global_best_vec = swarm.row(i); | ||
global_best_error = error; | ||
} | ||
} | ||
|
||
iter++; | ||
|
||
if (global_best_error < error_threshold) { | ||
break; | ||
} | ||
} | ||
|
||
double insilico = as<double>(loss(global_best_vec.t())); | ||
return List::create(Named("insilico") = insilico, | ||
Named("global_best_vec") = global_best_vec); | ||
} | ||
|
||
typedef double (*FP)(etr::Vec<double> &a); | ||
|
||
// [[Rcpp::export]] | ||
List pso_xptr(const vec &lb, const vec &ub, Rcpp::XPtr<FP> loss_xptr, int ngen, | ||
int npop, double error_threshold) { | ||
FP loss = *loss_xptr; | ||
int npar = lb.n_elem; | ||
mat swarm = randu<mat>(npop, npar); | ||
mat v = zeros<mat>(npop, npar); | ||
vec swarm_bests = zeros<vec>(npop); | ||
vec swarm_errors = zeros<vec>(npop); | ||
|
||
const double initial_cog = 2.5; | ||
const double final_cog = 0.5; | ||
const double initial_soc = 0.5; | ||
const double final_soc = 2.5; | ||
const double w_max = 0.9; | ||
const double w_min = 0.4; | ||
|
||
etr::Vec<double> param(etr::SI{npar}); | ||
|
||
for (int i = 0; i < npop; ++i) { | ||
swarm.row(i) = (randu<rowvec>(npar) % (ub.t() - lb.t())) + lb.t(); | ||
for (size_t j = 0; j < npar; j++) { | ||
param[j] = swarm.row(i)(j); | ||
} | ||
swarm_errors[i] = loss(param); | ||
swarm_bests[i] = swarm_errors[i]; | ||
} | ||
|
||
uword global_best = index_min(swarm_bests); | ||
rowvec global_best_vec = swarm.row(global_best); | ||
double global_best_error = swarm_bests[global_best]; | ||
mat swarm_best_params = swarm; | ||
|
||
auto correctBounds = [](const rowvec &target, const vec &lb, const vec &ub) { | ||
rowvec res(target.size()); | ||
for (size_t i = 0; i < lb.size(); i++) { | ||
if (target(i) < lb(i)) { | ||
res(i) = lb(i); | ||
} else if (target(i) > ub(i)) { | ||
res(i) = ub(i); | ||
} else { | ||
res(i) = target(i); | ||
} | ||
} | ||
return res; | ||
}; | ||
|
||
int iter = 0; | ||
while (iter < ngen) { | ||
double w = w_max - iter * (w_max - w_min) / ngen; | ||
double cog = initial_cog - (initial_cog - final_cog) * (iter + 1) / ngen; | ||
double soc = initial_soc - (initial_soc - final_soc) * (iter + 1) / ngen; | ||
|
||
for (int i = 0; i < npop; ++i) { | ||
rowvec local_best_vec = swarm.row(global_best); | ||
|
||
v.row(i) = | ||
w * v.row(i) + | ||
cog * randu<double>() * (swarm_best_params.row(i) - swarm.row(i)) + | ||
soc * randu<double>() * (local_best_vec - swarm.row(i)); | ||
swarm.row(i) += v.row(i); | ||
|
||
swarm.row(i) = correctBounds(swarm.row(i), lb, ub); | ||
for (size_t j = 0; j < npar; j++) { | ||
param[j] = swarm.row(i)(j); | ||
} | ||
|
||
double error = loss(param); | ||
|
||
if (!std::isinf(error) && !std::isnan(error) && error < swarm_bests[i]) { | ||
swarm_bests[i] = error; | ||
swarm_best_params.row(i) = swarm.row(i); | ||
} | ||
if (!std::isinf(error) && !std::isnan(error) && | ||
error < global_best_error) { | ||
global_best = i; | ||
global_best_vec = swarm.row(i); | ||
global_best_error = error; | ||
} | ||
} | ||
|
||
iter++; | ||
|
||
if (global_best_error < error_threshold) { | ||
break; | ||
} | ||
} | ||
|
||
for (size_t j = 0; j < npar; j++) { | ||
param[j] = global_best_vec(j); | ||
} | ||
double insilico = loss(param); | ||
return List::create(Named("insilico") = insilico, | ||
Named("global_best_vec") = global_best_vec); | ||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
setwd("/home/konrad/Documents/GitHub/RProjects/ast2ast_supplement/ast2ast/.development/") | ||
Rcpp::sourceCpp("PSO.cpp") | ||
|
||
|
||
rosenbrock <- function(parameter) { | ||
value <- 0 | ||
for (i in 1:(length(parameter) - 1)) { | ||
value <- value + | ||
100 * (parameter[i + 1] - parameter[i]^2)^2 + | ||
(1 - parameter[i])^2 | ||
} | ||
return(value) | ||
} | ||
|
||
|
||
lb <- -10000 | ||
ub <- 10000 | ||
error_threshold <- 0.0000001 | ||
npop <- 400 | ||
|
||
pso(rep(lb, 3), rep(ub, 3), rosenbrock, 10000, npop, error_threshold) | ||
|
||
|
||
rosenbrock_cpp <- function(parameter) { | ||
value::double <- 0 | ||
sz::int <- length(parameter) | ||
for (i in 1:(sz - 1)) { | ||
value <- value + | ||
100 * (parameter[i + 1] - parameter[i]^2)^2 + | ||
(1 - parameter[i])^2 | ||
} | ||
return(value) | ||
} | ||
|
||
rosenbrock_cpp <- ast2ast::translate(rosenbrock_cpp, | ||
output = "XPtr", | ||
handle_inputs = "", | ||
references = TRUE, | ||
verbose = TRUE | ||
) | ||
|
||
pso_xptr(rep(lb, 3), rep(ub, 3), rosenbrock_cpp, 10000, npop, error_threshold) | ||
|
||
|
||
res <- microbenchmark::microbenchmark( | ||
pso(rep(lb, 3), rep(ub, 3), rosenbrock, 10000, npop, error_threshold), | ||
pso_xptr(rep(lb, 3), rep(ub, 3), rosenbrock_cpp, 10000, npop, error_threshold) | ||
) | ||
res | ||
|
||
boxplot(res) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/* file lorenz.c */ | ||
#include <R.h> | ||
static double parms[3]; | ||
#define a parms[0] | ||
#define b parms[1] | ||
#define c parms[2] | ||
|
||
/* initializer */ | ||
void initmod(void (*odeparms)(int *, double *)) { | ||
int N = 3; | ||
odeparms(&N, parms); | ||
} | ||
|
||
/* Derivatives */ | ||
void derivs(int *neq, double *t, double *y, double *ydot, double *yout, | ||
int *ip) { | ||
ydot[0] = a * y[0] + y[1] * y[2]; | ||
ydot[1] = b * (y[1] - y[2]); | ||
ydot[2] = -y[0] * y[1] + c * y[1] - y[2]; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// [[Rcpp::depends(ast2ast, RcppArmadillo)]] | ||
// [[Rcpp::plugins(cpp20)]] | ||
|
||
#include "etr.hpp" | ||
using namespace Rcpp; | ||
using namespace etr; | ||
|
||
static double parms[3]; | ||
#define a parms[0] | ||
#define b parms[1] | ||
#define c parms[2] | ||
|
||
extern "C" { | ||
void initmod_a2a(void (*odeparms)(int *, double *)); | ||
} | ||
|
||
extern "C" { | ||
void derivs_a2a(int *neq, double *t, double *y, double *ydot, double *yout, | ||
int *ip); | ||
} | ||
|
||
void initmod_a2a(void (*odeparms)(int *, double *)) { | ||
int N = 3; | ||
odeparms(&N, parms); | ||
} | ||
|
||
void user_fct(double t, BorrowPtr &y, BorrowPtr &ydot) { | ||
at(ydot, 1) = a * at(y, 1) + at(y, 2) * at(y, 3); | ||
at(ydot, 2) = b * (at(y, 2) - at(y, 3)); | ||
at(ydot, 3) = -at(y, 1) * at(y, 2) + c * at(y, 2) - at(y, 3); | ||
} | ||
|
||
void derivs_a2a(int *neq, double *t, double *y, double *ydot, double *yout, | ||
int *ip) { | ||
BorrowPtr y_(y, *neq); | ||
BorrowPtr ydot_(ydot, *neq); | ||
double t_ = *t; | ||
user_fct(t_, y_, ydot_); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters