Skip to content

Commit

Permalink
Make Redshift to S3 authentication mechanisms mutually exclusive
Browse files Browse the repository at this point in the history
This patch makes a breaking change to how Redshift to S3 communication is authenticated. Previously, the implicit default behavior was to forward Spark's S3 credentials to Redshift and this default would be used unless `aws_iam_role` or the `temporary_aws_*` options were set. This behavior is slightly dodgy because it meant that a typo in the IAM settings (i.e. using the parameter `redshift_iam_role` instead of the correct `aws_iam_role`) would cause a default authentication mechanism to be used instead.

To fix that gap, this patch changes this library so that Spark's S3 credentials will only be forwarded to Redshift if `forward_spark_s3_credentials` is set to `true`. This option is mutually-exclusive with the `aws_iam_role` and `temporary_aws_*` options and is set to `false` by default. The net effect of this change is that users who were already using ``aws_iam_role` or the `temporary_aws_*` options will be unaffected, while users relying on the old default behavior will need to set `forward_spark_s3_credentials` to `true` in order to continue using that authentication scheme.

I have updated the README with a new section explaining the different connections involved in using this library and the different authentication mechanisms available for them. I also added a new section describing how to enable encryption of data in transit and at rest.

Because of the backwards-incompatible nature of this change, I'm bumping the version number to `3.0.0-preview1-SNAPSHOT`.

Author: Josh Rosen <[email protected]>
Author: Josh Rosen <[email protected]>

Closes #291 from JoshRosen/credential-mechanism-enforcement.
  • Loading branch information
JoshRosen committed Nov 1, 2016
1 parent 9ed18a0 commit 8afde06
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 109 deletions.
223 changes: 184 additions & 39 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,8 @@ class CrossRegionIntegrationSuite extends IntegrationSuiteBase {
StructType(StructField("foo", IntegerType) :: Nil))
val tableName = s"roundtrip_save_and_load_$randomSuffix"
try {
df.write
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
write(df)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("extracopyoptions", s"region '$bucketRegion'")
.save()
// Check that the table exists. It appears that creating a table in one connection then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,7 @@ class DecimalIntegrationSuite extends IntegrationSuiteBase {
}
conn.commit()
assert(DefaultJDBCWrapper.tableExists(conn, tableName))
val loadedDf = sqlContext.read
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.load()
val loadedDf = read.option("dbtable", tableName).load()
checkAnswer(loadedDf, expectedRows)
checkAnswer(loadedDf.selectExpr("x + 0"), expectedRows)
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,17 @@ class IAMIntegrationSuite extends IntegrationSuiteBase {
val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
StructType(StructField("a", IntegerType) :: Nil))
try {
df.write
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
write(df)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("forward_spark_s3_credentials", "false")
.option("aws_iam_role", IAM_ROLE_ARN)
.mode(SaveMode.ErrorIfExists)
.save()

assert(DefaultJDBCWrapper.tableExists(conn, tableName))
val loadedDf = sqlContext.read
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
val loadedDf = read
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("forward_spark_s3_credentials", "false")
.option("aws_iam_role", IAM_ROLE_ARN)
.load()
assert(loadedDf.schema.length === 1)
Expand All @@ -65,11 +61,9 @@ class IAMIntegrationSuite extends IntegrationSuiteBase {
val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
StructType(StructField("a", IntegerType) :: Nil))
val err = intercept[SQLException] {
df.write
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
write(df)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("forward_spark_s3_credentials", "false")
.option("aws_iam_role", IAM_ROLE_ARN + "-some-bogus-suffix")
.mode(SaveMode.ErrorIfExists)
.save()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ trait IntegrationSuiteBase
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
.option("tempdir", tempDir)
.option("forward_spark_s3_credentials", "true")
}
/**
* Create a new DataFrameWriter using common options for writing to Redshift.
Expand All @@ -139,6 +140,7 @@ trait IntegrationSuiteBase
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
.option("tempdir", tempDir)
.option("forward_spark_s3_credentials", "true")
}

protected def createTestDataInRedshift(tableName: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,18 @@ class RedshiftCredentialsInConfIntegrationSuite extends IntegrationSuiteBase {
StructType(StructField("foo", IntegerType) :: Nil))
val tableName = s"roundtrip_save_and_load_$randomSuffix"
try {
df.write
.format("com.databricks.spark.redshift")
write(df)
.option("url", AWS_REDSHIFT_JDBC_URL)
.option("user", AWS_REDSHIFT_USER)
.option("password", AWS_REDSHIFT_PASSWORD)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.save()
assert(DefaultJDBCWrapper.tableExists(conn, tableName))
val loadedDf = sqlContext.read
.format("com.databricks.spark.redshift")
val loadedDf = read
.option("url", AWS_REDSHIFT_JDBC_URL)
.option("user", AWS_REDSHIFT_USER)
.option("password", AWS_REDSHIFT_PASSWORD)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.load()
assert(loadedDf.schema === df.schema)
checkAnswer(loadedDf, df.collect())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,19 @@ class STSIntegrationSuite extends IntegrationSuiteBase {
val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1))),
StructType(StructField("a", IntegerType) :: Nil))
try {
df.write
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
write(df)
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("forward_spark_s3_credentials", "false")
.option("temporary_aws_access_key_id", STS_ACCESS_KEY_ID)
.option("temporary_aws_secret_access_key", STS_SECRET_ACCESS_KEY)
.option("temporary_aws_session_token", STS_SESSION_TOKEN)
.mode(SaveMode.ErrorIfExists)
.save()

assert(DefaultJDBCWrapper.tableExists(conn, tableName))
val loadedDf = sqlContext.read
.format("com.databricks.spark.redshift")
.option("url", jdbcUrl)
val loadedDf = read
.option("dbtable", tableName)
.option("tempdir", tempDir)
.option("forward_spark_s3_credentials", "false")
.option("temporary_aws_access_key_id", STS_ACCESS_KEY_ID)
.option("temporary_aws_secret_access_key", STS_SECRET_ACCESS_KEY)
.option("temporary_aws_session_token", STS_SESSION_TOKEN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,29 @@ private[redshift] object AWSCredentialsUtils {
* Generates a credentials string for use in Redshift COPY and UNLOAD statements.
* Favors a configured `aws_iam_role` if available in the parameters.
*/
def getRedshiftCredentialsString(params: MergedParameters,
awsCredentials: AWSCredentials): String = {
params.iamRole
.map { role => s"aws_iam_role=$role" }
.getOrElse(
awsCredentials match {
case creds: AWSSessionCredentials =>
s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
s"aws_secret_access_key=${creds.getAWSSecretKey};token=${creds.getSessionToken}"
case creds =>
s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
s"aws_secret_access_key=${creds.getAWSSecretKey}"
})
def getRedshiftCredentialsString(
params: MergedParameters,
sparkAwsCredentials: AWSCredentials): String = {

def awsCredsToString(credentials: AWSCredentials): String = {
credentials match {
case creds: AWSSessionCredentials =>
s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
s"aws_secret_access_key=${creds.getAWSSecretKey};token=${creds.getSessionToken}"
case creds =>
s"aws_access_key_id=${creds.getAWSAccessKeyId};" +
s"aws_secret_access_key=${creds.getAWSSecretKey}"
}
}
if (params.iamRole.isDefined) {
s"aws_iam_role=${params.iamRole.get}"
} else if (params.temporaryAWSCredentials.isDefined) {
awsCredsToString(params.temporaryAWSCredentials.get.getCredentials)
} else if (params.forwardSparkS3Credentials) {
awsCredsToString(sparkAwsCredentials)
} else {
throw new IllegalStateException("No Redshift S3 authentication mechanism was specified")
}
}

def staticCredentialsProvider(credentials: AWSCredentials): AWSCredentialsProvider = {
Expand Down
22 changes: 20 additions & 2 deletions src/main/scala/com/databricks/spark/redshift/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ private[redshift] object Parameters {
// * distkey has no default, but is optional unless using diststyle KEY
// * jdbcdriver has no default, but is optional

"forward_spark_s3_credentials" -> "false",
"tempformat" -> "AVRO",
"csvnullstring" -> "@NULL@",
"overwrite" -> "false",
Expand Down Expand Up @@ -87,6 +88,18 @@ private[redshift] object Parameters {
*/
case class MergedParameters(parameters: Map[String, String]) {

require(temporaryAWSCredentials.isDefined || iamRole.isDefined || forwardSparkS3Credentials,
"You must specify a method for authenticating Redshift's connection to S3 (aws_iam_role," +
" forward_spark_s3_credentials, or temporary_aws_*. For a discussion of the differences" +
" between these options, please see the README.")

require(Seq(
temporaryAWSCredentials.isDefined,
iamRole.isDefined,
forwardSparkS3Credentials).count(_ == true) == 1,
"The aws_iam_role, forward_spark_s3_credentials, and temporary_aws_*. options are " +
"mutually-exclusive; please specify only one.")

/**
* A root directory to be used for intermediate data exchange, expected to be on S3, or
* somewhere that can be written to and read from by Redshift. Make sure that AWS credentials
Expand Down Expand Up @@ -247,11 +260,16 @@ private[redshift] object Parameters {
def postActions: Array[String] = parameters("postactions").split(";")

/**
* The IAM role to assume for Redshift COPY/UNLOAD operations. This takes precedence over
* other forms of authentication.
* The IAM role that Redshift should assume for COPY/UNLOAD operations.
*/
def iamRole: Option[String] = parameters.get("aws_iam_role")

/**
* If true then this library will automatically discover the credentials that Spark is
* using to connect to S3 and will forward those credentials to Redshift over JDBC.
*/
def forwardSparkS3Credentials: Boolean = parameters("forward_spark_s3_credentials").toBoolean

/**
* Temporary AWS credentials which are passed to Redshift. These only need to be supplied by
* the user when Hadoop is configured to authenticate to S3 via IAM roles assigned to EC2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package com.databricks.spark.redshift

import scala.language.implicitConversions

import com.amazonaws.AmazonClientException
import com.amazonaws.auth.{AWSSessionCredentials, BasicSessionCredentials, BasicAWSCredentials}
import org.apache.hadoop.conf.Configuration
import org.scalatest.FunSuite
Expand All @@ -27,29 +26,39 @@ import com.databricks.spark.redshift.Parameters.MergedParameters

class AWSCredentialsUtilsSuite extends FunSuite {

val baseParams = Map(
"tempdir" -> "s3://foo/bar",
"dbtable" -> "test_schema.test_table",
"url" -> "jdbc:redshift://foo/bar?user=user&password=password")

private implicit def string2Params(tempdir: String): MergedParameters = {
MergedParameters(Map("tempdir" -> tempdir))
Parameters.mergeParameters(baseParams ++ Map(
"tempdir" -> tempdir,
"forward_spark_s3_credentials" -> "true"))
}

test("credentialsString with regular keys") {
val creds = new BasicAWSCredentials("ACCESSKEYID", "SECRET/KEY/WITH/SLASHES")
val params = MergedParameters(Map.empty)
val params =
Parameters.mergeParameters(baseParams ++ Map("forward_spark_s3_credentials" -> "true"))
assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, creds) ===
"aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY/WITH/SLASHES")
}

test("credentialsString with STS temporary keys") {
val creds = new BasicSessionCredentials("ACCESSKEYID", "SECRET/KEY", "SESSION/Token")
val params = MergedParameters(Map.empty)
assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, creds) ===
val params = Parameters.mergeParameters(baseParams ++ Map(
"temporary_aws_access_key_id" -> "ACCESSKEYID",
"temporary_aws_secret_access_key" -> "SECRET/KEY",
"temporary_aws_session_token" -> "SESSION/Token"))
assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) ===
"aws_access_key_id=ACCESSKEYID;aws_secret_access_key=SECRET/KEY;token=SESSION/Token")
}

test("Configured IAM roles should take precedence") {
val creds = new BasicSessionCredentials("ACCESSKEYID", "SECRET/KEY", "SESSION/Token")
val iamRole = "arn:aws:iam::123456789000:role/redshift_iam_role"
val params = MergedParameters(Map("aws_iam_role" -> iamRole))
assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, creds) ===
val params = Parameters.mergeParameters(baseParams ++ Map("aws_iam_role" -> iamRole))
assert(AWSCredentialsUtils.getRedshiftCredentialsString(params, null) ===
s"aws_iam_role=$iamRole")
}

Expand All @@ -58,7 +67,7 @@ class AWSCredentialsUtilsSuite extends FunSuite {
conf.set("fs.s3.awsAccessKeyId", "CONFID")
conf.set("fs.s3.awsSecretAccessKey", "CONFKEY")

val params = MergedParameters(Map(
val params = Parameters.mergeParameters(baseParams ++ Map(
"tempdir" -> "s3://URIID:URIKEY@bucket/path",
"temporary_aws_access_key_id" -> "key_id",
"temporary_aws_secret_access_key" -> "secret",
Expand Down
Loading

0 comments on commit 8afde06

Please sign in to comment.