Skip to content

Commit

Permalink
skip gpu tests if cuda not available.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed May 7, 2021
1 parent b510326 commit 5ea3d5b
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions tests/mdn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,34 +27,37 @@ def test_mdn_for_diff_dimension_data(
dim: int, device: str, hidden_features: int = 50, num_components: int = 10
) -> None:

theta = torch.rand(3, dim)
likelihood_shift = torch.zeros(theta.shape)
likelihood_cov = eye(dim)
context = linear_gaussian(theta, likelihood_shift, likelihood_cov)

x_numel = theta[0].numel()
y_numel = context[0].numel()

distribution = MultivariateGaussianMDN(
features=x_numel,
context_features=y_numel,
hidden_features=hidden_features,
hidden_net=nn.Sequential(
nn.Linear(y_numel, hidden_features),
nn.ReLU(),
nn.Linear(hidden_features, hidden_features),
nn.ReLU(),
),
num_components=num_components,
custom_initialization=True,
)
distribution = distribution.to(device)

logits, means, precisions, _, _ = distribution.get_mixture_components(
theta.to(device)
)

# Test evaluation and sampling.
distribution.log_prob(context.to(device), theta.to(device))
distribution.sample(100, theta.to(device))
distribution.sample_mog(10, logits, means, precisions)
if device == "cuda:0" and not torch.cuda.is_available():
pass
else:
theta = torch.rand(3, dim)
likelihood_shift = torch.zeros(theta.shape)
likelihood_cov = eye(dim)
context = linear_gaussian(theta, likelihood_shift, likelihood_cov)

x_numel = theta[0].numel()
y_numel = context[0].numel()

distribution = MultivariateGaussianMDN(
features=x_numel,
context_features=y_numel,
hidden_features=hidden_features,
hidden_net=nn.Sequential(
nn.Linear(y_numel, hidden_features),
nn.ReLU(),
nn.Linear(hidden_features, hidden_features),
nn.ReLU(),
),
num_components=num_components,
custom_initialization=True,
)
distribution = distribution.to(device)

logits, means, precisions, _, _ = distribution.get_mixture_components(
theta.to(device)
)

# Test evaluation and sampling.
distribution.log_prob(context.to(device), theta.to(device))
distribution.sample(100, theta.to(device))
distribution.sample_mog(10, logits, means, precisions)

0 comments on commit 5ea3d5b

Please sign in to comment.