Skip to content

Commit

Permalink
Fix array index OOB ex for vectors with dictionary and nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
gene-db committed Apr 26, 2024
1 parent d540786 commit f66452b
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ public byte[] getBytes(int rowId, int count) {
Platform.copyMemory(null, data + rowId, array, Platform.BYTE_ARRAY_OFFSET, count);
} else {
for (int i = 0; i < count; i++) {
array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -279,7 +281,9 @@ public short[] getShorts(int rowId, int count) {
Platform.copyMemory(null, data + rowId * 2L, array, Platform.SHORT_ARRAY_OFFSET, count * 2L);
} else {
for (int i = 0; i < count; i++) {
array[i] = (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -345,7 +349,9 @@ public int[] getInts(int rowId, int count) {
Platform.copyMemory(null, data + rowId * 4L, array, Platform.INT_ARRAY_OFFSET, count * 4L);
} else {
for (int i = 0; i < count; i++) {
array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -423,7 +429,9 @@ public long[] getLongs(int rowId, int count) {
Platform.copyMemory(null, data + rowId * 8L, array, Platform.LONG_ARRAY_OFFSET, count * 8L);
} else {
for (int i = 0; i < count; i++) {
array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -487,7 +495,9 @@ public float[] getFloats(int rowId, int count) {
Platform.copyMemory(null, data + rowId * 4L, array, Platform.FLOAT_ARRAY_OFFSET, count * 4L);
} else {
for (int i = 0; i < count; i++) {
array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -553,7 +563,9 @@ public double[] getDoubles(int rowId, int count) {
count * 8L);
} else {
for (int i = 0; i < count; i++) {
array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ public byte[] getBytes(int rowId, int count) {
System.arraycopy(byteData, rowId, array, 0, count);
} else {
for (int i = 0; i < count; i++) {
array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -276,7 +278,9 @@ public short[] getShorts(int rowId, int count) {
System.arraycopy(shortData, rowId, array, 0, count);
} else {
for (int i = 0; i < count; i++) {
array[i] = (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = (short) dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -337,7 +341,9 @@ public int[] getInts(int rowId, int count) {
System.arraycopy(intData, rowId, array, 0, count);
} else {
for (int i = 0; i < count; i++) {
array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -409,7 +415,9 @@ public long[] getLongs(int rowId, int count) {
System.arraycopy(longData, rowId, array, 0, count);
} else {
for (int i = 0; i < count; i++) {
array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -466,7 +474,9 @@ public float[] getFloats(int rowId, int count) {
System.arraycopy(floatData, rowId, array, 0, count);
} else {
for (int i = 0; i < count; i++) {
array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down Expand Up @@ -525,7 +535,9 @@ public double[] getDoubles(int rowId, int count) {
System.arraycopy(doubleData, rowId, array, 0, count);
} else {
for (int i = 0; i < count; i++) {
array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + i));
if (!isNullAt(rowId + i)) {
array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + i));
}
}
}
return array;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,175 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper {
assert(testVector.getDoubles(0, 3)(2) == 1342.17729d)
}

def check(expected: Seq[Any], testVector: WritableColumnVector): Unit = {
expected.zipWithIndex.foreach {
case (v: Integer, idx) =>
assert(testVector.getInt(idx) == v)
assert(testVector.getInts(0, testVector.capacity)(idx) == v)
case (v: Short, idx) =>
assert(testVector.getShort(idx) == v)
assert(testVector.getShorts(0, testVector.capacity)(idx) == v)
case (v: Byte, idx) =>
assert(testVector.getByte(idx) == v)
assert(testVector.getBytes(0, testVector.capacity)(idx) == v)
case (v: Long, idx) =>
assert(testVector.getLong(idx) == v)
assert(testVector.getLongs(0, testVector.capacity)(idx) == v)
case (v: Float, idx) =>
assert(testVector.getFloat(idx) == v)
assert(testVector.getFloats(0, testVector.capacity)(idx) == v)
case (v: Double, idx) =>
assert(testVector.getDouble(idx) == v)
assert(testVector.getDoubles(0, testVector.capacity)(idx) == v)
case (null, idx) => testVector.isNullAt(idx)
case (_, idx) => assert(false, s"Unexpected value at $idx")
}
}

testVectors("getInts with dictionary and nulls", 3, IntegerType) { testVector =>
// Validate without dictionary
val expected = Seq(1, null, 3)
expected.foreach {
case i: Integer => testVector.appendInt(i)
case _ => testVector.appendNull()
}
check(expected, testVector)

// Validate with dictionary
val expectedDictionary = Seq(7, null, 9)
val dictArray = (Seq(-1, -1) ++ expectedDictionary.map {
case i: Integer => i.toInt
case _ => -1
}).toArray
val dict = new ColumnDictionary(dictArray)
testVector.setDictionary(dict)
testVector.reserveDictionaryIds(3)
testVector.getDictionaryIds.putInt(0, 2)
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
testVector.getDictionaryIds.putInt(2, 4)
check(expectedDictionary, testVector)
}

testVectors("getShorts with dictionary and nulls", 3, ShortType) { testVector =>
// Validate without dictionary
val expected = Seq(1.toShort, null, 3.toShort)
expected.foreach {
case i: Short => testVector.appendShort(i)
case _ => testVector.appendNull()
}
check(expected, testVector)

// Validate with dictionary
val expectedDictionary = Seq(7.toShort, null, 9.toShort)
val dictArray = (Seq(-1, -1) ++ expectedDictionary.map {
case i: Short => i.toInt
case _ => -1
}).toArray
val dict = new ColumnDictionary(dictArray)
testVector.setDictionary(dict)
testVector.reserveDictionaryIds(3)
testVector.getDictionaryIds.putInt(0, 2)
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
testVector.getDictionaryIds.putInt(2, 4)
check(expectedDictionary, testVector)
}

testVectors("getBytes with dictionary and nulls", 3, ByteType) { testVector =>
// Validate without dictionary
val expected = Seq(1.toByte, null, 3.toByte)
expected.foreach {
case i: Byte => testVector.appendByte(i)
case _ => testVector.appendNull()
}
check(expected, testVector)

// Validate with dictionary
val expectedDictionary = Seq(7.toByte, null, 9.toByte)
val dictArray = (Seq(-1, -1) ++ expectedDictionary.map {
case i: Byte => i.toInt
case _ => -1
}).toArray
val dict = new ColumnDictionary(dictArray)
testVector.setDictionary(dict)
testVector.reserveDictionaryIds(3)
testVector.getDictionaryIds.putInt(0, 2)
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
testVector.getDictionaryIds.putInt(2, 4)
check(expectedDictionary, testVector)
}

testVectors("getLongs with dictionary and nulls", 3, LongType) { testVector =>
// Validate without dictionary
val expected = Seq(2147483L, null, 2147485L)
expected.foreach {
case i: Long => testVector.appendLong(i)
case _ => testVector.appendNull()
}
check(expected, testVector)

// Validate with dictionary
val expectedDictionary = Seq(2147483648L, null, 2147483650L)
val dictArray = (Seq(-1L, -1L) ++ expectedDictionary.map {
case i: Long => i
case _ => -1L
}).toArray
val dict = new ColumnDictionary(dictArray)
testVector.setDictionary(dict)
testVector.reserveDictionaryIds(3)
testVector.getDictionaryIds.putInt(0, 2)
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
testVector.getDictionaryIds.putInt(2, 4)
check(expectedDictionary, testVector)
}

testVectors("getFloats with dictionary and nulls", 3, FloatType) { testVector =>
// Validate without dictionary
val expected = Seq(1.1f, null, 3.3f)
expected.foreach {
case i: Float => testVector.appendFloat(i)
case _ => testVector.appendNull()
}
check(expected, testVector)

// Validate with dictionary
val expectedDictionary = Seq(0.1f, null, 0.3f)
val dictArray = (Seq(-1f, -1f) ++ expectedDictionary.map {
case i: Float => i
case _ => -1f
}).toArray
val dict = new ColumnDictionary(dictArray)
testVector.setDictionary(dict)
testVector.reserveDictionaryIds(3)
testVector.getDictionaryIds.putInt(0, 2)
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
testVector.getDictionaryIds.putInt(2, 4)
check(expectedDictionary, testVector)
}

testVectors("getDoubles with dictionary and nulls", 3, DoubleType) { testVector =>
// Validate without dictionary
val expected = Seq(1.1d, null, 3.3d)
expected.foreach {
case i: Double => testVector.appendDouble(i)
case _ => testVector.appendNull()
}
check(expected, testVector)

// Validate with dictionary
val expectedDictionary = Seq(1342.17727d, null, 1342.17729d)
val dictArray = (Seq(-1d, -1d) ++ expectedDictionary.map {
case i: Double => i
case _ => -1d
}).toArray
val dict = new ColumnDictionary(dictArray)
testVector.setDictionary(dict)
testVector.reserveDictionaryIds(3)
testVector.getDictionaryIds.putInt(0, 2)
testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry should be ignored
testVector.getDictionaryIds.putInt(2, 4)
check(expectedDictionary, testVector)
}

test("[SPARK-22092] off-heap column vector reallocation corrupts array data") {
withVector(new OffHeapColumnVector(8, arrayType)) { testVector =>
val data = testVector.arrayData()
Expand Down

0 comments on commit f66452b

Please sign in to comment.