Skip to content

Commit

Permalink
Fix comparison of numeric jsonb columns (#22)
Browse files Browse the repository at this point in the history
Since we use `->>` on `jsonb` fields we always get a string back.
Comparisons such as `<` and `>` were done on string values which gives
unexpected results.

I have tried various other approaches that failed.
- Casting everything to a `jsonb` doesn't work because you can't cast
query params to `jsonb`.
- Using a `CASE WHEN` with `jsonb_typeof` doesn't work because each
`WHEN` of a `CASE WHEN` needs to return the same type.
- There are also complications with calling `Convert` recursively for
`$elemMatch` where you then don't know the column type anymore.
  • Loading branch information
erikdubbelboer authored Aug 7, 2024
1 parent 947a271 commit 9ed74d2
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 28 deletions.
2 changes: 1 addition & 1 deletion examples/readme_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ func ExampleNewConverter_readme() {
fmt.Println(conditions)
fmt.Printf("%#v\n", values)
// Output:
// ((("meta"->>'map' ~* $1) OR ("meta"->>'map' ~* $2)) AND ("meta"->>'password' = $3) AND (("meta"->>'playerCount' >= $4) AND ("meta"->>'playerCount' < $5)))
// ((("meta"->>'map' ~* $1) OR ("meta"->>'map' ~* $2)) AND ("meta"->>'password' = $3) AND ((("meta"->>'playerCount')::numeric >= $4) AND (("meta"->>'playerCount')::numeric < $5)))
// []interface {}{"aztec", "nuke", "", 2, 10}
}
70 changes: 45 additions & 25 deletions filter/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ import (
"sync"
)

var basicOperatorMap = map[string]string{
"$gt": ">",
"$gte": ">=",
"$lt": "<",
"$lte": "<=",
var numericOperatorMap = map[string]string{
"$gt": ">",
"$gte": ">=",
"$lt": "<",
"$lte": "<=",
}

var textOperatorMap = map[string]string{
"$eq": "=",
"$ne": "!=",
"$regex": "~*",
Expand Down Expand Up @@ -200,14 +203,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
values = append(values, v[operator])
case "$exists":
// $exists only works on jsonb columns, so we need to check if the key is in the JSONB data first.
isNestedColumn := c.nestedColumn != ""
for _, exemption := range c.nestedExemptions {
if exemption == key {
isNestedColumn = false
break
}
}
if !isNestedColumn {
if !c.isNestedColumn(key) {
// There is no way in Postgres to check if a column exists on a table.
return "", nil, fmt.Errorf("$exists operator not supported on non-nested jsonb columns")
}
Expand All @@ -217,20 +213,14 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
}
inner = append(inner, fmt.Sprintf("(%sjsonb_path_match(%s, 'exists($.%s)'))", neg, c.nestedColumn, key))
case "$elemMatch":
// $elemMatch needs a different implementation depending on if the column is in JSONB or not.
isNestedColumn := c.nestedColumn != ""
for _, exemption := range c.nestedExemptions {
if exemption == key {
isNestedColumn = false
break
}
}
innerConditions, innerValues, err := c.convertFilter(map[string]any{c.placeholderName: v[operator]}, paramIndex)
if err != nil {
return "", nil, err
}
paramIndex += len(innerValues)
if isNestedColumn {

// $elemMatch needs a different implementation depending on if the column is in JSONB or not.
if c.isNestedColumn(key) {
// This will for example become:
//
// EXISTS (SELECT 1 FROM jsonb_array_elements("meta"->'foo') AS __filter_placeholder WHERE ("__filter_placeholder"::text = $1))
Expand All @@ -247,11 +237,27 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
values = append(values, innerValues...)
default:
value := v[operator]
op, ok := basicOperatorMap[operator]
isNumericOperator := false
op, ok := textOperatorMap[operator]
if !ok {
return "", nil, fmt.Errorf("unknown operator: %s", operator)
op, ok = numericOperatorMap[operator]
if !ok {
return "", nil, fmt.Errorf("unknown operator: %s", operator)
}
isNumericOperator = true
}

// Prevent cryptic errors like:
// unexpected error: sql: converting argument $1 type: unsupported type []interface {}, a slice of interface
if !isScalar(value) {
return "", nil, fmt.Errorf("invalid comparison value (must be a primitive): %v", value)
}

if isNumericOperator && isNumeric(value) && c.isNestedColumn(key) {
inner = append(inner, fmt.Sprintf("((%s)::numeric %s $%d)", c.columnName(key), op, paramIndex))
} else {
inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key), op, paramIndex))
}
inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key), op, paramIndex))
paramIndex++
values = append(values, value)
}
Expand All @@ -277,6 +283,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
conditions = append(conditions, fmt.Sprintf("(%s IS NULL)", c.columnName(key)))
}
default:
// Prevent cryptic errors like:
// unexpected error: sql: converting argument $1 type: unsupported type []interface {}, a slice of interface
if !isScalar(value) {
return "", nil, fmt.Errorf("invalid comparison value (must be a primitive): %v", value)
}
Expand Down Expand Up @@ -308,3 +316,15 @@ func (c *Converter) columnName(column string) string {
}
return fmt.Sprintf(`%q->>'%s'`, c.nestedColumn, column)
}

func (c *Converter) isNestedColumn(column string) bool {
if c.nestedColumn == "" {
return false
}
for _, exemption := range c.nestedExemptions {
if exemption == column {
return false
}
}
return true
}
24 changes: 24 additions & 0 deletions filter/converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,30 @@ func TestConverter_Convert(t *testing.T) {
[]any{float64(18)},
nil,
},
{
"numeric comparison bug with jsonb column",
filter.WithNestedJSONB("meta"),
`{"foo": {"$gt": 0}}`,
`(("meta"->>'foo')::numeric > $1)`,
[]any{float64(0)},
nil,
},
{
"numeric comparison against null with jsonb column",
filter.WithNestedJSONB("meta"),
`{"foo": {"$gt": null}}`,
`("meta"->>'foo' > $1)`,
[]any{nil},
nil,
},
{
"compare with non scalar",
nil,
`{"name": {"$eq": [1, 2]}}`,
``,
nil,
fmt.Errorf("invalid comparison value (must be a primitive): [1 2]"),
},
}

for _, tt := range tests {
Expand Down
7 changes: 7 additions & 0 deletions filter/util.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package filter

func isNumeric(v any) bool {
// json.Unmarshal returns float64 for all numbers
// so we only need to check for float64.
_, ok := v.(float64)
return ok
}

func isScalar(v any) bool {
if v == nil {
return true
Expand Down
22 changes: 20 additions & 2 deletions integration/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,16 @@ func TestIntegration_BasicOperators(t *testing.T) {
nil,
},
{
`invalid value`,
`invalid value type int`,
`{"level": "town1"}`, // Level is an integer column, but the value is a string.
nil,
errors.New("pq: invalid input syntax for type integer: \"town1\""),
errors.New(`pq: invalid input syntax for type integer: "town1"`),
},
{
`invalid value type string`,
`{"name": 123}`, // Name is a string column, but the value is an integer.
[]int{},
nil,
},
{
`empty object`,
Expand Down Expand Up @@ -381,6 +387,18 @@ func TestIntegration_BasicOperators(t *testing.T) {
[]int{3},
nil,
},
{
"$lt bug with jsonb column",
`{"guild_id": {"$lt": 100}}`,
[]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
nil,
},
{
"$lt with null and jsonb column",
`{"guild_id": {"$lt": null}}`,
[]int{},
nil,
},
}

for _, tt := range tests {
Expand Down

0 comments on commit 9ed74d2

Please sign in to comment.