Skip to content

Commit

Permalink
Add separable argument and implement TGCCA for tau != 1
Browse files Browse the repository at this point in the history
  • Loading branch information
GFabien committed Dec 14, 2023
1 parent 3fd5432 commit 7e7d8be
Show file tree
Hide file tree
Showing 18 changed files with 231 additions and 23 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ S3method(block_init,block)
S3method(block_init,dual_block)
S3method(block_init,dual_regularized_block)
S3method(block_init,primal_regularized_block)
S3method(block_init,regularized_tensor_block)
S3method(block_init,separable_regularized_tensor_block)
S3method(block_init,tensor_block)
S3method(block_postprocess,block)
S3method(block_postprocess,separable_regularized_tensor_block)
S3method(block_postprocess,sparse_block)
S3method(block_project,block)
S3method(block_project,dual_block)
Expand Down
24 changes: 19 additions & 5 deletions R/block.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,32 @@ new_tensor_block <- function(x, j, rank, mode_orth, ..., class = character()) {
new_regularized_tensor_block <- function(x, j, rank, mode_orth, tau, ...) {
new_tensor_block(
x, j, rank = rank, mode_orth = mode_orth, tau = tau,
M = NULL, ..., class = "tensor_regularized_block"
M = NULL, ..., class = "regularized_tensor_block"
)
}

new_separable_regularized_tensor_block <- function(x, j, rank, mode_orth,
tau, ...) {
new_tensor_block(
x, j, rank = rank, mode_orth = mode_orth, tau = tau,
M = NULL, ..., class = "separable_regularized_tensor_block"
)
}

### Utility method to choose the adequate class
create_block <- function(x, j, bias, na.rm, tau, sparsity,
tol, rank, mode_orth) {
tol, rank, mode_orth, separable) {
if (length(dim(x)) > 2) { # TGCCA
if (tau < 1) {
res <- new_regularized_tensor_block(
x, j, rank, mode_orth, tau, bias = bias, na.rm = na.rm
)
if (separable) {
res <- new_separable_regularized_tensor_block(
x, j, rank, mode_orth, tau, bias = bias, na.rm = na.rm
)
} else {
res <- new_regularized_tensor_block(
x, j, rank, mode_orth, tau, bias = bias, na.rm = na.rm
)
}
} else {
res <- new_tensor_block(x, j, rank, mode_orth, bias = bias, na.rm = na.rm)
}
Expand Down
21 changes: 21 additions & 0 deletions R/block_init.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,24 @@ block_init.tensor_block <- function(x, init = "svd") {

return(block_project(x))
}

#' @export
block_init.regularized_tensor_block <- function(x, init = "svd") {
NextMethod()
}

#' @export
block_init.separable_regularized_tensor_block <- function(x, init = "svd") {
d <- length(dim(x$x)) - 1
x$M <- estimate_separable_covariance(x$x)
x$M <- lapply(x$M, function(y) {
sqrt_matrix(
x$tau^(1 / d) * diag(nrow(y)) + (1 - x$tau^(1 / d)) * y,
inv = TRUE
)
})
for (m in seq_len(d)) {
x$x <- mode_product(x$x, x$M[[m]], m = m + 1)
}
NextMethod()
}
9 changes: 9 additions & 0 deletions R/block_postprocess.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ block_postprocess.sparse_block <- function(x, ctrl) {
}
NextMethod()
}

#' @export
block_postprocess.separable_regularized_tensor_block <- function(x, ctrl) {
x$factors <- lapply(seq_along(x$factors), function(m) {
x$M[[m]] %*% x$factors[[m]]
})
x$a <- Reduce(khatri_rao, rev(x$factors)) %*% x$weights
NextMethod()
}
27 changes: 27 additions & 0 deletions R/estimate_separable_covariance.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#' estimate_separable_covariance estimates the covariance matrix of a set of
#' random variables with an underlying tensor structure making the assumption
#' that the real covariance matrix has a separable structure.
#' @param x A numerical array with at least 3 dimensions.
#' @return The list composed of the estimated terms in the separable covariance.
#' @references Hoff, P. D. (2011), Separable covariance arrays via the Tucker
#' product, with applications to multivariate relational data.
#' Eun Jeong Min et al (2019), Tensor canonical correlation analysis.
#' @title Separable covariance estimator
#' @noRd
estimate_separable_covariance <- function(x) {
dim_x <- dim(x)
n <- dim_x[1]
d <- length(dim_x) - 1
x_bar <- apply(x, -1, mean)
r <- (1 / n) * norm(
matrix(x, nrow = n) - matrix(rep(x_bar, n), nrow = n, byrow = TRUE),
type = "F"
)^2
x_bar <- array(x_bar, dim = dim_x[-1])
lapply(seq_len(d), function(m) {
x_bar_m <- t(apply(x_bar, m, c))
x_bar_m <- x_bar_m %x% t(rep(1, n))
x_m <- t(apply(x, m + 1, c))
(1 / (n * r^((d - 1) / d))) * tcrossprod(x_m - x_bar_m)
})
}
4 changes: 3 additions & 1 deletion R/get_rgcca_args.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ get_rgcca_args <- function(object, default_args = list()) {
NA_method = tolower(default_args$NA_method),
comp_orth = default_args$comp_orth,
mode_orth = default_args$mode_orth,
separable = default_args$separable,
n_iter_max = default_args$n_iter_max,
connection = default_args$connection,
superblock = default_args$superblock,
Expand All @@ -61,7 +62,8 @@ get_rgcca_args <- function(object, default_args = list()) {
check_integer("tol", rgcca_args$tol, float = TRUE, min = 0)
check_integer("n_iter_max", rgcca_args$n_iter_max, min = 1)
for (i in c(
"superblock", "verbose", "scale", "bias", "quiet", "comp_orth"
"superblock", "verbose", "scale", "bias",
"quiet", "comp_orth", "separable"
)) {
check_boolean(i, rgcca_args[[i]])
}
Expand Down
29 changes: 29 additions & 0 deletions R/mode_product.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#' Compute the mode product between a tensor X and a matrix y on a given mode m.
#' @param X An array.
#' @param y A matrix.
#' @param m A scalar designating a mode of the tensor X.
#' @examples
#' X <- array(rnorm(40 * 5 * 7), dim = c(40, 5, 7))
#' y <- rnorm(5)
#' res <- mode_product(X, y, m = 2)
#' print(dim(X))
#' print(dim(res))
#' @noRd
mode_product <- function(X, y, m = 1) {
dim_X <- dim(X)

# Unfold the tensor on dimension m
perm <- c(m, seq_along(dim_X)[-m])
X <- aperm(X, perm)

# Compute the product
X <- matrix(X, nrow = nrow(X))
X <- t(y) %*% X

# Reshape the result back to a tensor
dim_X[m] <- NCOL(y)
X <- array(X, dim = dim_X[perm])

X <- aperm(X, c(1 + seq_len(m - 1), 1, m + seq_len(length(dim_X) - m)))
return(X)
}
13 changes: 8 additions & 5 deletions R/rgcca.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,23 @@
#' iterations.
#' @param comp_orth A logical value indicating if the deflation should lead to
#' orthogonal block components or orthogonal block weight vectors.
#' @param A Deprecated argument, please use blocks instead.
#' @param C Deprecated argument, please use connection instead.
#' @param rank Either an integer, an integer vector of
#' size \eqn{J} or an integer matrix
#' of dimension \eqn{\textrm{max}(\textrm{ncomp}) \times J} giving the rank
#' of the decomposition sought for the canonical vectors in TGCCA.
#' If block \eqn{j} is an array with at least three dimensions, rank must be
#' comprised between 1 and the number of variables on the mode bearing the
#' orthogonality constraint. See \textrm{mode_orth}.
#' orthogonality constraint. See \eqn{\textrm{mode_orth}}.
#' @param mode_orth Either an integer or an integer vector of size \eqn{J}
#' designating the mode which associated set of factors will be orthogonal
#' in the decomposition sought by TGCCA. If block \eqn{j} is an array with
#' \eqn{d > 2} dimensions, \textrm{mode_orth} must be comprised between
#' \eqn{d > 2} dimensions, \eqn{\textrm{mode_orth}} must be comprised between
#' 1 and \eqn{d - 1}.
#' @param separable A logical value if the regularization matrices must be
#' estimated as separable matrices (i.e. products of Kronecker products
#' matching the dimensions of the modes in the data).
#' @param A Deprecated argument, please use blocks instead.
#' @param C Deprecated argument, please use connection instead.
#' @return A fitted rgcca object.
#' @return \item{Y}{A list of \eqn{J} elements. The jth element
#' of the list \eqn{Y}
Expand Down Expand Up @@ -445,7 +448,7 @@ rgcca <- function(blocks, connection = NULL, tau = 1, ncomp = 1,
superblock = FALSE,
NA_method = "na.ignore", quiet = TRUE,
n_iter_max = 1000, comp_orth = TRUE,
rank = 1, mode_orth = 1,
rank = 1, mode_orth = 1, separable = TRUE,
A = NULL, C = NULL) {
# Check for deprecated arguments
if (!missing(A)) {
Expand Down
1 change: 1 addition & 0 deletions R/rgcca_cv.r
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ rgcca_cv <- function(blocks,
comp_orth = TRUE,
rank = 1,
mode_orth = 1,
separable = TRUE,
...) {
### Try to retrieve parameters from a rgcca object
rgcca_args <- as.list(environment())
Expand Down
6 changes: 4 additions & 2 deletions R/rgcca_inner_loop.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ rgcca_inner_loop <- function(A, C, g, dg, tau = rep(1, length(A)),
verbose = FALSE, init = "svd", bias = TRUE,
tol = 1e-08, na.rm = TRUE, n_iter_max = 1000,
rank = rep(1, length(A)),
mode_orth = rep(1, length(A))) {
mode_orth = rep(1, length(A)),
separable = TRUE) {
if (!is.numeric(tau)) {
# From Schafer and Strimmer, 2005
tau <- vapply(A, tau.estimate, na.rm = na.rm, FUN.VALUE = 1.0)
Expand All @@ -18,7 +19,8 @@ rgcca_inner_loop <- function(A, C, g, dg, tau = rep(1, length(A)),
### Initialization
block_objects <- lapply(seq_along(A), function(j) {
create_block(
A[[j]], j, bias, na.rm, tau[j], sparsity[j], tol, rank[j], mode_orth[j]
A[[j]], j, bias, na.rm, tau[j], sparsity[j],
tol, rank[j], mode_orth[j], separable
)
})

Expand Down
5 changes: 3 additions & 2 deletions R/rgcca_outer_loop.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ rgcca_outer_loop <- function(blocks, connection = 1 - diag(length(blocks)),
na.rm = TRUE, superblock = FALSE,
response = NULL, disjunction = NULL,
n_iter_max = 1000, comp_orth = TRUE,
rank = 1, mode_orth = 1) {
rank = 1, mode_orth = 1, separable = TRUE) {
if (verbose) {
scheme_str <- ifelse(is(scheme, "function"), "user-defined", scheme)
cat(
Expand Down Expand Up @@ -81,7 +81,8 @@ rgcca_outer_loop <- function(blocks, connection = 1 - diag(length(blocks)),
init = init, bias = bias, tol = tol,
verbose = verbose, na.rm = na.rm,
n_iter_max = n_iter_max,
rank = rank[n, ], mode_orth = mode_orth
rank = rank[n, ], mode_orth = mode_orth,
separable = separable
)

# Store tau, crit
Expand Down
3 changes: 2 additions & 1 deletion R/rgcca_permutation.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ rgcca_permutation <- function(blocks, par_type = "tau", par_value = NULL,
response = NULL, superblock = FALSE,
NA_method = "na.ignore", rgcca_res = NULL,
verbose = TRUE, n_iter_max = 1000,
comp_orth = TRUE, rank = 1, mode_orth = 1) {
comp_orth = TRUE, rank = 1, mode_orth = 1,
separable = TRUE) {
### Try to retrieve parameters from a rgcca object
rgcca_args <- as.list(environment())
tmp <- get_rgcca_args(blocks, rgcca_args)
Expand Down
19 changes: 19 additions & 0 deletions R/sqrt_matrix.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#' Compute the square root or the inverse of the square root of a
#' symmetric matrix.
#' @param X A symmetric matrix.
#' @param tol A relative tolerance to detect zero singular values.
#' @param inv A boolean indicating if the inverse of the square root must be
#' computed.
#' @noRd
sqrt_matrix <- function(X, tol = sqrt(.Machine$double.eps), inv = FALSE) {
eig <- eigen(X, symmetric = TRUE)
positive <- eig$values > max(tol * eig$values[1], 0)
d <- eig$values
if (inv) {
d[positive] <- 1 / sqrt(d[positive])
} else {
d[positive] <- sqrt(d[positive])
}
d[!positive] <- 0
eig$vectors %*% diag(d, nrow = length(d)) %*% t(eig$vectors)
}
9 changes: 7 additions & 2 deletions man/rgcca.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions man/rgcca_cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions man/rgcca_permutation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 7e7d8be

Please sign in to comment.