Skip to content

Commit

Permalink
fix type check
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed May 23, 2024
1 parent 9705a13 commit 0ec4c19
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
tst_overall = @timed for opt in model.optimizers
tst_epochs = @timed for ep in 1:(model.n_epochs)
if model.use_batch
if model.m.compute_mode <: VectorMode
if model.m.compute_mode isa VectorMode

Check warning on line 46 in src/exts/mlj_ext/core_cond_icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/exts/mlj_ext/core_cond_icnf.jl#L46

Added line #L46 was not covered by tests
data = MLUtils.DataLoader(
(x, y);
batchsize = -1,
Expand All @@ -52,7 +52,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
parallel = false,
buffer = false,
)
elseif model.m.compute_mode <: MatrixMode
elseif model.m.compute_mode isa MatrixMode

Check warning on line 55 in src/exts/mlj_ext/core_cond_icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/exts/mlj_ext/core_cond_icnf.jl#L55

Added line #L55 was not covered by tests
data = MLUtils.DataLoader(
(x, y);
batchsize = model.batch_size,
Expand Down Expand Up @@ -110,13 +110,13 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew)
end
(ps, st) = fitresult

tst = @timed if model.m.compute_mode <: VectorMode
tst = @timed if model.m.compute_mode isa VectorMode

Check warning on line 113 in src/exts/mlj_ext/core_cond_icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/exts/mlj_ext/core_cond_icnf.jl#L113

Added line #L113 was not covered by tests
logp̂x = broadcast(
(x, y) -> first(inference(model.m, TestMode(), x, y, ps, st)),
eachcol(xnew),
eachcol(ynew),
)
elseif model.m.compute_mode <: MatrixMode
elseif model.m.compute_mode isa MatrixMode

Check warning on line 119 in src/exts/mlj_ext/core_cond_icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/exts/mlj_ext/core_cond_icnf.jl#L119

Added line #L119 was not covered by tests
logp̂x = first(inference(model.m, TestMode(), xnew, ynew, ps, st))
else
error("Not Implemented")
Expand Down
8 changes: 4 additions & 4 deletions src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
tst_overall = @timed for opt in model.optimizers
tst_epochs = @timed for ep in 1:(model.n_epochs)
if model.use_batch
if model.m.compute_mode <: VectorMode
if model.m.compute_mode isa VectorMode

Check warning on line 44 in src/exts/mlj_ext/core_icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/exts/mlj_ext/core_icnf.jl#L44

Added line #L44 was not covered by tests
data = MLUtils.DataLoader(
(x,);
batchsize = -1,
Expand All @@ -50,7 +50,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
parallel = false,
buffer = false,
)
elseif model.m.compute_mode <: MatrixMode
elseif model.m.compute_mode isa MatrixMode

Check warning on line 53 in src/exts/mlj_ext/core_icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/exts/mlj_ext/core_icnf.jl#L53

Added line #L53 was not covered by tests
data = MLUtils.DataLoader(
(x,);
batchsize = model.batch_size,
Expand Down Expand Up @@ -105,9 +105,9 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)
end
(ps, st) = fitresult

tst = @timed if model.m.compute_mode <: VectorMode
tst = @timed if model.m.compute_mode isa VectorMode

Check warning on line 108 in src/exts/mlj_ext/core_icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/exts/mlj_ext/core_icnf.jl#L108

Added line #L108 was not covered by tests
logp̂x = broadcast(x -> first(inference(model.m, TestMode(), x, ps, st)), eachcol(xnew))
elseif model.m.compute_mode <: MatrixMode
elseif model.m.compute_mode isa MatrixMode

Check warning on line 110 in src/exts/mlj_ext/core_icnf.jl

View check run for this annotation

Codecov / codecov/patch

src/exts/mlj_ext/core_icnf.jl#L110

Added line #L110 was not covered by tests
logp̂x = first(inference(model.m, TestMode(), xnew, ps, st))
else
error("Not Implemented")
Expand Down
4 changes: 2 additions & 2 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ Test.@testset "Call Tests" begin
Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...)
data_dist2 =
Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (4, 2))...)
if compute_mode <: ContinuousNormalizingFlows.VectorMode
if compute_mode isa ContinuousNormalizingFlows.VectorMode
r = convert.(data_type, rand(data_dist, nvars))
r2 = convert.(data_type, rand(data_dist2, nvars))
elseif compute_mode <: ContinuousNormalizingFlows.MatrixMode
elseif compute_mode isa ContinuousNormalizingFlows.MatrixMode
r = convert.(data_type, rand(data_dist, nvars, ndata))
r2 = convert.(data_type, rand(data_dist2, nvars, ndata))
end
Expand Down

0 comments on commit 0ec4c19

Please sign in to comment.