diff --git a/R/build_fanhmm.R b/R/build_fanhmm.R index f3b1d9d..26a60a5 100644 --- a/R/build_fanhmm.R +++ b/R/build_fanhmm.R @@ -10,12 +10,6 @@ build_fanhmm <- function( !is.null(autoregression_formula) || !is.null(feedback_formula), "Provide {.arg autoregression_formula} and/or {.arg feedback_formula} for FAN-HMM." ) - stopifnot_( - inherits(autoregression_formula, "formula"), - "Argument {.arg autoregression_formula} must be a {.cls formula} object.") - stopifnot_( - inherits(feedback_formula, "formula"), - "Argument {.arg feedback_formula} must be a {.cls formula} object.") out <- create_base_nhmm( observations, data, time, id, n_states, state_names, channel_names = NULL, initial_formula, transition_formula, emission_formula) @@ -39,6 +33,9 @@ build_fanhmm <- function( ) np_rho_A <- 0 } else { + stopifnot_( + inherits(feedback_formula, "formula"), + "Argument {.arg feedback_formula} must be a {.cls formula} object.") W_A <- model_matrix_feedback_formula( feedback_formula, data, out$model$n_sequences, @@ -65,6 +62,9 @@ build_fanhmm <- function( ) np_rho_B <- 0 } else { + stopifnot_( + inherits(autoregression_formula, "formula"), + "Argument {.arg autoregression_formula} must be a {.cls formula} object.") W_B <- model_matrix_autoregression_formula( autoregression_formula, data, out$model$n_sequences, diff --git a/R/create_initial_values.R b/R/create_initial_values.R index 16f2973..4464d3d 100644 --- a/R/create_initial_values.R +++ b/R/create_initial_values.R @@ -303,17 +303,17 @@ create_inits_matrix <- function(x, n, m, K, sd = 0) { } create_rho_A_inits <- function(x, S, M, L, init_sd = 0) { - if (is.null(x$rho_A)) { + if (is.null(x)) { create_rho_A(numeric((S - 1) * L * (M - 1) * S), S, M, L, init_sd) } else { stopifnot_( - length(unlist(x$rho_A)) == (S - 1) * L * (M - 1) * S, + length(unlist(x)) == (S - 1) * L * (M - 1) * S, paste0( "Number of initial values for {.val rho_A} is not equal to ", "(S - 1) * L * (M - 1) * S = {(S - 1) * L * (M - 1) * S}." ) ) - create_rho_A(unlist(x$rho_A), S, M, L, init_sd) + create_rho_A(unlist(x), S, M, L, init_sd) } } @@ -325,17 +325,17 @@ create_rho_A <- function(x, S, M, L, init_sd = 0) { } create_rho_B_inits <- function(x, S, M, L, init_sd = 0) { - if (is.null(x$rho_B)) { + if (is.null(x)) { create_rho_B(numeric((M - 1) * L * (M - 1) * S), S, M, L, init_sd) } else { stopifnot_( - length(unlist(x$rho_B)) == (M - 1) * L * (M - 1) * S, + length(unlist(x)) == (M - 1) * L * (M - 1) * S, paste0( "Number of initial values for {.val rho_B} is not equal to ", "(M - 1) * L * (M - 1) * S = {(M - 1) * L * (M - 1) * S}." ) ) - create_rho_B(unlist(x$rho_B), S, M, L, init_sd) + create_rho_B(unlist(x), S, M, L, init_sd) } } diff --git a/R/dnm_fanhmm.R b/R/dnm_fanhmm.R index ff40605..0e5b4f9 100644 --- a/R/dnm_fanhmm.R +++ b/R/dnm_fanhmm.R @@ -95,8 +95,8 @@ dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, init <- unlist( c( create_initial_values(inits, model, init_sd), - create_rho_A_inits(inits, S, M, L_A, init_sd), - create_rho_B_inits(inits, S, M, L_B, init_sd) + create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd), + create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd) ) ) fit <- nloptr( @@ -133,8 +133,8 @@ dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, init <- unlist( c( create_initial_values(inits, model, init_sd), - create_rho_A_inits(inits, S, M, L_A, init_sd), - create_rho_B_inits(inits, S, M, L_B, init_sd) + create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd), + create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd) ) ) } else { @@ -148,8 +148,8 @@ dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda, bound, control, init <- unlist( c( create_initial_values(inits, model, init_sd), - create_rho_A_inits(inits, S, M, L_A, init_sd), - create_rho_B_inits(inits, S, M, L_B, init_sd) + create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd), + create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd) ) ) } diff --git a/R/em_dnm_fanhmm.R b/R/em_dnm_fanhmm.R index eb842d6..bf038f1 100644 --- a/R/em_dnm_fanhmm.R +++ b/R/em_dnm_fanhmm.R @@ -95,8 +95,8 @@ em_dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda, out <- future.apply::future_lapply(seq_len(restarts), function(i) { init <- c( create_initial_values(inits, model, init_sd), - rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)), - rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd)) + rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)), + rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd)) ) fit <- EM_LBFGS_fanhmm_singlechannel( init$eta_pi, model$X_pi, init$eta_A, model$X_A, init$eta_B, model$X_B, @@ -149,8 +149,8 @@ em_dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda, em_return_code <- return_codes[1] + 1000 init <- c( create_initial_values(inits, model, init_sd), - rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)), - rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd)) + rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)), + rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd)) ) } else { em_return_code <- 0 # generic success @@ -175,8 +175,8 @@ em_dnm_fanhmm <- function(model, inits, init_sd, restarts, lambda, } else { init <- c( create_initial_values(inits, model, init_sd), - rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)), - rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd)) + rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)), + rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd)) ) out <- EM_LBFGS_fanhmm_singlechannel( diff --git a/R/em_fanhmm.R b/R/em_fanhmm.R index 40804e6..459e3f1 100644 --- a/R/em_fanhmm.R +++ b/R/em_fanhmm.R @@ -44,8 +44,8 @@ em_fanhmm <- function(model, inits, init_sd, restarts, lambda, out <- future.apply::future_lapply(seq_len(restarts), function(i) { init <- c( create_initial_values(inits, model, init_sd), - rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)), - rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd)) + rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)), + rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd)) ) fit <- EM_LBFGS_fanhmm_singlechannel( init$eta_pi, model$X_pi, init$eta_A, model$X_A, init$eta_B, model$X_B, @@ -78,8 +78,8 @@ em_fanhmm <- function(model, inits, init_sd, restarts, lambda, } else { init <- c( create_initial_values(inits, model, init_sd), - rho_A = list(create_rho_A_inits(inits, S, M, L_A, init_sd)), - rho_B = list(create_rho_B_inits(inits, S, M, L_B, init_sd)) + rho_A = list(create_rho_A_inits(inits$rho_A, S, M, L_A, init_sd)), + rho_B = list(create_rho_B_inits(inits$rho_B, S, M, L_B, init_sd)) ) } out <- EM_LBFGS_fanhmm_singlechannel( diff --git a/R/fit_fanhmm.R b/R/fit_fanhmm.R index c7b0851..04c8b59 100644 --- a/R/fit_fanhmm.R +++ b/R/fit_fanhmm.R @@ -84,13 +84,13 @@ fit_fanhmm <- function(model, inits, init_sd, restarts, lambda, method, model$gammas$B <- eta_to_gamma_cube_field(model$etas$B) } model$rhos$A <- create_rho_A_inits( - inits, model$n_states, model$n_symbols, nrow(model$W_A), init_sd + inits$rho_A, model$n_states, model$n_symbols, nrow(model$W_A), init_sd ) - model$phis$A <- rho_to_phi(model$rhos$A) + model$phis$A <- rho_to_phi_field(model$rhos$A) model$rhos$B <- create_rho_B_inits( - inits, model$n_states, model$n_symbols, nrow(model$W_B), init_sd + inits$rho_B, model$n_states, model$n_symbols, nrow(model$W_B), init_sd ) - model$phis$B <- rho_to_phi(model$rhos$B) + model$phis$B <- rho_to_phi_field(model$rhos$B) return(model) } if (method == "EM-DNM") { diff --git a/R/fit_nhmm.R b/R/fit_nhmm.R index ae538ed..d5a0ba4 100644 --- a/R/fit_nhmm.R +++ b/R/fit_nhmm.R @@ -73,7 +73,10 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method, } if (isTRUE(control$maxeval < 0)) { - model$etas <- create_initial_values(inits, model, init_sd) + model$etas <- setNames( + create_initial_values(inits, model, init_sd), + c("pi","A", "B") + ) model$gammas$pi <- eta_to_gamma_mat(model$etas$pi) model$gammas$A <- eta_to_gamma_cube(model$etas$A) if (model$n_channels == 1L) { diff --git a/R/get_probs.R b/R/get_probs.R index 4086426..be77c34 100644 --- a/R/get_probs.R +++ b/R/get_probs.R @@ -103,7 +103,7 @@ get_transition_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) { ids <- rownames(model$observations[[1]]) times <- as.numeric(colnames(model$observations[[1]])) } - if (!attr(model$X_A, "iv")) { + if (!attr(model$X_A, "iv") && !isTRUE(attr(model$W_A, "iv"))) { X <- model$X_A[, , 1L, drop = FALSE] ids <- "all" } else { @@ -115,7 +115,7 @@ get_transition_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) { state_from = model$state_names, state_to = rep(model$state_names, each = S), estimate = unlist(get_A_all( - model$gammas$A, X, attr(model$X_A, "tv") + model$gammas$A, X, attr(model$X_A, "tv") || isTRUE(attr(model$W_A, "tv")) )) ) d <- stats::setNames( @@ -135,7 +135,7 @@ get_transition_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) { ) qs <- get_A_qs( model$boot$gamma_A, - X, attr(model$X_A, "tv"), probs + X, attr(model$X_A, "tv") || isTRUE(attr(model$W_A, "tv")), probs ) for(i in seq_along(probs)) { d[paste0("q", 100 * probs[i])] <- qs[, i] @@ -192,7 +192,7 @@ get_emission_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) { times <- as.numeric(colnames(model$observations[[1]])) symbol_names <- model$symbol_names } - if (!attr(model$X_B, "iv")) { + if (!attr(model$X_B, "iv") && !isTRUE(attr(model$W_B, "iv"))) { X <- model$X_B[, , 1L, drop = FALSE] ids <- "all" } else { @@ -208,7 +208,7 @@ get_emission_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) { channel = model$channel_names[i], observation = rep(symbol_names[[i]], each = S), estimate = unlist(get_B_all( - model$gammas$B[[i]], X, attr(model$X_B, "tv") + model$gammas$B[[i]], X, attr(model$X_B, "tv") || isTRUE(attr(model$W_B, "tv")) )) ) }) @@ -229,7 +229,7 @@ get_emission_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) { if (C == 1) { qs <- get_B_qs( model$boot$gamma_B, - X, attr(model$X_B, "tv"), probs + X, attr(model$X_B, "tv") || isTRUE(attr(model$W_B, "tv")), probs ) } else { qs <- do.call( @@ -237,7 +237,7 @@ get_emission_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) { lapply(seq_len(C), function(i) { get_B_qs( lapply(model$boot$gamma_B, "[[", i), - X, attr(model$X_B, "tv"), probs + X, attr(model$X_B, "tv") || isTRUE(attr(model$W_B, "tv")), probs ) }) ) diff --git a/src/fanhmm_EM.cpp b/src/fanhmm_EM.cpp index 7cc7c6e..1200bc8 100644 --- a/src/fanhmm_EM.cpp +++ b/src/fanhmm_EM.cpp @@ -81,11 +81,7 @@ void fanhmm_sc::mstep_A(const double ftol_abs, const double ftol_rel, auto* self = static_cast(data); arma::vec x_vec(const_cast(x), n, false, true); arma::vec grad_vec(grad, n, false, true); - if (self->L_A > 0) { - return self->objective_A(x_vec, grad_vec); - } else { - return self->nhmm_base::objective_A(x_vec, grad_vec); - } + return self->objective_A(x_vec, grad_vec); }; arma::uword n_eta = eta_A.slice(0).n_elem; @@ -234,11 +230,7 @@ void fanhmm_sc::mstep_B(const double ftol_abs, const double ftol_rel, auto* self = static_cast(data); arma::vec x_vec(const_cast(x), n, false, true); arma::vec grad_vec(grad, n, false, true); - if (self->L_B > 0) { - return self->objective_B(x_vec, grad_vec); - } else { - return self->nhmm_sc::objective_B(x_vec, grad_vec); - } + return self->objective_B(x_vec, grad_vec); }; arma::uword n_eta = eta_B.slice(0).n_elem; arma::uword n_rho = rho_B(0).n_elem; @@ -330,13 +322,17 @@ Rcpp::List EM_LBFGS_fanhmm_singlechannel( pars.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t(); pars.cols(n_pi + n_A, n_pi + n_A + n_B - 1) = arma::vectorise(model.eta_B).t(); arma::uword ii = n_pi + n_A + n_B; - for (arma::uword s = 0; s < S; s++) { - pars.cols(ii, ii + n_Ar - 1) = arma::vectorise(model.rho_A(s)).t(); - ii += n_Ar; + if (n_Ar > 0) { + for (arma::uword s = 0; s < S; s++) { + pars.cols(ii, ii + n_Ar - 1) = arma::vectorise(model.rho_A(s)).t(); + ii += n_Ar; + } } - for (arma::uword s = 0; s < S; s++) { - pars.cols(ii, ii + n_Br - 1) = arma::vectorise(model.rho_B(s)).t(); - ii += n_Br; + if (n_Br > 0) { + for (arma::uword s = 0; s < S; s++) { + pars.cols(ii, ii + n_Br - 1) = arma::vectorise(model.rho_B(s)).t(); + ii += n_Br; + } } double relative_change = ftol_rel + 1.0; double absolute_change = ftol_abs + 1.0; @@ -411,19 +407,33 @@ Rcpp::List EM_LBFGS_fanhmm_singlechannel( model.mstep_return_code, model, iter, relative_change, absolute_change, absolute_x_change, relative_x_change); } - model.mstep_A( - ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound, - print_level_m - ); + if (model.L_A > 0) { + model.mstep_A( + ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound, + print_level_m + ); + } else { + model.nhmm_base::mstep_A( + ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound, + print_level_m + ); + } if (model.mstep_return_code != 0) { return mstep_error_fanhmm( model.mstep_return_code, model, iter, relative_change, absolute_change, absolute_x_change, relative_x_change); } + if (model.L_B > 0) { model.mstep_B( ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound, print_level_m ); + } else { + model.nhmm_sc::mstep_B( + ftol_abs_m, ftol_rel_m, xtol_abs_m, xtol_rel_m, maxeval_m, bound, + print_level_m + ); + } if (model.mstep_return_code != 0) { return mstep_error_fanhmm( model.mstep_return_code, model, iter, relative_change, @@ -465,13 +475,17 @@ Rcpp::List EM_LBFGS_fanhmm_singlechannel( pars_new.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t(); pars_new.cols(n_pi + n_A, n_pi + n_A + n_B - 1) = arma::vectorise(model.eta_B).t(); ii = n_pi + n_A + n_B; - for (arma::uword s = 0; s < S; s++) { - pars_new.cols(ii, ii + n_Ar - 1) = arma::vectorise(model.rho_A(s)).t(); - ii += n_Ar; + if (n_Ar > 0) { + for (arma::uword s = 0; s < S; s++) { + pars_new.cols(ii, ii + n_Ar - 1) = arma::vectorise(model.rho_A(s)).t(); + ii += n_Ar; + } } - for (arma::uword s = 0; s < S; s++) { - pars_new.cols(ii, ii + n_Br - 1) = arma::vectorise(model.rho_B(s)).t(); - ii += n_Br; + if (n_Br > 0) { + for (arma::uword s = 0; s < S; s++) { + pars_new.cols(ii, ii + n_Br - 1) = arma::vectorise(model.rho_B(s)).t(); + ii += n_Br; + } } penalty_term = 0.5 * lambda * std::pow(arma::norm(pars_new, 2), 2); diff --git a/src/fanhmm_sc.h b/src/fanhmm_sc.h index 4c22a51..0d52468 100644 --- a/src/fanhmm_sc.h +++ b/src/fanhmm_sc.h @@ -78,6 +78,7 @@ struct fanhmm_sc : public nhmm_sc { A.each_slice() = Atmp.t(); } else { if (tv_A) { + // A_t = P(z_t+1 | z_t, y_t), pi = P(z_1) for (arma::uword t = 0; t < Ti(i); t++) { // time for (arma::uword s = 0; s < S; s ++) { // from states Atmp.col(s) = softmax( diff --git a/src/nhmm_backward.h b/src/nhmm_backward.h index 99d1edd..0970516 100644 --- a/src/nhmm_backward.h +++ b/src/nhmm_backward.h @@ -24,25 +24,6 @@ void univariate_backward_nhmm( } } } -// // time-invariant A -// template -// void univariate_backward_nhmm( -// submat& log_beta, -// const arma::mat& log_A, -// const arma::mat& log_py) { -// -// arma::uword S = log_py.n_rows; -// arma::uword T = log_py.n_cols; -// arma::mat log_tA = log_A.t(); -// log_beta.col(T - 1).zeros(); -// for (int t = (T - 2); t >= 0; t--) { -// for (arma::uword i = 0; i < S; i++) { -// log_beta(i, t) = logSumExp( -// log_beta.col(t + 1) + log_tA.col(i) + log_py.col(t + 1) -// ); -// } -// } -// } template void backward_nhmm(Model& model, arma::cube& log_beta) { diff --git a/src/nhmm_gradients.cpp b/src/nhmm_gradients.cpp index 9a6fe5e..5297018 100644 --- a/src/nhmm_gradients.cpp +++ b/src/nhmm_gradients.cpp @@ -236,14 +236,13 @@ void gradient_wrt_A( const arma::mat& log_beta, const double ll, const arma::cube& A, const arma::cube& X, const arma::cube& W, const arma::uword i, const arma::uword t, const arma::uword s) { - arma::uword n = W.n_rows; tmpmat = -A.slice(t).row(s).t() * A.slice(t).row(s); tmpmat.diag() += A.slice(t).row(s); arma::vec tmpvec = tmpmat * exp(log_alpha(s, t) + log_py.col(t + 1) + log_beta.col(t + 1) - ll); grad.cols(0, X.n_rows - 1) += tmpvec * X.slice(i).col(t).t(); - if (obs(t, i) > 0) { + if (obs(t, i) > 0 && n > 0) { grad.cols(X.n_rows + (obs(t, i) - 1) * n, X.n_rows + obs(t, i) * n - 1) += tmpvec * W.slice(i).col(t).t(); } @@ -264,7 +263,7 @@ void gradient_wrt_B_t0( tmpvec(idx) += brow; tmpvec = exp(log_pi(s) + log_beta(s, 0) - ll) * tmpvec; grad.cols(0, X.n_rows - 1) += tmpvec * X.slice(i).col(0).t(); - if (obs_0(i) > 0) { + if (obs_0(i) > 0 && n > 0) { grad.cols(X.n_rows + (obs_0(i) - 1) * n, X.n_rows + obs_0(i) * n - 1) += tmpvec * W.slice(i).col(0).t(); } @@ -285,7 +284,7 @@ void gradient_wrt_B( tmpvec = arma::accu(exp(log_alpha.col(t) + log_A.slice(t).col(s) + log_beta(s, t + 1) - ll)) * tmpvec; grad.cols(0, X.n_rows - 1) += tmpvec * X.slice(i).col(t + 1).t(); - if (obs(t, i) > 0) { + if (obs(t, i) > 0 && n > 0) { grad.cols(X.n_rows + (obs(t, i) - 1) * n, X.n_rows + obs(t, i) * n - 1) += tmpvec * W.slice(i).col(t + 1).t(); } diff --git a/src/nhmm_log_objective.cpp b/src/nhmm_log_objective.cpp index c7876f0..c826d2a 100644 --- a/src/nhmm_log_objective.cpp +++ b/src/nhmm_log_objective.cpp @@ -552,6 +552,7 @@ Rcpp::List log_objective_fanhmm_singlechannel( ); } } + // gradient wrt gamma_B for (arma::uword s = 0; s < model.S; s++) { gradient_wrt_B_t0( @@ -571,12 +572,20 @@ Rcpp::List log_objective_fanhmm_singlechannel( grad_A2.slice(s) = model.Qs.t() * grad_A.slice(s); grad_B2.slice(s) = model.Qm.t() * grad_B.slice(s); } + arma::cube grad_rho_A(model.S - 1, (model.M - 1) * model.L_A, model.S); + if (model.L_A > 0) { + grad_rho_A = grad_A2.cols(model.K_A, grad_A2.n_cols - 1); + } + arma::cube grad_rho_B(model.M - 1, (model.M - 1) * model.L_B, model.S); + if (model.L_B > 0) { + grad_rho_B = grad_B2.cols(model.K_B, grad_B2.n_cols - 1); + } return Rcpp::List::create( Rcpp::Named("loglik") = sum(loglik), Rcpp::Named("gradient_pi") = Rcpp::wrap(grad_pi2), Rcpp::Named("gradient_A") = Rcpp::wrap(grad_A2.cols(0, model.K_A - 1)), Rcpp::Named("gradient_B") = Rcpp::wrap(grad_B2.cols(0, model.K_B - 1)), - Rcpp::Named("gradient_rho_A") = Rcpp::wrap(grad_A2.cols(model.K_A, grad_A2.n_cols - 1)), - Rcpp::Named("gradient_rho_B") = Rcpp::wrap(grad_B2.cols(model.K_A, grad_B2.n_cols - 1)) + Rcpp::Named("gradient_rho_A") = Rcpp::wrap(grad_rho_A), + Rcpp::Named("gradient_rho_B") = Rcpp::wrap(grad_rho_B) ); }