Skip to content

Commit

Permalink
collapsed catchability stan files and fixed jointSummarize error
Browse files Browse the repository at this point in the history
  • Loading branch information
abigailkeller committed Dec 23, 2024
1 parent 9d73549 commit 9f25e9f
Show file tree
Hide file tree
Showing 14 changed files with 711 additions and 1,050 deletions.
930 changes: 465 additions & 465 deletions .Rhistory

Large diffs are not rendered by default.

53 changes: 21 additions & 32 deletions R/jointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -369,17 +369,17 @@ jointModel <- function(data, cov = NULL, family = 'poisson',
model_data <- rlist::list.append(
model_data,
nparams = length(q_names),
mat = as.integer(count_all$count.type)#, ###### CHANGED
#ctch = 1
mat = as.integer(count_all$count.type),
ctch = 1
)
} #else {
#model_data <- rlist::list.append(
# model_data,
# nparams = 0,
# mat = matrix(NA,nrow=nrow(count_all),ncol=0)#,
#ctch = 0
#)
#}
} else {
model_data <- rlist::list.append(
model_data,
nparams = 0,
mat = as.integer(rep(1,nrow(count_all))),
ctch = 0
)
}
# append data if family == negbin
if(get_family_index(family)==2){
model_data <- rlist::list.append(
Expand All @@ -406,7 +406,7 @@ jointModel <- function(data, cov = NULL, family = 'poisson',
sample.int(.Machine$integer.max, 1))

# get stan model
model_index <- get_stan_model(q, family)
model_index <- get_stan_model(family)

# get initial values
if(isCatch_type(q)){
Expand All @@ -419,10 +419,8 @@ jointModel <- function(data, cov = NULL, family = 'poisson',

# run model
out <- rstan::sampling(
c(stanmodels$joint_binary_cov_catchability_count,
stanmodels$joint_binary_cov_catchability_gamma,
stanmodels$joint_binary_cov_count,
stanmodels$joint_binary_cov_gamma)[model_index][[1]],
c(stanmodels$joint_count,
stanmodels$joint_continuous)[model_index][[1]],
data = model_data,
cores = cores,
seed = SEED,
Expand Down Expand Up @@ -503,20 +501,11 @@ get_family_index <- function(family){
# helper functions: get stan model #
####################################

get_stan_model <- function(q, family){
get_stan_model <- function(family){

index <- get_family_index(family)
index <- ifelse(family %in% c('poisson','negbin'),1,2)

if(isCatch_type(q) && index %in% c(1,2)){
final_index <- 1
} else if(isCatch_type(q) && index == 3){
final_index <- 2
} else if(!isCatch_type(q) && index %in% c(1,2)){
final_index <- 3
} else {
final_index <- 4
}
return(final_index)
return(index)
}


Expand Down Expand Up @@ -625,9 +614,9 @@ init_joint_cov_catchability <- function(n.chain,qPCR_all,q_names,cov,
for(i in 1:n.chain){
A[[i]] <- list(
if('mu' %in% names(initial_values[[i]])){
mu_trad_1 <- initial_values[[i]]$mu[L_match_trad$L]
mu_trad <- initial_values[[i]]$mu[L_match_trad$L]
} else {
mu_trad_1 <- mu_means_trad
mu_trad <- mu_means_trad
},

if('mu' %in% names(initial_values[[i]])){
Expand Down Expand Up @@ -656,22 +645,22 @@ init_joint_cov_catchability <- function(n.chain,qPCR_all,q_names,cov,
p_dna <- rep(0.4,dim(L_match_dna)[1]),
p11_dna <- rep(0.4,dim(L_match_dna)[1]) - 0.01
)
names(A[[i]]) <- c('mu_trad_1','mu','log_p10','alpha','q_trans',
names(A[[i]]) <- c('mu_trad','mu','log_p10','alpha','q_trans',
'p_dna','p11_dna')

}
} else {
for(i in 1:n.chain){
A[[i]] <- list(
mu_trad_1 <- mu_means_trad,
mu_trad <- mu_means_trad,
mu <- mu_means_all,
log_p10 <- stats::runif(1,log(0.0001),log(0.01)),
alpha <- as.array(c(3.5,rep(0,length(cov)))),
q_trans <- as.data.frame(stats::runif(length(q_names),0.01,1)),
p_dna <- rep(0.4,dim(L_match_dna)[1]),
p11_dna <- rep(0.4,dim(L_match_dna)[1]) - 0.01
)
names(A[[i]]) <- c('mu_trad_1','mu','log_p10','alpha','q_trans',
names(A[[i]]) <- c('mu_trad','mu','log_p10','alpha','q_trans',
'p_dna','p11_dna')
}
}
Expand Down
18 changes: 17 additions & 1 deletion R/jointSummarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ jointSummarize <- function(modelfit, par = 'all', probs = c(0.025,0.975),
use_cache = FALSE)$summary, as.integer(digits))
}

# fix row name if phi present
if('phi[1]'%in% rownames(out)){
row_index <- which(rownames(out)=='phi[1]')
rownames(out)[row_index] <- 'phi'
}

# fix row names if q = FALSE
if(modelfit@par_dims$q == 0){
names <- rownames(out)
rownames(out) <- gsub("\\[([0-9]+),1\\]", "[\\1]", names)
}

return(out)

}
Expand All @@ -110,7 +122,11 @@ isCatch <- function(pars){
return(out)
}
isNegbin <- function(modelfit){
out <- ifelse(modelfit@par_dims$phi == 1,TRUE,FALSE)
if('phi' %in% modelfit@model_pars){
out <- ifelse(modelfit@par_dims$phi == 1,TRUE,FALSE)
} else {
out <- FALSE
}
return(out)
}

Expand Down
189 changes: 71 additions & 118 deletions R/traditionalModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,131 +205,84 @@ traditionalModel <- function(data, family = 'poisson',
cores <- 1
}

# get seed
SEED <- ifelse(!is.null(seed),
as.integer(seed),
sample.int(.Machine$integer.max, 1))
# create data that will be present in all model variations
model_data <- list(
Nloc = length(unique(count_all$L)),
C = nrow(count_all),
R = count_all$L,
E = count_all$count,
control = list(adapt_delta = adapt_delta)
)

# append data based on catchability
if(q==TRUE){
model_data <- rlist::list.append(
model_data,
nparams = length(q_names),
mat = as.integer(count_all$count.type),
ctch = 1
)
} else {
model_data <- rlist::list.append(
model_data,
nparams = 0,
mat = as.integer(rep(1,nrow(count_all))),
ctch = 0
)
}

##run model, catchability, pois/gamma
if(q == TRUE && family != 'negbin'){
model_index <- dplyr::case_when(family == 'poisson'~ 1,
family == 'gamma' ~ 2)
inits <- init_trad_catchability(n.chain, count_all, q_names,
initial_values)
out <- rstan::sampling(c(
stanmodels$traditional_catchability_count,
stanmodels$traditional_catchability_gamma)[model_index][[1]],
data = list(
Nloc = length(unique(count_all$L)),
C = nrow(count_all),
R = count_all$L,
E = count_all$count,
nparams = length(q_names),
mat = as.integer(count_all$count.type),
phipriors = c(1,1),
negbin = 0,
control = list(adapt_delta = adapt_delta)
),
cores = cores,
seed = SEED,
#' @srrstats {G2.4,G2.4a} explicit conversion to
#' integers for sampling arguments
chains = as.integer(n.chain),
thin = as.integer(thin),
warmup = as.integer(n.iter.burn),
iter = (
as.integer(n.iter.burn) + as.integer(n.iter.sample)
),
init = inits,
refresh = ifelse(verbose == TRUE,500,0)
# append data if family == negbin
if(family=='negbin'){
model_data <- rlist::list.append(
model_data,
phipriors = phipriors,
negbin = 1
)
} else if(family=='poisson'){
model_data <- rlist::list.append(
model_data,
phipriors = c(1,1),
negbin = 0
)
} else if(q == TRUE && family == 'negbin'){
##run model, catchability, negbin
}

# get stan model
model_index <- get_stan_model(family)

# get initial values
if(q == TRUE){
inits <- init_trad_catchability(n.chain, count_all, q_names,
initial_values)
out <- rstan::sampling(stanmodels$traditional_catchability_count,
data = list(
Nloc = length(unique(count_all$L)),
C = nrow(count_all),
R = count_all$L,
E = count_all$count,
nparams = length(q_names),
mat = as.integer(count_all$count.type),
phipriors = phipriors,
negbin = 1,
control = list(adapt_delta = adapt_delta)
),
cores = cores,
seed = SEED,
#' @srrstats {G2.4,G2.4a} explicit conversion
#' to integers for sampling arguments
chains = as.integer(n.chain),
thin = as.integer(thin),
warmup = as.integer(n.iter.burn),
iter = (
as.integer(n.iter.burn) + as.integer(n.iter.sample)
),
init = inits,
refresh = ifelse(verbose == TRUE,500,0)
)
} else if(q == FALSE && family != 'negbin'){
##run model, no catchability, pois/gamma
model_index <- dplyr::case_when(family == 'poisson'~ 1,
family == 'gamma' ~ 2)
inits <- init_trad(n.chain, count_all, initial_values)
out <- rstan::sampling(c(stanmodels$traditional_count,
stanmodels$traditional_gamma)[model_index][[1]],
data = list(
Nloc = length(unique(count_all$L)),
C = nrow(count_all),
R = count_all$L,
E = count_all$count,
phipriors = c(1,1),
negbin = 0,
control = list(adapt_delta = adapt_delta,
stepsize = 0.5)
),
cores = cores,
seed = SEED,
#' @srrstats {G2.4,G2.4a} explicit conversion to
#' integers for sampling arguments
chains = as.integer(n.chain),
thin = as.integer(thin),
warmup = as.integer(n.iter.burn),
iter = (
as.integer(n.iter.burn) + as.integer(n.iter.sample)
),
init = inits,
refresh = ifelse(verbose == TRUE,500,0)
)
} else if(q == FALSE && family == 'negbin'){
##run model, no catchability, negbin
} else {
inits <- init_trad(n.chain, count_all, initial_values)
out <- rstan::sampling(stanmodels$traditional_count,
data = list(
Nloc = length(unique(count_all$L)),
C = nrow(count_all),
R = count_all$L,
E = count_all$count,
phipriors = phipriors,
negbin = 1,
control = list(adapt_delta = adapt_delta)
),
cores = cores,
seed = SEED,
#' @srrstats {G2.4,G2.4a} explicit conversion to
#' integers for sampling arguments
chains = as.integer(n.chain),
thin = as.integer(thin),
warmup = as.integer(n.iter.burn),
iter = (
as.integer(n.iter.burn) + as.integer(n.iter.sample)
),
init = inits,
refresh = ifelse(verbose == TRUE,500,0)
)
}

# get seed
SEED <- ifelse(!is.null(seed),
as.integer(seed),
sample.int(.Machine$integer.max, 1))

# run model
out <- rstan::sampling(
c(stanmodels$traditional_count,
stanmodels$traditional_continuous)[model_index][[1]],
data = model_data,
cores = cores,
seed = SEED,
#' @srrstats {G2.4,G2.4a} explicit conversion to
#' integers for sampling arguments
chains = as.integer(n.chain),
thin = as.integer(thin),
warmup = as.integer(n.iter.burn),
iter = (
as.integer(n.iter.burn) + as.integer(n.iter.sample)
),
init = inits,
refresh = ifelse(verbose == TRUE,500,0)
)



# assert that the log likelihood is a double
#' @srrstats {G5.3} assert that model run worked and the log likelihood is
#' valid (i.e., not NA)
Expand Down
Loading

0 comments on commit 9f25e9f

Please sign in to comment.