Skip to content

Commit

Permalink
[SPARK-48697][SQL] Add collation aware string filters
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Adding a new classes of filters which are collation aware.

### Why are the changes needed?

apache#46760 Added the logic of predicate widening for collated column references, but this would completely change the filters and if the original expression did not get evaluated by spark later we could end up with wrong results. Also, data sources would never be able to actually support these filters and they would just see them as AlwaysTrue.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New UTs.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47059 from stefankandic/fixPredicateWidening.

Authored-by: Stefan Kandic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
stefankandic authored and cloud-fan committed Jul 1, 2024
1 parent 48eb4d5 commit 703b076
Show file tree
Hide file tree
Showing 11 changed files with 446 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ object StructFilters {
Some(Literal(true, BooleanType))
case sources.AlwaysFalse() =>
Some(Literal(false, BooleanType))
case _: sources.CollatedFilter =>
None
}
translate(filter)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath
import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate}
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -381,3 +381,87 @@ case class AlwaysFalse() extends Filter {
@Evolving
object AlwaysFalse extends AlwaysFalse {
}

/**
* Base class for collation aware string filters.
*/
@Evolving
abstract class CollatedFilter() extends Filter {

/** The corresponding non-collation aware filter. */
def correspondingFilter: Filter
def dataType: DataType

override def references: Array[String] = correspondingFilter.references
override def toV2: Predicate = correspondingFilter.toV2
}

/** Collation aware equivalent of [[EqualTo]]. */
@Evolving
case class CollatedEqualTo(attribute: String, value: Any, dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = EqualTo(attribute, value)
}

/** Collation aware equivalent of [[EqualNullSafe]]. */
@Evolving
case class CollatedEqualNullSafe(attribute: String, value: Any, dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = EqualNullSafe(attribute, value)
}

/** Collation aware equivalent of [[GreaterThan]]. */
@Evolving
case class CollatedGreaterThan(attribute: String, value: Any, dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = GreaterThan(attribute, value)
}

/** Collation aware equivalent of [[GreaterThanOrEqual]]. */
@Evolving
case class CollatedGreaterThanOrEqual(attribute: String, value: Any, dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = GreaterThanOrEqual(attribute, value)
}

/** Collation aware equivalent of [[LessThan]]. */
@Evolving
case class CollatedLessThan(attribute: String, value: Any, dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = LessThan(attribute, value)
}

/** Collation aware equivalent of [[LessThanOrEqual]]. */
@Evolving
case class CollatedLessThanOrEqual(attribute: String, value: Any, dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = LessThanOrEqual(attribute, value)
}

/** Collation aware equivalent of [[In]]. */
@Evolving
case class CollatedIn(attribute: String, values: Array[Any], dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = In(attribute, values)
}

/** Collation aware equivalent of [[StringStartsWith]]. */
@Evolving
case class CollatedStringStartsWith(attribute: String, value: String, dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = StringStartsWith(attribute, value)
}

/** Collation aware equivalent of [[StringEndsWith]]. */
@Evolving
case class CollatedStringEndsWith(attribute: String, value: String, dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = StringEndsWith(attribute, value)
}

/** Collation aware equivalent of [[StringContains]]. */
@Evolving
case class CollatedStringContains(attribute: String, value: String, dataType: DataType)
extends CollatedFilter {
override def correspondingFilter: Filter = StringContains(attribute, value)
}
Original file line number Diff line number Diff line change
Expand Up @@ -500,63 +500,97 @@ object DataSourceStrategy
}
}

/**
* Creates a collation aware filter if the input data type is string with non-default collation
*/
private def collationAwareFilter(filter: sources.Filter, dataType: DataType): Filter = {
if (!SchemaUtils.hasNonUTF8BinaryCollation(dataType)) {
return filter
}

filter match {
case sources.EqualTo(attribute, value) =>
CollatedEqualTo(attribute, value, dataType)
case sources.EqualNullSafe(attribute, value) =>
CollatedEqualNullSafe(attribute, value, dataType)
case sources.GreaterThan(attribute, value) =>
CollatedGreaterThan(attribute, value, dataType)
case sources.GreaterThanOrEqual(attribute, value) =>
CollatedGreaterThanOrEqual(attribute, value, dataType)
case sources.LessThan(attribute, value) =>
CollatedLessThan(attribute, value, dataType)
case sources.LessThanOrEqual(attribute, value) =>
CollatedLessThanOrEqual(attribute, value, dataType)
case sources.In(attribute, values) =>
CollatedIn(attribute, values, dataType)
case sources.StringStartsWith(attribute, value) =>
CollatedStringStartsWith(attribute, value, dataType)
case sources.StringEndsWith(attribute, value) =>
CollatedStringEndsWith(attribute, value, dataType)
case sources.StringContains(attribute, value) =>
CollatedStringContains(attribute, value, dataType)
case other =>
other
}
}

private def translateLeafNodeFilter(
predicate: Expression,
pushableColumn: PushableColumnBase): Option[Filter] = predicate match {
case expressions.EqualTo(pushableColumn(name), Literal(v, t)) =>
Some(sources.EqualTo(name, convertToScala(v, t)))
case expressions.EqualTo(Literal(v, t), pushableColumn(name)) =>
Some(sources.EqualTo(name, convertToScala(v, t)))

case expressions.EqualNullSafe(pushableColumn(name), Literal(v, t)) =>
Some(sources.EqualNullSafe(name, convertToScala(v, t)))
case expressions.EqualNullSafe(Literal(v, t), pushableColumn(name)) =>
Some(sources.EqualNullSafe(name, convertToScala(v, t)))

case expressions.GreaterThan(pushableColumn(name), Literal(v, t)) =>
Some(sources.GreaterThan(name, convertToScala(v, t)))
case expressions.GreaterThan(Literal(v, t), pushableColumn(name)) =>
Some(sources.LessThan(name, convertToScala(v, t)))

case expressions.LessThan(pushableColumn(name), Literal(v, t)) =>
Some(sources.LessThan(name, convertToScala(v, t)))
case expressions.LessThan(Literal(v, t), pushableColumn(name)) =>
Some(sources.GreaterThan(name, convertToScala(v, t)))

case expressions.GreaterThanOrEqual(pushableColumn(name), Literal(v, t)) =>
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
case expressions.GreaterThanOrEqual(Literal(v, t), pushableColumn(name)) =>
Some(sources.LessThanOrEqual(name, convertToScala(v, t)))

case expressions.LessThanOrEqual(pushableColumn(name), Literal(v, t)) =>
Some(sources.LessThanOrEqual(name, convertToScala(v, t)))
case expressions.LessThanOrEqual(Literal(v, t), pushableColumn(name)) =>
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
case expressions.EqualTo(e @ pushableColumn(name), Literal(v, t)) =>
Some(collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType))
case expressions.EqualTo(Literal(v, t), e @ pushableColumn(name)) =>
Some(collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType))

case expressions.EqualNullSafe(e @ pushableColumn(name), Literal(v, t)) =>
Some(collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType))
case expressions.EqualNullSafe(Literal(v, t), e @ pushableColumn(name)) =>
Some(collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType))

case expressions.GreaterThan(e @ pushableColumn(name), Literal(v, t)) =>
Some(collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType))
case expressions.GreaterThan(Literal(v, t), e @ pushableColumn(name)) =>
Some(collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType))

case expressions.LessThan(e @ pushableColumn(name), Literal(v, t)) =>
Some(collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType))
case expressions.LessThan(Literal(v, t), e @ pushableColumn(name)) =>
Some(collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType))

case expressions.GreaterThanOrEqual(e @ pushableColumn(name), Literal(v, t)) =>
Some(collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType))
case expressions.GreaterThanOrEqual(Literal(v, t), e @ pushableColumn(name)) =>
Some(collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType))

case expressions.LessThanOrEqual(e @ pushableColumn(name), Literal(v, t)) =>
Some(collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType))
case expressions.LessThanOrEqual(Literal(v, t), e @ pushableColumn(name)) =>
Some(collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType))

case expressions.InSet(e @ pushableColumn(name), set) =>
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
Some(sources.In(name, set.toArray.map(toScala)))
Some(collationAwareFilter(sources.In(name, set.toArray.map(toScala)), e.dataType))

// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(e @ pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) =>
val hSet = list.map(_.eval(EmptyRow))
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
Some(sources.In(name, hSet.toArray.map(toScala)))
Some(collationAwareFilter(sources.In(name, hSet.toArray.map(toScala)), e.dataType))

case expressions.IsNull(pushableColumn(name)) =>
Some(sources.IsNull(name))
case expressions.IsNotNull(pushableColumn(name)) =>
Some(sources.IsNotNull(name))
case expressions.StartsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) =>
Some(sources.StringStartsWith(name, v.toString))
case expressions.StartsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) =>
Some(collationAwareFilter(sources.StringStartsWith(name, v.toString), e.dataType))

case expressions.EndsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) =>
Some(sources.StringEndsWith(name, v.toString))
case expressions.EndsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) =>
Some(collationAwareFilter(sources.StringEndsWith(name, v.toString), e.dataType))

case expressions.Contains(pushableColumn(name), Literal(v: UTF8String, StringType)) =>
Some(sources.StringContains(name, v.toString))
case expressions.Contains(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) =>
Some(collationAwareFilter(sources.StringContains(name, v.toString), e.dataType))

case expressions.Literal(true, BooleanType) =>
Some(sources.AlwaysTrue)
Expand Down Expand Up @@ -595,16 +629,6 @@ object DataSourceStrategy
translatedFilterToExpr: Option[mutable.HashMap[sources.Filter, Expression]],
nestedPredicatePushdownEnabled: Boolean)
: Option[Filter] = {

def translateAndRecordLeafNodeFilter(filter: Expression): Option[Filter] = {
val translatedFilter =
translateLeafNodeFilter(filter, PushableColumn(nestedPredicatePushdownEnabled))
if (translatedFilter.isDefined && translatedFilterToExpr.isDefined) {
translatedFilterToExpr.get(translatedFilter.get) = predicate
}
translatedFilter
}

predicate match {
case expressions.And(left, right) =>
// See SPARK-12218 for detailed discussion
Expand All @@ -631,25 +655,16 @@ object DataSourceStrategy
right, translatedFilterToExpr, nestedPredicatePushdownEnabled)
} yield sources.Or(leftFilter, rightFilter)

case notNull @ expressions.IsNotNull(_: AttributeReference) =>
// Not null filters on attribute references can always be pushed, also for collated columns.
translateAndRecordLeafNodeFilter(notNull)

case isNull @ expressions.IsNull(_: AttributeReference) =>
// Is null filters on attribute references can always be pushed, also for collated columns.
translateAndRecordLeafNodeFilter(isNull)

case p if p.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType)) =>
// The filter cannot be pushed and we widen it to be AlwaysTrue(). This is only valid if
// the result of the filter is not negated by a Not expression it is wrapped in.
translateAndRecordLeafNodeFilter(Literal.TrueLiteral)

case expressions.Not(child) =>
translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled)
.map(sources.Not)

case other =>
translateAndRecordLeafNodeFilter(other)
val filter = translateLeafNodeFilter(other, PushableColumn(nestedPredicatePushdownEnabled))
if (filter.isDefined && translatedFilterToExpr.isDefined) {
translatedFilterToExpr.get(filter.get) = predicate
}
filter
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import org.json4s.{Formats, NoTypeHints}
import org.json4s.jackson.Serialization

import org.apache.spark.{SparkException, SparkUpgradeException}
import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.{sources, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, ExpressionSet, GetStructField, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils


Expand Down Expand Up @@ -280,22 +280,15 @@ object DataSourceUtils extends PredicateHelper {
(ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters)
}

/**
* Determines whether a filter should be pushed down to the data source or not.
*
* @param expression The filter expression to be evaluated.
* @param isCollationPushDownSupported Whether the data source supports collation push down.
* @return A boolean indicating whether the filter should be pushed down or not.
*/
def shouldPushFilter(expression: Expression, isCollationPushDownSupported: Boolean): Boolean = {
if (!expression.deterministic) return false

isCollationPushDownSupported || !expression.exists {
case childExpression @ (_: Attribute | _: GetStructField) =>
// don't push down filters for types with non-binary sortable collation
// as it could lead to incorrect results
SchemaUtils.hasNonUTF8BinaryCollation(childExpression.dataType)

def containsFiltersWithCollation(filter: sources.Filter): Boolean = {
filter match {
case sources.And(left, right) =>
containsFiltersWithCollation(left) || containsFiltersWithCollation(right)
case sources.Or(left, right) =>
containsFiltersWithCollation(left) || containsFiltersWithCollation(right)
case sources.Not(child) =>
containsFiltersWithCollation(child)
case _: sources.CollatedFilter => true
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,6 @@ trait FileFormat {
*/
def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] =
FileFormat.BASE_METADATA_EXTRACTORS

/**
* Returns whether the file format supports filter push down
* for non utf8 binary collated columns.
*/
def supportsCollationPushDown: Boolean = false
}

object FileFormat {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,8 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
// - filters that need to be evaluated again after the scan
val filterSet = ExpressionSet(filters)

val filtersToPush = filters.filter(f =>
DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown))

val normalizedFilters = DataSourceStrategy.normalizeExprs(
filtersToPush, l.output)
filters.filter(_.deterministic), l.output)

val partitionColumns =
l.resolve(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
_))
if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty =>
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => !SubqueryExpression.hasSubquery(f) &&
DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown)),
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)),
logicalRelation.output)
val (partitionKeyFilters, _) = DataSourceUtils
.getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters)
Expand Down
Loading

0 comments on commit 703b076

Please sign in to comment.