Skip to content

Commit

Permalink
added Numberlist - and avoid heterogenous Number arrah in NumberList …
Browse files Browse the repository at this point in the history
…and Samples
  • Loading branch information
B0SKAMP committed Mar 4, 2024
1 parent 65efc04 commit 59b61ef
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.fairdatapipeline.estimate.ImmutableEstimate;
import org.fairdatapipeline.file.CleanableFileChannel;
import org.fairdatapipeline.parameters.ImmutableBoolList;
import org.fairdatapipeline.parameters.ImmutableNumberList;
import org.fairdatapipeline.parameters.ImmutableStringList;
import org.fairdatapipeline.parameters.ReadComponent;
import org.fairdatapipeline.samples.ImmutableSamples;
Expand Down Expand Up @@ -149,6 +150,25 @@ public List<String> readStrings() {
return ((ImmutableStringList) data).getStrings();
}

/**
* read the Numbers that were stored as this component in a TOML file.
*
* @return the Numbers object
*/
public List<Number> readNumbers() {
ReadComponent data;
try (CleanableFileChannel fileChannel = this.getFileChannel()) {
data = this.dp.coderun.parameterDataReader.read(fileChannel, this.component_name);
} catch (IOException e) {
throw (new RuntimeException("readStrings() -- IOException trying to read from file.", e));
}
if (!(data instanceof ImmutableNumberList)) {
throw (new RuntimeException(
"readNumbers() -- this objComponent (" + this.component_name + ") is not a NumberList"));
}
return ((ImmutableNumberList) data).getNumbers();
}

/**
* read the Samples that were stored as this component in a TOML file.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.fairdatapipeline.estimate.ImmutableEstimate;
import org.fairdatapipeline.file.CleanableFileChannel;
import org.fairdatapipeline.parameters.BoolList;
import org.fairdatapipeline.parameters.NumberList;
import org.fairdatapipeline.parameters.StringList;
import org.fairdatapipeline.samples.Samples;

Expand Down Expand Up @@ -116,7 +117,7 @@ public void writeBools(BoolList bools) {
}

/**
* write a BoolList, as this named component in the data product.
* write a StringList, as this named component in the data product.
*
* @param strings the Strings to write
*/
Expand All @@ -132,6 +133,23 @@ public void writeStrings(StringList strings) {
this.been_used = true;
}

/**
* write NumberList, as this named component in the data product.
*
* @param numbers the Numbers to write
*/
public void writeNumbers(NumberList numbers) {
if (this.been_used) {
throw (new RuntimeException("obj component already written"));
}
try (CleanableFileChannel fileChannel = this.getFileChannel()) {
this.dp.coderun.parameterDataWriter.write(fileChannel, this.component_name, numbers);
} catch (IOException e) {
throw (new RuntimeException("writeStrings() -- IOException trying to write to file.", e));
}
this.been_used = true;
}

/**
* write Samples, as this named component in the data product.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ public ComponentsDeserializer(RandomGenerator rng) {
"bools",
ImmutableBoolList.class,
"strings",
ImmutableStringList.class);
ImmutableStringList.class,
"numbers",
ImmutableNumberList.class);

@Override
public Components deserialize(JsonParser jsonParser, DeserializationContext ctxt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ public ComponentsSerializer(RandomGenerator rng) {
ImmutableStringList.class,
"strings",
ImmutableBoolList.class,
"bools");
"bools",
ImmutableNumberList.class,
"numbers");

@Override
public void serialize(Components components, JsonGenerator gen, SerializerProvider serializers)
Expand Down
31 changes: 31 additions & 0 deletions api/src/main/java/org/fairdatapipeline/parameters/NumberList.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.fairdatapipeline.parameters;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import java.util.List;
import java.util.stream.Collectors;
import org.immutables.value.Value;
import org.immutables.value.Value.Immutable;

@Immutable
@JsonSerialize
@JsonDeserialize
public interface NumberList extends Component {
List<Number> numbers();

@JsonIgnore
default List<Number> getNumbers() {
return numbers();
}

@Value.Check
default NumberList avoidHeterogeneous() {
// count the number of integers in this list:
int i = (int) numbers().stream().filter((x) -> ((Number) x.intValue()) == x).count();
if (i != 0 && i < numbers().size()) {
return ImmutableNumberList.builder().numbers(numbers().stream().map(Number::doubleValue).collect(Collectors.toList())).build();
}
return this;
}
}
15 changes: 15 additions & 0 deletions api/src/main/java/org/fairdatapipeline/samples/Samples.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import java.util.List;
import java.util.stream.Collectors;

import org.fairdatapipeline.distribution.Distribution;
import org.fairdatapipeline.distribution.ImmutableDistribution;
import org.fairdatapipeline.parameters.ImmutableNumberList;
import org.fairdatapipeline.parameters.NumberList;
import org.fairdatapipeline.parameters.RngComponent;
import org.immutables.value.Value;
import org.immutables.value.Value.Immutable;

@Immutable
Expand Down Expand Up @@ -41,4 +46,14 @@ default Distribution getDistribution() {
.rng(rng())
.build();
}

@Value.Check
default Samples avoidHeterogeneous() {
// count the number of integers in this list:
int i = (int) samples().stream().filter((x) -> ((Number) x.intValue()) == x).count();
if (i != 0 && i < samples().size()) {
return ImmutableSamples.builder().samples(samples().stream().map(Number::doubleValue).collect(Collectors.toList())).rng(rng()).build();
}
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
import org.fairdatapipeline.distribution.ImmutableMinMax;
import org.fairdatapipeline.distribution.MinMax;
import org.fairdatapipeline.file.CleanableFileChannel;
import org.fairdatapipeline.parameters.BoolList;
import org.fairdatapipeline.parameters.ImmutableBoolList;
import org.fairdatapipeline.parameters.ImmutableStringList;
import org.fairdatapipeline.parameters.StringList;
import org.fairdatapipeline.parameters.*;
import org.fairdatapipeline.samples.ImmutableSamples;
import org.fairdatapipeline.samples.Samples;
import org.javatuples.Triplet;
Expand All @@ -57,6 +54,7 @@ class CoderunIntegrationTest {
private Samples samples, samples2, samples3, samples4;
private StringList stringlist1, stringlist2;
private BoolList boollist;
private NumberList numberlist;
private Distribution distribution;
private Distribution categoricalDistribution;
private final Number estimate = 1.0;
Expand Down Expand Up @@ -108,6 +106,7 @@ void setup_data() {
stringlist1 = ImmutableStringList.builder().addStrings("do", "re", "mi").build();
stringlist2 = ImmutableStringList.builder().addStrings("just the one").build();
boollist = ImmutableBoolList.builder().addBools(true).build();
numberlist = ImmutableNumberList.builder().addNumbers(1, 1.5, 12345.67).build();

csv_data = new ArrayList<>();
csv_data.add(new String[] {"apple", "12", "green"});
Expand Down Expand Up @@ -542,18 +541,20 @@ void testWriteAllSortsComponents() {
dp.getComponent("e").writeStrings(stringlist1);
dp.getComponent("f").writeStrings(stringlist2);
dp.getComponent("g").writeBools(boollist);
dp.getComponent("h").writeNumbers(numberlist);
}
String hash = "b756331d77b31ab6ab9c55d1825e3f625997fdf9";
String hash = "68608005664459de456d75941430e7290031ddb5";
check_last_coderun(
null,
Arrays.asList(
new Triplet<>(dataProduct, "a", hash),
new Triplet<>(dataProduct, "b", hash),
new Triplet<>(dataProduct, "c", hash),
new Triplet<>(dataProduct, "d", hash),
new Triplet<>(dataProduct, "e", hash),
new Triplet<>(dataProduct, "f", hash),
new Triplet<>(dataProduct, "g", hash)));
null,
Arrays.asList(
new Triplet<>(dataProduct, "a", hash),
new Triplet<>(dataProduct, "b", hash),
new Triplet<>(dataProduct, "c", hash),
new Triplet<>(dataProduct, "d", hash),
new Triplet<>(dataProduct, "e", hash),
new Triplet<>(dataProduct, "f", hash),
new Triplet<>(dataProduct, "g", hash),
new Triplet<>(dataProduct, "h", hash)));
}

@Test
Expand All @@ -562,25 +563,28 @@ void testReadAllSortsComponents() {
String dataProduct = "human/allsortscomp";
try (var coderun = new Coderun(configPath, scriptPath, token)) {
Data_product_read dc = coderun.get_dp_for_read(dataProduct);
assertThat(dc.getComponent("a").readSamples()).containsExactly(1,2,3);
assertThat(dc.getComponent("b").readSamples()).containsExactly(4,5,6);
assertThat(dc.getComponent("c").readDistribution().getDistribution().internalType()).isEqualTo(DistributionType.categorical);
assertThat(dc.getComponent("a").readSamples()).containsExactly(1, 2, 3);
assertThat(dc.getComponent("b").readSamples()).containsExactly(4, 5, 6);
assertThat(dc.getComponent("c").readDistribution().getDistribution().internalType())
.isEqualTo(DistributionType.categorical);
assertThat(dc.getComponent("d").readEstimate()).isEqualTo(estimate);
assertThat(dc.getComponent("e").readStrings()).containsExactly("do", "re", "mi");
assertThat(dc.getComponent("f").readStrings()).containsExactly("just the one");
assertThat(dc.getComponent("g").readBools()).containsExactly(true);
assertThat(dc.getComponent("h").readNumbers()).containsExactly(1.0, 1.5, 12345.67);
}
String hash = "b756331d77b31ab6ab9c55d1825e3f625997fdf9";
String hash = "68608005664459de456d75941430e7290031ddb5";
check_last_coderun(
Arrays.asList(
new Triplet<>(dataProduct, "a", hash),
new Triplet<>(dataProduct, "b", hash),
new Triplet<>(dataProduct, "c", hash),
new Triplet<>(dataProduct, "d", hash),
new Triplet<>(dataProduct, "e", hash),
new Triplet<>(dataProduct, "f", hash),
new Triplet<>(dataProduct, "g", hash)),
null);
Arrays.asList(
new Triplet<>(dataProduct, "a", hash),
new Triplet<>(dataProduct, "b", hash),
new Triplet<>(dataProduct, "c", hash),
new Triplet<>(dataProduct, "d", hash),
new Triplet<>(dataProduct, "e", hash),
new Triplet<>(dataProduct, "f", hash),
new Triplet<>(dataProduct, "g", hash),
new Triplet<>(dataProduct, "h", hash)),
null);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.fairdatapipeline.parameters;

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Test;

class NumberListTest {

/**
* test that a mix of ints and doubles will end up as all doubles. (necessary for TOML - it can't read a 'heterogenous array' of mixed int & double numbers)
*
*/
@Test
void makeNumberList() {
var numberlist = ImmutableNumberList.builder().addNumbers(1, 1.5, 12345.67).build();
assertThat(numberlist.numbers()).contains(1.0, 1.5, 12345.67);
}

/**
* given just ints, they should NOT be converted to doubles.
*/
@Test
void makeNumberList2() {
var numberlist = ImmutableNumberList.builder().addNumbers(1, 2, 12).build();
assertThat(numberlist.numbers()).contains(1, 2, 12);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ void derivedSamplesFromSamples() {
assertThat(samples.getSamples()).containsExactly(1, 2, 3);
}

@Test
void heterogenousToDoubles() {
var samples = ImmutableSamples.builder().addSamples(1.4, 2.5, 3).rng(rng).build();
assertThat(samples.getSamples()).containsExactly(1.4, 2.5, 3.0);
}

@Test
void derivedDistributionFromSamples() {
var samples = ImmutableSamples.builder().addSamples(1, 2, 3).rng(rng).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,36 @@
import com.fasterxml.jackson.core.type.TypeReference;
import java.io.StringReader;
import org.apache.commons.math3.random.RandomGenerator;
import org.fairdatapipeline.distribution.Distribution;
import org.fairdatapipeline.distribution.ImmutableDistribution;
import org.fairdatapipeline.estimate.ImmutableEstimate;
import org.fairdatapipeline.parameters.Components;
import org.fairdatapipeline.parameters.ImmutableComponents;
import org.fairdatapipeline.parameters.ImmutableStringList;
import org.fairdatapipeline.samples.ImmutableSamples;
import org.junit.jupiter.api.*;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class TomlReaderPairwiseIntegrationTest {
private final String toml =
"[example-estimate]\n" + "type = \"point-estimate\"\n" + "value = 1.0";

private final String toml2 =
"[example-estimate]\n"
+ "value = 1.0\n"
+ "type = \"point-estimate\"\n"
+ "[example-distribution]\n"
+ "distribution = \"gamma\"\n"
+ "shape = 1.0\n"
+ "scale = 2.0\n"
+ "type = \"distribution\"\n"
+ "[example-samples]\n"
+ "samples = [1.5, 2.0, 3.0]\n"
+ "type = \"samples\"\n"
+ "[example-strings]\n"
+ "strings = [\"bla\", \"blo\"]\n"
+ "type = \"strings\"\n";

private RandomGenerator rng;

@BeforeAll
Expand All @@ -36,4 +56,31 @@ void read() {

assertThat(components_read).isEqualTo(components);
}

@Test
void read2() {
TomlReader tomlReader = new TomlReader(new TOMLMapper(rng));
var reader = new StringReader(toml2);
var estimate = ImmutableEstimate.builder().internalValue(1.0).rng(rng).build();
Components components =
ImmutableComponents.builder()
.putComponents("example-estimate", estimate)
.putComponents(
"example-distribution",
ImmutableDistribution.builder()
.internalShape(1)
.internalScale(2)
.internalType(Distribution.DistributionType.gamma)
.rng(rng)
.build())
.putComponents(
"example-samples",
ImmutableSamples.builder().addSamples(1.5, 2, 3).rng(rng).build())
.putComponents(
"example-strings", ImmutableStringList.builder().addStrings("bla", "blo").build())
.build();
Components components_read = tomlReader.read(reader, new TypeReference<>() {});

assertThat(components_read.components()).containsAllEntriesOf(components.components());
}
}
Loading

0 comments on commit 59b61ef

Please sign in to comment.