diff --git a/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java b/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java index 363a80cd..fbb97032 100644 --- a/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java +++ b/core/src/main/java/com/tdunning/math/stats/AVLTreeDigest.java @@ -17,6 +17,9 @@ package com.tdunning.math.stats; +import com.tdunning.math.stats.serde.AVLTreeDigestCompactSerde; +import com.tdunning.math.stats.serde.DigestModelDefaultSerde; + import java.nio.ByteBuffer; import java.util.Collection; import java.util.Collections; @@ -315,7 +318,7 @@ public double compression() { */ @Override public int byteSize() { - return 32 + summary.size() * 12; + return DigestModelDefaultSerde.byteSize(summary.size()); } /** @@ -324,53 +327,20 @@ public int byteSize() { */ @Override public int smallByteSize() { - int bound = byteSize(); - ByteBuffer buf = ByteBuffer.allocate(bound); - asSmallBytes(buf); - return buf.position(); + return AVLTreeDigestCompactSerde.byteSize(this); } - private final static int VERBOSE_ENCODING = 1; - private final static int SMALL_ENCODING = 2; - /** * Outputs a histogram as bytes using a particularly cheesy encoding. */ @Override public void asBytes(ByteBuffer buf) { - buf.putInt(VERBOSE_ENCODING); - buf.putDouble(min); - buf.putDouble(max); - buf.putDouble((float) compression()); - buf.putInt(summary.size()); - for (Centroid centroid : summary) { - buf.putDouble(centroid.mean()); - } - - for (Centroid centroid : summary) { - buf.putInt(centroid.count()); - } + DigestModelDefaultSerde.serialize(toModel(), buf); } @Override public void asSmallBytes(ByteBuffer buf) { - buf.putInt(SMALL_ENCODING); - buf.putDouble(min); - buf.putDouble(max); - buf.putDouble(compression()); - buf.putInt(summary.size()); - - double x = 0; - for (Centroid centroid : summary) { - double delta = centroid.mean() - x; - x = centroid.mean(); - buf.putFloat((float) delta); - } - - for (Centroid centroid : summary) { - int n = centroid.count(); - encode(buf, n); - } + AVLTreeDigestCompactSerde.serialize(this, buf); } /** @@ -381,45 +351,36 @@ public void asSmallBytes(ByteBuffer buf) { */ @SuppressWarnings("WeakerAccess") public static AVLTreeDigest fromBytes(ByteBuffer buf) { - int encoding = buf.getInt(); - if (encoding == VERBOSE_ENCODING) { - double min = buf.getDouble(); - double max = buf.getDouble(); - double compression = buf.getDouble(); - AVLTreeDigest r = new AVLTreeDigest(compression); - r.setMinMax(min, max); - int n = buf.getInt(); - double[] means = new double[n]; - for (int i = 0; i < n; i++) { - means[i] = buf.getDouble(); - } - for (int i = 0; i < n; i++) { - r.add(means[i], buf.getInt()); - } - return r; - } else if (encoding == SMALL_ENCODING) { - double min = buf.getDouble(); - double max = buf.getDouble(); - double compression = buf.getDouble(); - AVLTreeDigest r = new AVLTreeDigest(compression); - r.setMinMax(min, max); - int n = buf.getInt(); - double[] means = new double[n]; - double x = 0; - for (int i = 0; i < n; i++) { - double delta = buf.getFloat(); - x += delta; - means[i] = x; - } + try { + DigestModel digestModel = DigestModelDefaultSerde.deserialize(buf); + return fromModel(digestModel); + } catch (IllegalArgumentException ex) { + buf.rewind(); //reset the buf position to enable read from start + return AVLTreeDigestCompactSerde.deserialize(buf); + } + } - for (int i = 0; i < n; i++) { - int z = decode(buf); - r.add(means[i], z); - } - return r; - } else { - throw new IllegalStateException("Invalid format for serialized histogram"); + public DigestModel toModel() { + double[] positions = new double[summary.size()]; + double[] weights = new double[summary.size()]; + int i = 0; + for (Centroid centroid : summary) { + positions[i] = centroid.mean(); + weights[i] = centroid.count(); + i++; } + return new DigestModel(compression, min, max, i, positions, weights); } + public static AVLTreeDigest fromModel(DigestModel model) { + AVLTreeDigest r = new AVLTreeDigest(model.compression()); + r.setMinMax(model.min(), model.max()); + double[] mean = model.centroidPositions(); + double[] weight = model.centroidWeights(); + for (int i = 0; i < model.centroidCount(); i++) { + r.add(mean[i], (int) weight[i]); + } + + return r; + } } diff --git a/core/src/main/java/com/tdunning/math/stats/AbstractTDigest.java b/core/src/main/java/com/tdunning/math/stats/AbstractTDigest.java index e75d09e1..e8c5ebc1 100644 --- a/core/src/main/java/com/tdunning/math/stats/AbstractTDigest.java +++ b/core/src/main/java/com/tdunning/math/stats/AbstractTDigest.java @@ -57,35 +57,6 @@ static double interpolate(double x, double x0, double x1) { return (x - x0) / (x1 - x0); } - static void encode(ByteBuffer buf, int n) { - int k = 0; - while (n < 0 || n > 0x7f) { - byte b = (byte) (0x80 | (0x7f & n)); - buf.put(b); - n = n >>> 7; - k++; - if (k >= 6) { - throw new IllegalStateException("Size is implausibly large"); - } - } - buf.put((byte) n); - } - - static int decode(ByteBuffer buf) { - int v = buf.get(); - int z = 0x7f & v; - int shift = 7; - while ((v & 0x80) != 0) { - if (shift > 28) { - throw new IllegalStateException("Shift too large in decode"); - } - v = buf.get(); - z += (v & 0x7f) << shift; - shift += 7; - } - return z; - } - abstract void add(double x, int w, Centroid base); /** diff --git a/core/src/main/java/com/tdunning/math/stats/DigestModel.java b/core/src/main/java/com/tdunning/math/stats/DigestModel.java new file mode 100644 index 00000000..78728710 --- /dev/null +++ b/core/src/main/java/com/tdunning/math/stats/DigestModel.java @@ -0,0 +1,71 @@ +package com.tdunning.math.stats; + +public class DigestModel { + private final double compression; + private final double min; + private final double max; + private final int centroidCount; + private final double[] centroidPositions; + private final double[] centroidWeights; + + //For compact encoding of MergingDigest + private Integer mainBufferSize = null; + private Integer tempBufferSize = null; + + private boolean compactEncoding = false; + + public DigestModel(double compression, double min, double max, int centroidCount, double[] centroidPositions, double[] centroidWeights) { + this.compression = compression; + this.min = min; + this.max = max; + this.centroidCount = centroidCount; + this.centroidPositions = centroidPositions; + this.centroidWeights = centroidWeights; + } + + public DigestModel(double compression, double min, double max, int centroidCount, double[] centroidPositions, double[] centroidWeights, int mainBufferSize, int tempBufferSize) { + this(compression, min, max, centroidCount, centroidPositions, centroidWeights); + this.mainBufferSize = mainBufferSize; + this.tempBufferSize = tempBufferSize; + } + + public void setCompactEncoding(boolean compactEncoding) { + this.compactEncoding = compactEncoding; + } + + public double compression() { + return compression; + } + + public double min() { + return min; + } + + public double max() { + return max; + } + + public int centroidCount() { + return centroidCount; + } + + public double[] centroidPositions() { + return centroidPositions; + } + + public double[] centroidWeights() { + return centroidWeights; + } + + public boolean compactEncoding() { + return compactEncoding; + } + + public Integer mainBufferSize() { + return mainBufferSize; + } + + public Integer tempBufferSize() { + return tempBufferSize; + } +} diff --git a/core/src/main/java/com/tdunning/math/stats/MergingDigest.java b/core/src/main/java/com/tdunning/math/stats/MergingDigest.java index d43f0c6f..0eb5f9f9 100644 --- a/core/src/main/java/com/tdunning/math/stats/MergingDigest.java +++ b/core/src/main/java/com/tdunning/math/stats/MergingDigest.java @@ -17,6 +17,9 @@ package com.tdunning.math.stats; +import com.tdunning.math.stats.serde.DigestModelDefaultSerde; +import com.tdunning.math.stats.serde.MergingDigestCompactSerde; + import java.nio.ByteBuffer; import java.util.AbstractCollection; import java.util.ArrayList; @@ -752,96 +755,60 @@ public double compression() { @Override public int byteSize() { compress(); - // format code, compression(float), buffer-size(int), temp-size(int), #centroids-1(int), - // then two doubles per centroid - return lastUsedCell * 16 + 32; + return DigestModelDefaultSerde.byteSize(lastUsedCell); } @Override public int smallByteSize() { compress(); - // format code(int), compression(float), buffer-size(short), temp-size(short), #centroids-1(short), - // then two floats per centroid - return lastUsedCell * 8 + 30; - } - - public enum Encoding { - VERBOSE_ENCODING(1), SMALL_ENCODING(2); - - private final int code; - - Encoding(int code) { - this.code = code; - } + return MergingDigestCompactSerde.byteSize(lastUsedCell); } @Override public void asBytes(ByteBuffer buf) { compress(); - buf.putInt(Encoding.VERBOSE_ENCODING.code); - buf.putDouble(min); - buf.putDouble(max); - buf.putDouble(compression); - buf.putInt(lastUsedCell); - for (int i = 0; i < lastUsedCell; i++) { - buf.putDouble(weight[i]); - buf.putDouble(mean[i]); - } + DigestModelDefaultSerde.serialize(toModel(), buf); } @Override public void asSmallBytes(ByteBuffer buf) { compress(); - buf.putInt(Encoding.SMALL_ENCODING.code); // 4 - buf.putDouble(min); // + 8 - buf.putDouble(max); // + 8 - buf.putFloat((float) compression); // + 4 - buf.putShort((short) mean.length); // + 2 - buf.putShort((short) tempMean.length); // + 2 - buf.putShort((short) lastUsedCell); // + 2 = 30 - for (int i = 0; i < lastUsedCell; i++) { - buf.putFloat((float) weight[i]); - buf.putFloat((float) mean[i]); - } + MergingDigestCompactSerde.serialize(this, buf); } @SuppressWarnings("WeakerAccess") public static MergingDigest fromBytes(ByteBuffer buf) { - int encoding = buf.getInt(); - if (encoding == Encoding.VERBOSE_ENCODING.code) { - double min = buf.getDouble(); - double max = buf.getDouble(); - double compression = buf.getDouble(); - int n = buf.getInt(); - MergingDigest r = new MergingDigest(compression); - r.setMinMax(min, max); - r.lastUsedCell = n; - for (int i = 0; i < n; i++) { - r.weight[i] = buf.getDouble(); - r.mean[i] = buf.getDouble(); - - r.totalWeight += r.weight[i]; - } - return r; - } else if (encoding == Encoding.SMALL_ENCODING.code) { - double min = buf.getDouble(); - double max = buf.getDouble(); - double compression = buf.getFloat(); - int n = buf.getShort(); - int bufferSize = buf.getShort(); - MergingDigest r = new MergingDigest(compression, bufferSize, n); - r.setMinMax(min, max); - r.lastUsedCell = buf.getShort(); - for (int i = 0; i < r.lastUsedCell; i++) { - r.weight[i] = buf.getFloat(); - r.mean[i] = buf.getFloat(); - - r.totalWeight += r.weight[i]; - } - return r; + try { + DigestModel digestModel = DigestModelDefaultSerde.deserialize(buf); + return fromModel(digestModel); + } catch (IllegalArgumentException ex) { + buf.rewind(); //reset the buf position to enable read from start + return MergingDigestCompactSerde.deserialize(buf); + } + } + + public DigestModel toModel() { + compress(); + return new DigestModel(compression, min, max, lastUsedCell, mean, weight, mean.length, tempMean.length); + } + + public static MergingDigest fromModel(DigestModel model) { + MergingDigest r; + if(model.compactEncoding()) { + r = new MergingDigest(model.compression(), model.tempBufferSize(), model.mainBufferSize()); } else { - throw new IllegalStateException("Invalid format for serialized histogram"); + r = new MergingDigest(model.compression()); } + r.setMinMax(model.min(), model.max()); + r.lastUsedCell = model.centroidCount(); + double[] mean = model.centroidPositions(); + double[] weight = model.centroidWeights(); + for (int i = 0;i < model.centroidCount();i++) { + r.mean[i] = mean[i]; + r.weight[i] = weight[i]; + r.totalWeight += weight[i]; + } + return r; } } diff --git a/core/src/main/java/com/tdunning/math/stats/TDigest.java b/core/src/main/java/com/tdunning/math/stats/TDigest.java index ed663dcb..015ecbaf 100644 --- a/core/src/main/java/com/tdunning/math/stats/TDigest.java +++ b/core/src/main/java/com/tdunning/math/stats/TDigest.java @@ -184,6 +184,13 @@ final void checkValue(double x) { */ public abstract void asSmallBytes(ByteBuffer buf); + /** + * Returns final representation of the digest as a plain data structure + * Can be used for serde, inter-conversion, etc + * @return DigestModel which captures minimum required fields for representing the digest accurately + */ + public abstract DigestModel toModel(); + /** * Tell this TDigest to record the original data as much as possible for test * purposes. diff --git a/core/src/main/java/com/tdunning/math/stats/serde/AVLTreeDigestCompactSerde.java b/core/src/main/java/com/tdunning/math/stats/serde/AVLTreeDigestCompactSerde.java new file mode 100644 index 00000000..ae31dc33 --- /dev/null +++ b/core/src/main/java/com/tdunning/math/stats/serde/AVLTreeDigestCompactSerde.java @@ -0,0 +1,91 @@ +package com.tdunning.math.stats.serde; + +import com.tdunning.math.stats.AVLTreeDigest; +import com.tdunning.math.stats.DigestModel; + +import java.nio.ByteBuffer; + +public class AVLTreeDigestCompactSerde { + + public static int byteSize(AVLTreeDigest digest) { + DigestModel model = digest.toModel(); + int bound = DigestModelDefaultSerde.byteSize(model.centroidCount()); + ByteBuffer buf = ByteBuffer.allocate(bound); + serialize(digest, buf); + return buf.position(); + } + + public static void serialize(AVLTreeDigest digest, ByteBuffer buf) { + DigestModel model = digest.toModel(); + buf.putInt(Encoding.COMPACT.code()); + buf.putDouble(model.min()); + buf.putDouble(model.max()); + buf.putDouble(model.compression()); + buf.putInt(model.centroidCount()); + + double[] position = model.centroidPositions(); + double[] weight = model.centroidWeights(); + double x = 0; + for(int i = 0;i < model.centroidCount(); i++) { + double delta = position[i] - x; + x = position[i]; + buf.putFloat((float) delta); + encodeInt(buf, (int) weight[i]); + } + } + + public static AVLTreeDigest deserialize(ByteBuffer buf) { + boolean compactEncoding = buf.getInt() == Encoding.COMPACT.code(); + if (!compactEncoding) { + throw new IllegalArgumentException("Serialization was not done using compact encoding, cannot deserialize"); + } + + double min = buf.getDouble(); + double max = buf.getDouble(); + double compression = buf.getDouble(); + int centroidCount = buf.getInt(); + + double[] position = new double[centroidCount]; + double[] weight = new double[centroidCount]; + double x = 0; + for (int i = 0; i < centroidCount; i++) { + double delta = buf.getFloat(); + x += delta; + position[i] = x; + weight[i] = decodeInt(buf); + } + + DigestModel model = new DigestModel(compression, min, max, centroidCount, position, weight); + model.setCompactEncoding(true); + return AVLTreeDigest.fromModel(model); + } + + public static void encodeInt(ByteBuffer buf, int n) { + int k = 0; + while (n < 0 || n > 0x7f) { + byte b = (byte) (0x80 | (0x7f & n)); + buf.put(b); + n = n >>> 7; + k++; + if (k >= 6) { + throw new IllegalStateException("Size is implausibly large"); + } + } + buf.put((byte) n); + } + + public static int decodeInt(ByteBuffer buf) { + int v = buf.get(); + int z = 0x7f & v; + int shift = 7; + while ((v & 0x80) != 0) { + if (shift > 28) { + throw new IllegalStateException("Shift too large in decode"); + } + v = buf.get(); + z += (v & 0x7f) << shift; + shift += 7; + } + return z; + } +} diff --git a/core/src/main/java/com/tdunning/math/stats/serde/DigestModelDefaultSerde.java b/core/src/main/java/com/tdunning/math/stats/serde/DigestModelDefaultSerde.java new file mode 100644 index 00000000..7959725f --- /dev/null +++ b/core/src/main/java/com/tdunning/math/stats/serde/DigestModelDefaultSerde.java @@ -0,0 +1,47 @@ +package com.tdunning.math.stats.serde; + +import com.tdunning.math.stats.DigestModel; + +import java.nio.ByteBuffer; + +public class DigestModelDefaultSerde { + + public static int byteSize(int centroids) { + return (centroids * 16) + 32; + } + + public static void serialize(DigestModel model, ByteBuffer buf) { + buf.putInt(Encoding.VERBOSE.code()); + buf.putDouble(model.min()); + buf.putDouble(model.max()); + buf.putDouble(model.compression()); + buf.putInt(model.centroidCount()); + double[] position = model.centroidPositions(); + double[] weight = model.centroidWeights(); + for (int i = 0; i < model.centroidCount(); i++) { + buf.putDouble(position[i]); + buf.putDouble(weight[i]); + } + } + + public static DigestModel deserialize(ByteBuffer buf) { + boolean verboseEncoding = buf.getInt() == Encoding.VERBOSE.code(); + if (!verboseEncoding) { + throw new IllegalArgumentException("Serialization was not done using verbose encoding, cannot deserialize"); + } + + double min = buf.getDouble(); + double max = buf.getDouble(); + double compression = buf.getDouble(); + int centroidCount = buf.getInt(); + double[] position = new double[centroidCount]; + double[] weight = new double[centroidCount]; + for (int i = 0; i < centroidCount; i++) { + position[i] = buf.getDouble(); + weight[i] = buf.getDouble(); + } + + return new DigestModel(compression, min, max, centroidCount, position, weight); + } + +} diff --git a/core/src/main/java/com/tdunning/math/stats/serde/Encoding.java b/core/src/main/java/com/tdunning/math/stats/serde/Encoding.java new file mode 100644 index 00000000..a48e5abf --- /dev/null +++ b/core/src/main/java/com/tdunning/math/stats/serde/Encoding.java @@ -0,0 +1,15 @@ +package com.tdunning.math.stats.serde; + +public enum Encoding { + VERBOSE(1), COMPACT(2); + + private final int code; + + Encoding(int code) { + this.code = code; + } + + public int code() { + return this.code; + } +} \ No newline at end of file diff --git a/core/src/main/java/com/tdunning/math/stats/serde/MergingDigestCompactSerde.java b/core/src/main/java/com/tdunning/math/stats/serde/MergingDigestCompactSerde.java new file mode 100644 index 00000000..0cad94e6 --- /dev/null +++ b/core/src/main/java/com/tdunning/math/stats/serde/MergingDigestCompactSerde.java @@ -0,0 +1,58 @@ +package com.tdunning.math.stats.serde; + +import com.tdunning.math.stats.DigestModel; +import com.tdunning.math.stats.MergingDigest; + +import java.nio.ByteBuffer; + + +public class MergingDigestCompactSerde { + + public static int byteSize(int centroids) { + return (centroids * 8) + 30; + } + + public static void serialize(MergingDigest digest, ByteBuffer buf) { + DigestModel model = digest.toModel(); + buf.putInt(Encoding.COMPACT.code()); + buf.putDouble(model.min()); // + 8 + buf.putDouble(model.max()); // + 8 + buf.putFloat((float) model.compression()); // + 4 + buf.putShort((short) model.mainBufferSize().intValue()); // + 2 + buf.putShort((short) model.tempBufferSize().intValue()); // + 2 + buf.putShort((short) model.centroidCount()); // + 2 = 30 + + double[] position = model.centroidPositions(); + double[] weight = model.centroidWeights(); + for (int i = 0; i < model.centroidCount(); i++) { + buf.putFloat((float) position[i]); + buf.putFloat((float) weight[i]); + } + } + + public static MergingDigest deserialize(ByteBuffer buf) { + boolean compactEncoding = buf.getInt() == Encoding.COMPACT.code(); + if (!compactEncoding) { + throw new IllegalArgumentException("Serialization was not done using compact encoding, cannot deserialize"); + } + + double min = buf.getDouble(); + double max = buf.getDouble(); + double compression = buf.getFloat(); + int mainBufferSize = buf.getShort(); + int tempBufferSize = buf.getShort(); + int centroidCount = buf.getShort(); + + double[] position = new double[centroidCount]; + double[] weight = new double[centroidCount]; + for (int i = 0; i < centroidCount; i++) { + position[i] = buf.getFloat(); + weight[i] = buf.getFloat(); + } + + DigestModel model = new DigestModel(compression, min, max, centroidCount, position, weight, mainBufferSize, tempBufferSize); + model.setCompactEncoding(true); + return MergingDigest.fromModel(model); + } + +} diff --git a/core/src/test/java/com/tdunning/math/stats/TDigestTest.java b/core/src/test/java/com/tdunning/math/stats/TDigestTest.java index ea08ed4d..999599b5 100644 --- a/core/src/test/java/com/tdunning/math/stats/TDigestTest.java +++ b/core/src/test/java/com/tdunning/math/stats/TDigestTest.java @@ -20,6 +20,7 @@ import com.clearspring.analytics.stream.quantile.QDigest; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import com.tdunning.math.stats.serde.AVLTreeDigestCompactSerde; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.math.jet.random.AbstractContinousDistribution; import org.apache.mahout.math.jet.random.Gamma; @@ -732,13 +733,13 @@ public void testIntEncoding() { int n = gen.nextInt(); n = n >>> (i / 100); ref.add(n); - AbstractTDigest.encode(buf, n); + AVLTreeDigestCompactSerde.encodeInt(buf, n); } buf.flip(); for (int i = 0; i < 3000; i++) { - int n = AbstractTDigest.decode(buf); + int n = AVLTreeDigestCompactSerde.decodeInt(buf); assertEquals(String.format("%d:", i), ref.get(i).intValue(), n); } } diff --git a/core/src/test/java/com/tdunning/math/stats/TDigestUtilTest.java b/core/src/test/java/com/tdunning/math/stats/TDigestUtilTest.java index c750333d..5b74c89b 100644 --- a/core/src/test/java/com/tdunning/math/stats/TDigestUtilTest.java +++ b/core/src/test/java/com/tdunning/math/stats/TDigestUtilTest.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Random; +import com.tdunning.math.stats.serde.AVLTreeDigestCompactSerde; import org.junit.Test; import com.google.common.collect.Lists; @@ -36,13 +37,13 @@ public void testIntEncoding() { int n = gen.nextInt(); n = n >>> (i / 100); ref.add(n); - AbstractTDigest.encode(buf, n); + AVLTreeDigestCompactSerde.encodeInt(buf, n); } buf.flip(); for (int i = 0; i < 3000; i++) { - int n = AbstractTDigest.decode(buf); + int n = AVLTreeDigestCompactSerde.decodeInt(buf); assertEquals(String.format("%d:", i), ref.get(i).intValue(), n); } }