Skip to content

Commit

Permalink
Merge pull request #27 from ATISLabs/feature/generate_gaussian_quantiles
Browse files Browse the repository at this point in the history
[#1]feature/generate_gaussian_quantiles
  • Loading branch information
filipebraida authored Sep 10, 2020
2 parents e09eaa2 + 0fabb49 commit 2388f1b
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ make_regression | Generate a random regression problem.
make_classification | Generate a random n-class classification problem. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html)
make_low_rank_matrix | Generate a mostly low rank matrix with bell-shaped singular values. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_low_rank_matrix.html)
make_swiss_roll | Generate a swiss roll dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_swiss_roll.html)
make_gaussian_quantiles | Generate a swiss roll dataset. | [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_gaussian_quantiles.html)

**Disclaimer**: SyntheticDatasets.jl borrows code and documentation from
[scikit-learn](https://scikit-learn.org/stable/modules/classes.html#samples-generator) in the dataset module, but *it is not an official part
Expand Down
2 changes: 1 addition & 1 deletion src/matlab.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ function generate_twospirals(; n_samples::Int = 2000,
labels = [zeros(Int, N1); ones(Int, N1)]

return convert(features, labels);
end
end
39 changes: 39 additions & 0 deletions src/sklearn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,42 @@ function generate_swiss_roll(; n_samples::Int = 100,

return convert(features, labels)
end

"""
function generate_gaussian_quantiles(; mean::Array{<:Union{Number, Nothing}, 1} = [nothing],
cov::Float64 = 1,
n_samples::Int = 100,
n_features::Int = 2,
n_classes::Int = 3,
shuffle::Bool = true,
random_state::Union{Int, Nothing} = nothing)
Generate isotropic Gaussian and label samples by quantile.
#Arguments
- `mean::Array{<:Union{Number, Nothing}, 1} = [nothing]`: The mean of the multi-dimensional normal distribution. If None then use the origin (0, 0, …).
- `cov::Float64 = 1`: The covariance matrix will be this value times the unit matrix.
- `n_samples::Int = 100`: The total number of points equally divided among classes.
- `n_features::Int = 2`: The number of features for each sample.
- `n_classes::Int = 3`: The number of classes.
- `shuffle::Bool = true`: Shuffle the samples.
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_gaussian_quantiles.html)
"""
function generate_gaussian_quantiles(; mean::Union{Array{<:Number, 1}, Nothing} = nothing,
cov::Float64 = 1.0,
n_samples::Int = 100,
n_features::Int = 2,
n_classes::Int = 3,
shuffle::Bool = true,
random_state::Union{Int, Nothing} = nothing)

(features, labels) = datasets.make_gaussian_quantiles(mean = mean,
cov = cov,
n_samples = n_samples,
n_features = n_features,
n_classes = n_classes,
shuffle = shuffle,
random_state = random_state)

return convert(features, labels)
end
12 changes: 8 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ using Test
@test size(data)[1] == samples
@test size(data)[2] == features + 1

@test size(data)[1] == samples
@test size(data)[2] == features + 1

data = SyntheticDatasets.generate_friedman1(n_samples = samples,
n_features = features)

Expand All @@ -74,13 +71,20 @@ using Test

@test size(data)[1] == samples
@test size(data)[2] == features

data = SyntheticDatasets.generate_swiss_roll(n_samples =samples,
noise = 2.2,
random_state = 5)

@test size(data)[1] == samples
@test size(data)[2] == 4

data = SyntheticDatasets.generate_gaussian_quantiles(n_samples = samples,
n_features = features,
random_state = 5)

@test size(data)[1] == samples
@test size(data)[2] == features + 1
end

@testset "Matlab Generators" begin
Expand Down

0 comments on commit 2388f1b

Please sign in to comment.