Skip to content

Commit

Permalink
Merge pull request #77 from rgcca-factory/refactor/use_classes_for_bl…
Browse files Browse the repository at this point in the history
…ocks

Refactor/use classes for blocks
  • Loading branch information
Tenenhaus authored Feb 5, 2024
2 parents 17347b7 + b0cb631 commit eefad3b
Show file tree
Hide file tree
Showing 29 changed files with 473 additions and 1,302 deletions.
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Generated by roxygen2: do not edit by hand

S3method(block_init,block)
S3method(block_init,dual_block)
S3method(block_init,dual_regularized_block)
S3method(block_init,primal_regularized_block)
S3method(block_postprocess,block)
S3method(block_postprocess,sparse_block)
S3method(block_project,block)
S3method(block_project,dual_block)
S3method(block_project,dual_regularized_block)
S3method(block_project,primal_regularized_block)
S3method(block_project,sparse_block)
S3method(block_update,block)
S3method(block_update,dual_block)
S3method(plot,rgcca)
S3method(plot,rgcca_bootstrap)
S3method(plot,rgcca_cv)
Expand Down
66 changes: 66 additions & 0 deletions R/block.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
### Create classes
new_block <- function(x, j, na.rm = TRUE, bias = TRUE,
..., class = character()) {
n <- NROW(x)
p <- NCOL(x)
N <- ifelse(bias, n, n - 1)

x <- list(
x = x,
j = j,
n = n,
p = p,
N = N,
na.rm = na.rm,
a = NULL,
Y = NULL,
...
)
class(x) <- c(class, "block")
return(x)
}

new_dual_block <- function(x, j, na.rm = TRUE, ..., class = character()) {
K <- pm(x, t(x), na.rm = na.rm)
new_block(
x, j, na.rm, alpha = NULL, K = K, ..., class = c(class, "dual_block")
)
}

new_primal_regularized_block <- function(x, j, tau, ...) {
new_block(x, j, tau = tau, M = NULL, ..., class = "primal_regularized_block")
}

new_dual_regularized_block <- function(x, j, tau, ...) {
new_dual_block(
x, j, tau = tau, M = NULL, ..., class = "dual_regularized_block"
)
}

new_sparse_block <- function(x, j, sparsity, tol = 1e-08, ...) {
const <- sqrt(NCOL(x)) * sparsity
new_block(
x, j, sparsity = sparsity, const = const,
tol = tol, ..., class = "sparse_block"
)
}

### Utility method to choose the adequate class
create_block <- function(x, j, bias, na.rm, tau, sparsity, tol) {
if (sparsity < 1) {
res <- new_sparse_block(x, j, sparsity, tol, bias = bias, na.rm = na.rm)
} else if (NROW(x) > NCOL(x)) {
if (tau < 1) {
res <- new_primal_regularized_block(x, j, tau, bias = bias, na.rm = na.rm)
} else {
res <- new_block(x, j, bias = bias, na.rm = na.rm)
}
} else {
if (tau < 1) {
res <- new_dual_regularized_block(x, j, tau, bias = bias, na.rm = na.rm)
} else {
res <- new_dual_block(x, j, bias = bias, na.rm = na.rm)
}
}
return(res)
}
41 changes: 41 additions & 0 deletions R/block_init.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#' @importFrom MASS ginv

block_init <- function(x, init = "svd") {
UseMethod("block_init")
}

#' @export
block_init.block <- function(x, init = "svd") {
if (init == "svd") {
x$a <- initsvd(x$x, dual = FALSE)
} else {
x$a <- rnorm(x$p)
}

return(block_project(x))
}

#' @export
block_init.dual_block <- function(x, init = "svd") {
if (init == "svd") {
x$alpha <- initsvd(x$x, dual = TRUE)
} else {
x$alpha <- rnorm(x$n)
}

return(block_project(x))
}

#' @export
block_init.primal_regularized_block <- function(x, init = "svd") {
x$M <- ginv(
x$tau * diag(x$p) + (1 - x$tau) * pm(t(x$x), x$x, na.rm = x$na.rm) / x$N
)
NextMethod()
}

#' @export
block_init.dual_regularized_block <- function(x, init = "svd") {
x$M <- ginv(x$tau * diag(x$n) + (1 - x$tau) * x$K / x$N)
NextMethod()
}
33 changes: 33 additions & 0 deletions R/block_postprocess.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
block_postprocess <- function(x, ctrl) {
UseMethod("block_postprocess")
}

#' @export
block_postprocess.block <- function(x, ctrl) {
if (ctrl && (x$a[1] < 0)) {
x$a <- -x$a
x$Y <- -x$Y
}
return(x)
}

#' @export
block_postprocess.sparse_block <- function(x, ctrl) {
l2_sat <- norm(x$a, "2")
if (abs(l2_sat - 1) > x$tol) {
if (l2_sat < .Machine$double.eps) {
warning(
"The l2 norm of the block weight vector #",
x$j, " is too small :", l2_sat
)
} else {
warning(
"The l2 constraint is not saturated for block #", x$j,
". The intersection of the l1 and l2 spheres is empty for ",
"a sparsity parameter equal to ", x$sparsity,
". Try to increase the value of the sparsity parameter."
)
}
}
NextMethod()
}
56 changes: 56 additions & 0 deletions R/block_project.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
block_project <- function(x) {
UseMethod("block_project")
}

#' @export
block_project.block <- function(x) {
if (any(x$a != 0)) {
x$a <- x$a / norm(x$a, type = "2")
}

x$Y <- pm(x$x, x$a, na.rm = x$na.rm)
return(x)
}

#' @export
block_project.dual_block <- function(x) {
if (any(x$alpha != 0)) {
x$alpha <- x$alpha / drop(sqrt(t(x$alpha) %*% x$K %*% x$alpha))
}
x$a <- pm(t(x$x), x$alpha, na.rm = x$na.rm)

x$Y <- pm(x$x, x$a, na.rm = x$na.rm)
return(x)
}

#' @export
block_project.primal_regularized_block <- function(x) {
if (any(x$a != 0)) {
x$a <- x$M %*% x$a / drop(sqrt(t(x$a) %*% x$M %*% x$a))
}

x$Y <- pm(x$x, x$a, na.rm = x$na.rm)
return(x)
}

#' @export
block_project.dual_regularized_block <- function(x) {
if (any(x$alpha != 0)) {
x$alpha <- x$M %*% x$alpha / drop(sqrt(
t(x$alpha) %*% x$M %*% x$K %*% x$alpha
))
}
x$a <- pm(t(x$x), x$alpha, na.rm = x$na.rm)

x$Y <- pm(x$x, x$a, na.rm = x$na.rm)
return(x)
}

#' @export
block_project.sparse_block <- function(x) {
if (any(x$a != 0)) {
x$a <- soft_threshold(x$a, x$const)
}
x$Y <- pm(x$x, x$a, na.rm = x$na.rm)
return(x)
}
15 changes: 15 additions & 0 deletions R/block_update.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
block_update <- function(x, grad) {
UseMethod("block_update")
}

#' @export
block_update.block <- function(x, grad) {
x$a <- pm(t(x$x), grad, na.rm = x$na.rm)
return(block_project(x))
}

#' @export
block_update.dual_block <- function(x, grad) {
x$alpha <- grad
return(block_project(x))
}
2 changes: 1 addition & 1 deletion R/rgcca.R
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ rgcca <- function(blocks, connection = NULL, tau = 1, ncomp = 1,
gcca_args[["blocks"]] <- blocks
gcca_args[["disjunction"]] <- opt$disjunction
gcca_args[[opt$param]] <- rgcca_args[[opt$param]]
func_out <- do.call(opt$gcca, gcca_args)
func_out <- do.call(rgcca_outer_loop, gcca_args)

### Format the output
func_out <- format_output(func_out, rgcca_args, opt, blocks)
Expand Down
2 changes: 1 addition & 1 deletion R/rgcca_cv.r
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
#' plot(cv_out, display_order = FALSE)
#' }
rgcca_cv <- function(blocks,
connection = NULL,
method = "rgcca",
response = NULL,
par_type = "tau",
Expand Down Expand Up @@ -228,7 +229,6 @@ rgcca_cv <- function(blocks,
par_type <- "sparsity"
} else if (par_type == "sparsity") {
rgcca_args$method <- "sgcca"
opt$gcca <- sgcca
opt$param <- "sparsity"
}

Expand Down
73 changes: 0 additions & 73 deletions R/rgcca_init.R

This file was deleted.

Loading

0 comments on commit eefad3b

Please sign in to comment.