Skip to content

Commit

Permalink
Fixes NPE when rewriting column without column index
Browse files Browse the repository at this point in the history
  • Loading branch information
ConeyLiu committed Oct 17, 2023
1 parent 5ee5133 commit a2a9714
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,16 @@ int sizeOf(Object value) {
}
};

private PrimitiveType type;
private final BooleanList nullPages = new BooleanArrayList();
private final LongList nullCounts = new LongArrayList();
private long minMaxSize;
private final IntList pageIndexes = new IntArrayList();

private PrimitiveType type;
private long minMaxSize;
private int nextPageIndex;

protected boolean invalid;

/**
* @return a no-op builder that does not collect statistics objects and therefore returns {@code null} at
* {@link #build()}.
Expand Down Expand Up @@ -543,6 +546,11 @@ public static ColumnIndex build(
* the statistics to be added
*/
public void add(Statistics<?> stats) {
if (stats.isEmpty()) {
invalid = true;
return;
}

if (stats.hasNonNullValue()) {
nullPages.add(false);
Object min = stats.genericGetMin();
Expand Down Expand Up @@ -603,7 +611,7 @@ public ColumnIndex build() {
}

private ColumnIndexBase<?> build(PrimitiveType type) {
if (nullPages.isEmpty()) {
if (nullPages.isEmpty() || invalid) {
return null;
}
ColumnIndexBase<?> columnIndex = createColumnIndex(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ int compareValueToMax(int arrayIndex) {

private final DoubleList minValues = new DoubleArrayList();
private final DoubleList maxValues = new DoubleArrayList();
private boolean invalid;

private static double convert(ByteBuffer buffer) {
return buffer.order(LITTLE_ENDIAN).getDouble(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ int compareValueToMax(int arrayIndex) {

private final FloatList minValues = new FloatArrayList();
private final FloatList maxValues = new FloatArrayList();
private boolean invalid;

private static float convert(ByteBuffer buffer) {
return buffer.order(LITTLE_ENDIAN).getFloat(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,13 +612,13 @@ public void writeDataPage(
* @throws IOException if any I/O error occurs during writing the file
*/
public void writeDataPage(
int valueCount, int uncompressedPageSize,
BytesInput bytes,
Statistics<?> statistics,
long rowCount,
Encoding rlEncoding,
Encoding dlEncoding,
Encoding valuesEncoding) throws IOException {
int valueCount, int uncompressedPageSize,
BytesInput bytes,
Statistics<?> statistics,
long rowCount,
Encoding rlEncoding,
Encoding dlEncoding,
Encoding valuesEncoding) throws IOException {
writeDataPage(valueCount, uncompressedPageSize, bytes, statistics, rowCount, rlEncoding, dlEncoding, valuesEncoding, null, null);
}

Expand Down Expand Up @@ -988,6 +988,21 @@ void writeColumnChunk(ColumnDescriptor descriptor,
endColumn();
}

/**
* Overwrite the column total statistics. This special used when the column total statistics
* is known while all the page statistics are invalid, for example when rewriting the column.
*
* @param totalStatistics the column total statistics
* @throws IOException if there is an error while writing
*/
public void endColumn(Statistics<?> totalStatistics) throws IOException {
Preconditions.checkArgument(totalStatistics != null, "Column total statistics can not be null");
currentStatistics = totalStatistics;
// Invalid the ColumnIndex
columnIndexBuilder = ColumnIndexBuilder.getNoOpBuilder();
endColumn();
}

/**
* end a column (once all rep, def and data have been written)
* @throws IOException if there is an error while writing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,13 @@ private void processBlocksFromReader() throws IOException {

// Translate compression and/or encryption
writer.startColumn(descriptor, crStore.getColumnReader(descriptor).getTotalValueCount(), newCodecName);
processChunk(chunk, newCodecName, columnChunkEncryptorRunTime, encryptColumn);
writer.endColumn();
boolean needOverwriteStatistics = processChunk(chunk, newCodecName, columnChunkEncryptorRunTime, encryptColumn);
if (needOverwriteStatistics) {
// All the page statistics are invalid, so we need to overwrite the column statistics
writer.endColumn(chunk.getStatistics());
} else {
writer.endColumn();
}
} else {
// Nothing changed, simply copy the binary data.
BloomFilter bloomFilter = reader.readBloomFilter(chunk);
Expand All @@ -335,10 +340,15 @@ private void processBlocksFromReader() throws IOException {
}
}

private void processChunk(ColumnChunkMetaData chunk,
CompressionCodecName newCodecName,
ColumnChunkEncryptorRunTime columnChunkEncryptorRunTime,
boolean encryptColumn) throws IOException {
/**
* Rewrite a single column with the given new compression codec or new encryptor
*
* @return whether the column statistics should be overwritten
*/
private boolean processChunk(ColumnChunkMetaData chunk,
CompressionCodecName newCodecName,
ColumnChunkEncryptorRunTime columnChunkEncryptorRunTime,
boolean encryptColumn) throws IOException {
CompressionCodecFactory codecFactory = HadoopCodecs.newFactory(0);
CompressionCodecFactory.BytesInputDecompressor decompressor = null;
CompressionCodecFactory.BytesInputCompressor compressor = null;
Expand Down Expand Up @@ -375,6 +385,7 @@ private void processChunk(ColumnChunkMetaData chunk,
DictionaryPage dictionaryPage = null;
long readValues = 0;
Statistics<?> statistics = null;
Statistics<?> emptyStatistics = Statistics.getBuilderForReading(chunk.getPrimitiveType()).build();
ParquetMetadataConverter converter = new ParquetMetadataConverter();
int pageOrdinal = 0;
long totalChunkValues = chunk.getValueCount();
Expand Down Expand Up @@ -420,16 +431,28 @@ private void processChunk(ColumnChunkMetaData chunk,
encryptColumn,
dataEncryptor,
dataPageAAD);
statistics = convertStatistics(
Statistics<?> v1PageStatistics = convertStatistics(
originalCreatedBy, chunk.getPrimitiveType(), headerV1.getStatistics(), columnIndex, pageOrdinal, converter);
if (v1PageStatistics == null) {
// Reach here means both the columnIndex and the page header statistics are null
if (statistics != null) {
// Mixed null page statistics and non-null page statistics is not allowed
throw new IOException("Detected mixed null page statistics and non-null page statistics");
}
// Pass an empty page statistics to writer and overwrite the column statistics in the end
v1PageStatistics = emptyStatistics;
} else {
statistics = v1PageStatistics;
}

readValues += headerV1.getNum_values();
if (offsetIndex != null) {
long rowCount = 1 + offsetIndex.getLastRowIndex(
pageOrdinal, totalChunkValues) - offsetIndex.getFirstRowIndex(pageOrdinal);
writer.writeDataPage(toIntWithCheck(headerV1.getNum_values()),
pageHeader.getUncompressed_page_size(),
BytesInput.from(pageLoad),
statistics,
v1PageStatistics,
toIntWithCheck(rowCount),
converter.getEncoding(headerV1.getRepetition_level_encoding()),
converter.getEncoding(headerV1.getDefinition_level_encoding()),
Expand All @@ -440,7 +463,7 @@ private void processChunk(ColumnChunkMetaData chunk,
writer.writeDataPage(toIntWithCheck(headerV1.getNum_values()),
pageHeader.getUncompressed_page_size(),
BytesInput.from(pageLoad),
statistics,
v1PageStatistics,
converter.getEncoding(headerV1.getRepetition_level_encoding()),
converter.getEncoding(headerV1.getDefinition_level_encoding()),
converter.getEncoding(headerV1.getEncoding()),
Expand Down Expand Up @@ -471,8 +494,19 @@ private void processChunk(ColumnChunkMetaData chunk,
encryptColumn,
dataEncryptor,
dataPageAAD);
statistics = convertStatistics(
Statistics<?> v2PageStatistics = convertStatistics(
originalCreatedBy, chunk.getPrimitiveType(), headerV2.getStatistics(), columnIndex, pageOrdinal, converter);
if (v2PageStatistics == null) {
// Reach here means both the columnIndex and the page header statistics are null
if (statistics != null) {
// Mixed null page statistics and non-null page statistics is not allowed
throw new IOException("Detected mixed null page statistics and non-null page statistics");
}
// Pass an empty page statistics to writer and overwrite the column statistics in the end
v2PageStatistics = emptyStatistics;
} else {
statistics = v2PageStatistics;
}
readValues += headerV2.getNum_values();
writer.writeDataPageV2(headerV2.getNum_rows(),
headerV2.getNum_nulls(),
Expand All @@ -482,7 +516,7 @@ private void processChunk(ColumnChunkMetaData chunk,
converter.getEncoding(headerV2.getEncoding()),
BytesInput.from(pageLoad),
rawDataLength,
statistics,
v2PageStatistics,
metaEncryptor,
dataPageHeaderAAD);
pageOrdinal++;
Expand All @@ -492,6 +526,8 @@ private void processChunk(ColumnChunkMetaData chunk,
break;
}
}

return statistics == null;
}

private Statistics<?> convertStatistics(String createdBy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@
import java.util.stream.Collectors;

import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
import static org.apache.parquet.schema.Type.Repetition.OPTIONAL;
import static org.apache.parquet.schema.Type.Repetition.REPEATED;
Expand Down Expand Up @@ -123,17 +125,7 @@ private void testPruneSingleColumnTranslateCodec(List<Path> inputPaths) throws E
rewriter.close();

// Verify the schema are not changed for the columns not pruned
ParquetMetadata pmd = ParquetFileReader.readFooter(conf, new Path(outputFile), ParquetMetadataConverter.NO_FILTER);
MessageType schema = pmd.getFileMetaData().getSchema();
List<Type> fields = schema.getFields();
assertEquals(fields.size(), 3);
assertEquals(fields.get(0).getName(), "DocId");
assertEquals(fields.get(1).getName(), "Name");
assertEquals(fields.get(2).getName(), "Links");
List<Type> subFields = fields.get(2).asGroupType().getFields();
assertEquals(subFields.size(), 2);
assertEquals(subFields.get(0).getName(), "Backward");
assertEquals(subFields.get(1).getName(), "Forward");
validateSchema();

// Verify codec has been translated
verifyCodec(outputFile, new HashSet<CompressionCodecName>() {{
Expand Down Expand Up @@ -194,17 +186,7 @@ private void testPruneNullifyTranslateCodec(List<Path> inputPaths) throws Except
rewriter.close();

// Verify the schema are not changed for the columns not pruned
ParquetMetadata pmd = ParquetFileReader.readFooter(conf, new Path(outputFile), ParquetMetadataConverter.NO_FILTER);
MessageType schema = pmd.getFileMetaData().getSchema();
List<Type> fields = schema.getFields();
assertEquals(fields.size(), 3);
assertEquals(fields.get(0).getName(), "DocId");
assertEquals(fields.get(1).getName(), "Name");
assertEquals(fields.get(2).getName(), "Links");
List<Type> subFields = fields.get(2).asGroupType().getFields();
assertEquals(subFields.size(), 2);
assertEquals(subFields.get(0).getName(), "Backward");
assertEquals(subFields.get(1).getName(), "Forward");
validateSchema();

// Verify codec has been translated
verifyCodec(outputFile, new HashSet<CompressionCodecName>() {{
Expand Down Expand Up @@ -269,17 +251,7 @@ private void testPruneEncryptTranslateCodec(List<Path> inputPaths) throws Except
rewriter.close();

// Verify the schema are not changed for the columns not pruned
ParquetMetadata pmd = ParquetFileReader.readFooter(conf, new Path(outputFile), ParquetMetadataConverter.NO_FILTER);
MessageType schema = pmd.getFileMetaData().getSchema();
List<Type> fields = schema.getFields();
assertEquals(fields.size(), 3);
assertEquals(fields.get(0).getName(), "DocId");
assertEquals(fields.get(1).getName(), "Name");
assertEquals(fields.get(2).getName(), "Links");
List<Type> subFields = fields.get(2).asGroupType().getFields();
assertEquals(subFields.size(), 2);
assertEquals(subFields.get(0).getName(), "Backward");
assertEquals(subFields.get(1).getName(), "Forward");
validateSchema();

// Verify codec has been translated
FileDecryptionProperties fileDecryptionProperties = EncDecProperties.getFileDecryptionProperties();
Expand Down Expand Up @@ -660,6 +632,8 @@ private MessageType createSchema() {
new PrimitiveType(OPTIONAL, INT64, "DocId"),
new PrimitiveType(REQUIRED, BINARY, "Name"),
new PrimitiveType(OPTIONAL, BINARY, "Gender"),
new PrimitiveType(REPEATED, FLOAT, "FloatFraction"),
new PrimitiveType(OPTIONAL, DOUBLE, "DoubleFraction"),
new GroupType(OPTIONAL, "Links",
new PrimitiveType(REPEATED, BINARY, "Backward"),
new PrimitiveType(REPEATED, BINARY, "Forward")));
Expand Down Expand Up @@ -701,6 +675,16 @@ private void validateColumnData(Set<String> prunePaths,
expectGroup.getBinary("Gender", 0).getBytes());
}

if (!prunePaths.contains("FloatFraction") && !nullifiedPaths.contains("FloatFraction")) {
assertEquals(group.getFloat("FloatFraction", 0),
expectGroup.getFloat("FloatFraction", 0), 0);
}

if (!prunePaths.contains("DoubleFraction") && !nullifiedPaths.contains("DoubleFraction")) {
assertEquals(group.getDouble("DoubleFraction", 0),
expectGroup.getDouble("DoubleFraction", 0), 0);
}

Group subGroup = group.getGroup("Links", 0);

if (!prunePaths.contains("Links.Backward") && !nullifiedPaths.contains("Links.Backward")) {
Expand Down Expand Up @@ -937,4 +921,20 @@ private Map<ColumnPath, List<BloomFilter>> allBloomFilters(

return allBloomFilters;
}

private void validateSchema() throws IOException {
ParquetMetadata pmd = ParquetFileReader.readFooter(conf, new Path(outputFile), ParquetMetadataConverter.NO_FILTER);
MessageType schema = pmd.getFileMetaData().getSchema();
List<Type> fields = schema.getFields();
assertEquals(fields.size(), 5);
assertEquals(fields.get(0).getName(), "DocId");
assertEquals(fields.get(1).getName(), "Name");
assertEquals(fields.get(2).getName(), "FloatFraction");
assertEquals(fields.get(3).getName(), "DoubleFraction");
assertEquals(fields.get(4).getName(), "Links");
List<Type> subFields = fields.get(4).asGroupType().getFields();
assertEquals(subFields.size(), 2);
assertEquals(subFields.get(0).getName(), "Backward");
assertEquals(subFields.get(1).getName(), "Forward");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import java.util.concurrent.ThreadLocalRandom;

import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.FLOAT;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;

Expand Down Expand Up @@ -175,7 +177,13 @@ else if (primitiveType.getPrimitiveTypeName().equals(INT64)) {
else if (primitiveType.getPrimitiveTypeName().equals(BINARY)) {
g.add(type.getName(), getString());
}
// Only support 3 types now, more can be added later
else if (primitiveType.getPrimitiveTypeName().equals(FLOAT)) {
g.add(type.getName(), getFloat());
}
else if (primitiveType.getPrimitiveTypeName().equals(DOUBLE)) {
g.add(type.getName(), getDouble());
}
// Only support 5 types now, more can be added later
}
else {
GroupType groupType = (GroupType) type;
Expand Down Expand Up @@ -206,6 +214,22 @@ private static String getString()
return sb.toString();
}

private static float getFloat()
{
if (ThreadLocalRandom.current().nextBoolean()) {
return Float.NaN;
}
return ThreadLocalRandom.current().nextFloat();
}

private static double getDouble()
{
if (ThreadLocalRandom.current().nextBoolean()) {
return Double.NaN;
}
return ThreadLocalRandom.current().nextDouble();
}

public static String createTempFile(String prefix)
{
try {
Expand Down

0 comments on commit a2a9714

Please sign in to comment.