Skip to content

Commit

Permalink
Use FileFormat-based data source instead of HadoopRDD for reads
Browse files Browse the repository at this point in the history
This patch refactors this library's read path to use a Spark 2.0's `FileFormat`-based data source to read unloaded Redshift output from S3. This approach has a few advantages over using our existing `HadoopRDD`-based approach:

- It will benefit from performance improvements in `FileScanRDD` and `HadoopFsRelation`, including automatic coalescing.
- We don't have to create a separate RDD per partition and union them together, making the RDD DAG smaller.

The bulk of the diff are helper classes copied from Spark and `spark-avro` and inlined here for API compatibility / stability purposes. Some of the new classes implemented here are likely to become incompatible with new releases of Spark, but note that `spark-avro` itself relies on similar unstable / experimental APIs and thus this library is already vulnerable to changes to those APIs (in other words, this change is not making our compatibility story significantly worse).

Author: Josh Rosen <[email protected]>
Author: Josh Rosen <[email protected]>

Closes #289 from JoshRosen/use-fileformat-for-reads.
  • Loading branch information
JoshRosen committed Oct 25, 2016
1 parent 6cc49da commit 9ed18a0
Show file tree
Hide file tree
Showing 11 changed files with 405 additions and 70 deletions.
2 changes: 2 additions & 0 deletions project/SparkRedshiftBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ object SparkRedshiftBuild extends Build {
} else {
"org.apache.avro" % "avro-mapred" % "1.7.7" % "provided" classifier "hadoop2" exclude("org.mortbay.jetty", "servlet-api")
},
// Kryo is provided by Spark, but we need this here in order to be able to import KryoSerializable
"com.esotericsoftware" % "kryo-shaded" % "3.0.3" % "provided",
// A Redshift-compatible JDBC driver must be present on the classpath for spark-redshift to work.
// For testing, we use an Amazon driver, which is available from
// http://docs.aws.amazon.com/redshift/latest/mgmt/configure-jdbc-connection.html
Expand Down
14 changes: 14 additions & 0 deletions src/it/scala/com/databricks/spark/redshift/RedshiftReadSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,18 @@ class RedshiftReadSuite extends IntegrationSuiteBase {
conn.commit()
}
}

test("read records containing escaped characters") {
withTempRedshiftTable("records_with_escaped_characters") { tableName =>
conn.createStatement().executeUpdate(
s"CREATE TABLE $tableName (x text)")
conn.createStatement().executeUpdate(
s"""INSERT INTO $tableName VALUES ('a\\nb'), ('\\\\'), ('"')""")
conn.commit()
assert(DefaultJDBCWrapper.tableExists(conn, tableName))
checkAnswer(
read.option("dbtable", tableName).load(),
Seq("a\nb", "\\", "\"").map(x => Row.apply(x)))
}
}
}
20 changes: 11 additions & 9 deletions src/main/scala/com/databricks/spark/redshift/Conversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import java.sql.Timestamp
import java.text.{DecimalFormat, DecimalFormatSymbols, SimpleDateFormat}
import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

/**
* Data type conversions for Redshift unloaded data
Expand Down Expand Up @@ -78,7 +78,7 @@ private[redshift] object Conversions {
*
* Note that instances of this function are NOT thread-safe.
*/
def createRowConverter(schema: StructType): (Array[String]) => Row = {
def createRowConverter(schema: StructType): Array[String] => InternalRow = {
val dateFormat = createRedshiftDateFormat()
val decimalFormat = createRedshiftDecimalFormat()
val conversionFunctions: Array[String => Any] = schema.fields.map { field =>
Expand Down Expand Up @@ -108,16 +108,18 @@ private[redshift] object Conversions {
case _ => (data: String) => data
}
}
// As a performance optimization, re-use the same mutable Seq:
val converted: mutable.IndexedSeq[Any] = mutable.IndexedSeq.fill(schema.length)(null)
(fields: Array[String]) => {
// As a performance optimization, re-use the same mutable row / array:
val converted: Array[Any] = Array.fill(schema.length)(null)
val externalRow = new GenericRow(converted)
val encoder = RowEncoder(schema)
(inputRow: Array[String]) => {
var i = 0
while (i < schema.length) {
val data = fields(i)
val data = inputRow(i)
converted(i) = if (data == null || data.isEmpty) null else conversionFunctions(i)(data)
i += 1
}
Row.fromSeq(converted)
encoder.toRow(externalRow)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.redshift

import java.io.Closeable

import org.apache.hadoop.mapreduce.RecordReader

/**
* An adaptor from a Hadoop [[RecordReader]] to an [[Iterator]] over the values returned.
*
* This is copied from Apache Spark and is inlined here to avoid depending on Spark internals
* in this external library.
*/
private[redshift] class RecordReaderIterator[T](
private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable {
private[this] var havePair = false
private[this] var finished = false

override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !rowReader.nextKeyValue
if (finished) {
// Close and release the reader here; close() will also be called when the task
// completes, but for tasks that read from many files, it helps to release the
// resources early.
close()
}
havePair = !finished
}
!finished
}

override def next(): T = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
rowReader.getCurrentValue
}

override def close(): Unit = {
if (rowReader != null) {
try {
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues
// when reading compressed input.
rowReader.close()
} finally {
rowReader = null
}
}
}
}
102 changes: 102 additions & 0 deletions src/main/scala/com/databricks/spark/redshift/RedshiftFileFormat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright 2016 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.redshift

import java.net.URI

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.FileSplit
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.spark.TaskContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

/**
* Internal data source used for reading Redshift UNLOAD files.
*
* This is not intended for public consumption / use outside of this package and therefore
* no API stability is guaranteed.
*/
private[redshift] class RedshiftFileFormat extends FileFormat {
override def inferSchema(
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
// Schema is provided by caller.
None
}

override def prepareWrite(
sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
throw new UnsupportedOperationException(s"prepareWrite is not supported for $this")
}

override def isSplitable(
sparkSession: SparkSession,
options: Map[String, String],
path: Path): Boolean = {
// Our custom InputFormat handles split records properly
true
}

override def buildReader(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {

require(partitionSchema.isEmpty)
require(filters.isEmpty)
require(dataSchema == requiredSchema)

val broadcastedConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))

(file: PartitionedFile) => {
val conf = broadcastedConf.value.value

val fileSplit = new FileSplit(
new Path(new URI(file.filePath)),
file.start,
file.length,
// TODO: Implement Locality
Array.empty)
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
val reader = new RedshiftRecordReader
reader.initialize(fileSplit, hadoopAttemptContext)
val iter = new RecordReaderIterator[Array[String]](reader)
// Ensure that the record reader is closed upon task completion. It will ordinarily
// be closed once it is completely iterated, but this is necessary to guard against
// resource leaks in case the task fails or is interrupted.
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
val converter = Conversions.createRowConverter(requiredSchema)
iter.map(converter)
}
}
}
34 changes: 16 additions & 18 deletions src/main/scala/com/databricks/spark/redshift/RedshiftRelation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
package com.databricks.spark.redshift

import java.io.InputStreamReader
import java.lang
import java.net.URI

import org.apache.spark.sql.catalyst.encoders.RowEncoder

import scala.collection.JavaConverters._

import com.amazonaws.auth.AWSCredentialsProvider
Expand Down Expand Up @@ -115,8 +116,11 @@ private[redshift] case class RedshiftRelation(
if (results.next()) {
val numRows = results.getLong(1)
val parallelism = sqlContext.getConf("spark.sql.shuffle.partitions", "200").toInt
val emptyRow = Row.empty
sqlContext.sparkContext.parallelize(1L to numRows, parallelism).map(_ => emptyRow)
val emptyRow = RowEncoder(StructType(Seq.empty)).toRow(Row(Seq.empty))
sqlContext.sparkContext
.parallelize(1L to numRows, parallelism)
.map(_ => emptyRow)
.asInstanceOf[RDD[Row]]
} else {
throw new IllegalStateException("Could not read count from Redshift")
}
Expand Down Expand Up @@ -155,25 +159,19 @@ private[redshift] case class RedshiftRelation(
tempDir.stripSuffix("/") + '/' + file.stripPrefix(cleanedTempDirUri).stripPrefix("/")
}
}
// Create a DataFrame to read the unloaded data:
val rdd: RDD[(lang.Long, Array[String])] = {
val rdds = filesToRead.map { file =>
sqlContext.sparkContext.newAPIHadoopFile(
file,
classOf[RedshiftInputFormat],
classOf[java.lang.Long],
classOf[Array[String]])
}.toArray
sqlContext.sparkContext.union(rdds)
}

val prunedSchema = pruneSchema(schema, requiredColumns)
rdd.values.mapPartitions { iter =>
val converter: Array[String] => Row = Conversions.createRowConverter(prunedSchema)
iter.map(converter)
}

sqlContext.read
.format(classOf[RedshiftFileFormat].getName)
.schema(prunedSchema)
.load(filesToRead: _*)
.queryExecution.executedPlan.execute().asInstanceOf[RDD[Row]]
}
}

override def needConversion: Boolean = false

private def buildUnloadStmt(
requiredColumns: Array[String],
filters: Array[Filter],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2016 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.redshift

import java.io._

import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import org.apache.hadoop.conf.Configuration
import org.slf4j.LoggerFactory

import scala.util.control.NonFatal

class SerializableConfiguration(@transient var value: Configuration)
extends Serializable with KryoSerializable {
@transient private[redshift] lazy val log = LoggerFactory.getLogger(getClass)

private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException {
out.defaultWriteObject()
value.write(out)
}

private def readObject(in: ObjectInputStream): Unit = tryOrIOException {
value = new Configuration(false)
value.readFields(in)
}

private def tryOrIOException[T](block: => T): T = {
try {
block
} catch {
case e: IOException =>
log.error("Exception encountered", e)
throw e
case NonFatal(e) =>
log.error("Exception encountered", e)
throw new IOException(e)
}
}

def write(kryo: Kryo, out: Output): Unit = {
val dos = new DataOutputStream(out)
value.write(dos)
dos.flush()
}

def read(kryo: Kryo, in: Input): Unit = {
value = new Configuration(false)
value.readFields(new DataInputStream(in))
}
}
8 changes: 7 additions & 1 deletion src/main/scala/com/databricks/spark/redshift/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,16 @@ private[redshift] object Utils {
uri.getFragment)
}

// Visible for testing
private[redshift] var lastTempPathGenerated: String = null

/**
* Creates a randomly named temp directory path for intermediate data
*/
def makeTempPath(tempRoot: String): String = Utils.joinUrls(tempRoot, UUID.randomUUID().toString)
def makeTempPath(tempRoot: String): String = {
lastTempPathGenerated = Utils.joinUrls(tempRoot, UUID.randomUUID().toString)
lastTempPathGenerated
}

/**
* Checks whether the S3 bucket for the given UI has an object lifecycle configuration to
Expand Down
Loading

0 comments on commit 9ed18a0

Please sign in to comment.