Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Jan 7, 2025
1 parent 87d2c29 commit 0e6d384
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 91 deletions.
12 changes: 6 additions & 6 deletions R/build_fanhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ build_fanhmm <- function(
!is.null(autoregression_formula) || !is.null(feedback_formula),
"Provide {.arg autoregression_formula} and/or {.arg feedback_formula} for FAN-HMM."
)
stopifnot_(
inherits(autoregression_formula, "formula"),
"Argument {.arg autoregression_formula} must be a {.cls formula} object.")
stopifnot_(
inherits(feedback_formula, "formula"),
"Argument {.arg feedback_formula} must be a {.cls formula} object.")
out <- create_base_nhmm(
observations, data, time, id, n_states, state_names, channel_names = NULL,
initial_formula, transition_formula, emission_formula)
Expand All @@ -39,6 +33,9 @@ build_fanhmm <- function(
)
np_rho_A <- 0
} else {
stopifnot_(
inherits(feedback_formula, "formula"),
"Argument {.arg feedback_formula} must be a {.cls formula} object.")
W_A <- model_matrix_feedback_formula(
feedback_formula, data,
out$model$n_sequences,
Expand All @@ -65,6 +62,9 @@ build_fanhmm <- function(
)
np_rho_B <- 0
} else {
stopifnot_(
inherits(autoregression_formula, "formula"),
"Argument {.arg autoregression_formula} must be a {.cls formula} object.")
W_B <- model_matrix_autoregression_formula(
autoregression_formula, data,
out$model$n_sequences,
Expand Down
12 changes: 6 additions & 6 deletions R/create_initial_values.R
Original file line number Diff line number Diff line change
Expand Up @@ -303,17 +303,17 @@ create_inits_matrix <- function(x, n, m, K, sd = 0) {
}

create_rho_A_inits <- function(x, S, M, L, init_sd = 0) {
if (is.null(x$rho_A)) {
if (is.null(x)) {
create_rho_A(numeric((S - 1) * L * (M - 1) * S), S, M, L, init_sd)
} else {
stopifnot_(
length(unlist(x$rho_A)) == (S - 1) * L * (M - 1) * S,
length(unlist(x)) == (S - 1) * L * (M - 1) * S,
paste0(
"Number of initial values for {.val rho_A} is not equal to ",
"(S - 1) * L * (M - 1) * S = {(S - 1) * L * (M - 1) * S}."
)
)
create_rho_A(unlist(x$rho_A), S, M, L, init_sd)
create_rho_A(unlist(x), S, M, L, init_sd)
}
}

Expand All @@ -325,17 +325,17 @@ create_rho_A <- function(x, S, M, L, init_sd = 0) {
}

create_rho_B_inits <- function(x, S, M, L, init_sd = 0) {
if (is.null(x$rho_B)) {
if (is.null(x)) {
create_rho_B(numeric((M - 1) * L * (M - 1) * S), S, M, L, init_sd)
} else {
stopifnot_(
length(unlist(x$rho_B)) == (M - 1) * L * (M - 1) * S,
length(unlist(x)) == (M - 1) * L * (M - 1) * S,
paste0(
"Number of initial values for {.val rho_B} is not equal to ",
"(M - 1) * L * (M - 1) * S = {(M - 1) * L * (M - 1) * S}."
)
)
create_rho_B(unlist(x$rho_B), S, M, L, init_sd)
create_rho_B(unlist(x), S, M, L, init_sd)
}
}

Expand Down
12 changes: 6 additions & 6 deletions R/dnm_fanhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
init <- unlist(
c(
create_initial_values(inits, model, init_sd),
create_rho_A_inits(inits, S, M, L_A, init_sd),
create_rho_B_inits(inits, S, M, L_B, init_sd)
create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd),
create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd)
)
)
fit <- nloptr(
Expand Down Expand Up @@ -133,8 +133,8 @@ dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
init <- unlist(
c(
create_initial_values(inits, model, init_sd),
create_rho_A_inits(inits, S, M, L_A, init_sd),
create_rho_B_inits(inits, S, M, L_B, init_sd)
create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd),
create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd)
)
)
} else {
Expand All @@ -148,8 +148,8 @@ dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda, bound, control,
init <- unlist(
c(
create_initial_values(inits, model, init_sd),
create_rho_A_inits(inits, S, M, L_A, init_sd),
create_rho_B_inits(inits, S, M, L_B, init_sd)
create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd),
create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd)
)
)
}
Expand Down
12 changes: 6 additions & 6 deletions R/em_dnm_fanhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ em_dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda,
out <- future.apply::future_lapply(seq_len(restarts), function(i) {
init <- c(
create_initial_values(inits, model, init_sd),
rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd))
rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd))
)
fit <- EM_LBFGS_fanhmm_singlechannel(
init$eta_pi, model$X_pi, init$eta_A, model$X_A, init$eta_B, model$X_B,
Expand Down Expand Up @@ -149,8 +149,8 @@ em_dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda,
em_return_code <- return_codes[1] + 1000
init <- c(
create_initial_values(inits, model, init_sd),
rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd))
rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd))
)
} else {
em_return_code <- 0 # generic success
Expand All @@ -175,8 +175,8 @@ em_dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda,
} else {
init <- c(
create_initial_values(inits, model, init_sd),
rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd))
rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd))
)

out <- EM_LBFGS_fanhmm_singlechannel(
Expand Down
8 changes: 4 additions & 4 deletions R/em_fanhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ em_fanhmm <- function(model, inits, init_sd, restarts, lambda,
out <- future.apply::future_lapply(seq_len(restarts), function(i) {
init <- c(
create_initial_values(inits, model, init_sd),
rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd))
rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd))
)
fit <- EM_LBFGS_fanhmm_singlechannel(
init$eta_pi, model$X_pi, init$eta_A, model$X_A, init$eta_B, model$X_B,
Expand Down Expand Up @@ -78,8 +78,8 @@ em_fanhmm <- function(model, inits, init_sd, restarts, lambda,
} else {
init <- c(
create_initial_values(inits, model, init_sd),
rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd))
rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)),
rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd))
)
}
out <- EM_LBFGS_fanhmm_singlechannel(
Expand Down
8 changes: 4 additions & 4 deletions R/fit_fanhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ fit_fanhmm <- function(model, inits, init_sd, restarts, lambda, method,
model$gammas$B <- eta_to_gamma_cube_field(model$etas$B)
}
model$rhos$A <- create_rho_A_inits(
inits, model$n_states, model$n_symbols, nrow(model$W_A), init_sd
inits$rho_A, model$n_states, model$n_symbols, nrow(model$W_A), init_sd
)
model$phis$A <- rho_to_phi(model$rhos$A)
model$phis$A <- rho_to_phi_field(model$rhos$A)
model$rhos$B <- create_rho_B_inits(
inits, model$n_states, model$n_symbols, nrow(model$W_B), init_sd
inits$rho_B, model$n_states, model$n_symbols, nrow(model$W_B), init_sd
)
model$phis$B <- rho_to_phi(model$rhos$B)
model$phis$B <- rho_to_phi_field(model$rhos$B)
return(model)
}
if (method == "EM-DNM") {
Expand Down
5 changes: 4 additions & 1 deletion R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
}

if (isTRUE(control$maxeval < 0)) {
model$etas <- create_initial_values(inits, model, init_sd)
model$etas <- setNames(
create_initial_values(inits, model, init_sd),
c("pi","A", "B")
)
model$gammas$pi <- eta_to_gamma_mat(model$etas$pi)
model$gammas$A <- eta_to_gamma_cube(model$etas$A)
if (model$n_channels == 1L) {
Expand Down
14 changes: 7 additions & 7 deletions R/get_probs.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ get_transition_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) {
ids <- rownames(model$observations[[1]])
times <- as.numeric(colnames(model$observations[[1]]))
}
if (!attr(model$X_A, "iv")) {
if (!attr(model$X_A, "iv") && !isTRUE(attr(model$W_A, "iv"))) {
X <- model$X_A[, , 1L, drop = FALSE]
ids <- "all"
} else {
Expand All @@ -115,7 +115,7 @@ get_transition_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) {
state_from = model$state_names,
state_to = rep(model$state_names, each = S),
estimate = unlist(get_A_all(
model$gammas$A, X, attr(model$X_A, "tv")
model$gammas$A, X, attr(model$X_A, "tv") || isTRUE(attr(model$W_A, "tv"))
))
)
d <- stats::setNames(
Expand All @@ -135,7 +135,7 @@ get_transition_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) {
)
qs <- get_A_qs(
model$boot$gamma_A,
X, attr(model$X_A, "tv"), probs
X, attr(model$X_A, "tv") || isTRUE(attr(model$W_A, "tv")), probs
)
for(i in seq_along(probs)) {
d[paste0("q", 100 * probs[i])] <- qs[, i]
Expand Down Expand Up @@ -192,7 +192,7 @@ get_emission_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) {
times <- as.numeric(colnames(model$observations[[1]]))
symbol_names <- model$symbol_names
}
if (!attr(model$X_B, "iv")) {
if (!attr(model$X_B, "iv") && !isTRUE(attr(model$W_B, "iv"))) {
X <- model$X_B[, , 1L, drop = FALSE]
ids <- "all"
} else {
Expand All @@ -208,7 +208,7 @@ get_emission_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) {
channel = model$channel_names[i],
observation = rep(symbol_names[[i]], each = S),
estimate = unlist(get_B_all(
model$gammas$B[[i]], X, attr(model$X_B, "tv")
model$gammas$B[[i]], X, attr(model$X_B, "tv") || isTRUE(attr(model$W_B, "tv"))
))
)
})
Expand All @@ -229,15 +229,15 @@ get_emission_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) {
if (C == 1) {
qs <- get_B_qs(
model$boot$gamma_B,
X, attr(model$X_B, "tv"), probs
X, attr(model$X_B, "tv") || isTRUE(attr(model$W_B, "tv")), probs
)
} else {
qs <- do.call(
rbind,
lapply(seq_len(C), function(i) {
get_B_qs(
lapply(model$boot$gamma_B, "[[", i),
X, attr(model$X_B, "tv"), probs
X, attr(model$X_B, "tv") || isTRUE(attr(model$W_B, "tv")), probs
)
})
)
Expand Down
66 changes: 40 additions & 26 deletions src/fanhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,7 @@ void fanhmm_sc::mstep_A(const double ftol_abs, const double ftol_rel,
auto* self = static_cast<fanhmm_sc*>(data);
arma::vec x_vec(const_cast<double*>(x), n, false, true);
arma::vec grad_vec(grad, n, false, true);
if (self->L_A > 0) {
return self->objective_A(x_vec, grad_vec);
} else {
return self->nhmm_base::objective_A(x_vec, grad_vec);
}
return self->objective_A(x_vec, grad_vec);
};

arma::uword n_eta = eta_A.slice(0).n_elem;
Expand Down Expand Up @@ -234,11 +230,7 @@ void fanhmm_sc::mstep_B(const double ftol_abs, const double ftol_rel,
auto* self = static_cast<fanhmm_sc*>(data);
arma::vec x_vec(const_cast<double*>(x), n, false, true);
arma::vec grad_vec(grad, n, false, true);
if (self->L_B > 0) {
return self->objective_B(x_vec, grad_vec);
} else {
return self->nhmm_sc::objective_B(x_vec, grad_vec);
}
return self->objective_B(x_vec, grad_vec);
};
arma::uword n_eta = eta_B.slice(0).n_elem;
arma::uword n_rho = rho_B(0).n_elem;
Expand Down Expand Up @@ -330,13 +322,17 @@ Rcpp::List EM_LBFGS_fanhmm_singlechannel(
pars.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t();
pars.cols(n_pi + n_A, n_pi + n_A + n_B - 1) = arma::vectorise(model.eta_B).t();
arma::uword ii = n_pi + n_A + n_B;
for (arma::uword s = 0; s < S; s++) {
pars.cols(ii, ii + n_Ar - 1) = arma::vectorise(model.rho_A(s)).t();
ii += n_Ar;
if (n_Ar > 0) {
for (arma::uword s = 0; s < S; s++) {
pars.cols(ii, ii + n_Ar - 1) = arma::vectorise(model.rho_A(s)).t();
ii += n_Ar;
}
}
for (arma::uword s = 0; s < S; s++) {
pars.cols(ii, ii + n_Br - 1) = arma::vectorise(model.rho_B(s)).t();
ii += n_Br;
if (n_Br > 0) {
for (arma::uword s = 0; s < S; s++) {
pars.cols(ii, ii + n_Br - 1) = arma::vectorise(model.rho_B(s)).t();
ii += n_Br;
}
}
double relative_change = ftol_rel + 1.0;
double absolute_change = ftol_abs + 1.0;
Expand Down Expand Up @@ -411,19 +407,33 @@ Rcpp::List EM_LBFGS_fanhmm_singlechannel(
model.mstep_return_code, model, iter, relative_change,
absolute_change, absolute_x_change, relative_x_change);
}
model.mstep_A(
ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound,
print_level_m
);
if (model.L_A > 0) {
model.mstep_A(
ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound,
print_level_m
);
} else {
model.nhmm_base::mstep_A(
ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound,
print_level_m
);
}
if (model.mstep_return_code != 0) {
return mstep_error_fanhmm(
model.mstep_return_code, model, iter, relative_change,
absolute_change, absolute_x_change, relative_x_change);
}
if (model.L_B > 0) {
model.mstep_B(
ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound,
print_level_m
);
} else {
model.nhmm_sc::mstep_B(
ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound,
print_level_m
);
}
if (model.mstep_return_code != 0) {
return mstep_error_fanhmm(
model.mstep_return_code, model, iter, relative_change,
Expand Down Expand Up @@ -465,13 +475,17 @@ Rcpp::List EM_LBFGS_fanhmm_singlechannel(
pars_new.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t();
pars_new.cols(n_pi + n_A, n_pi + n_A + n_B - 1) = arma::vectorise(model.eta_B).t();
ii = n_pi + n_A + n_B;
for (arma::uword s = 0; s < S; s++) {
pars_new.cols(ii, ii + n_Ar - 1) = arma::vectorise(model.rho_A(s)).t();
ii += n_Ar;
if (n_Ar > 0) {
for (arma::uword s = 0; s < S; s++) {
pars_new.cols(ii, ii + n_Ar - 1) = arma::vectorise(model.rho_A(s)).t();
ii += n_Ar;
}
}
for (arma::uword s = 0; s < S; s++) {
pars_new.cols(ii, ii + n_Br - 1) = arma::vectorise(model.rho_B(s)).t();
ii += n_Br;
if (n_Br > 0) {
for (arma::uword s = 0; s < S; s++) {
pars_new.cols(ii, ii + n_Br - 1) = arma::vectorise(model.rho_B(s)).t();
ii += n_Br;
}
}

penalty_term = 0.5 * lambda * std::pow(arma::norm(pars_new, 2), 2);
Expand Down
1 change: 1 addition & 0 deletions src/fanhmm_sc.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ struct fanhmm_sc : public nhmm_sc {
A.each_slice() = Atmp.t();
} else {
if (tv_A) {
// A_t = P(z_t+1 | z_t, y_t), pi = P(z_1)
for (arma::uword t = 0; t < Ti(i); t++) { // time
for (arma::uword s = 0; s < S; s ++) { // from states
Atmp.col(s) = softmax(
Expand Down
Loading

0 comments on commit 0e6d384

Please sign in to comment.