Skip to content

Commit

Permalink
Add support for NULLS NOT DISTINCT
Browse files Browse the repository at this point in the history
  • Loading branch information
Raveline committed Sep 5, 2024
1 parent 1299379 commit 96bcc55
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/haskell-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
image: buildpack-deps:jammy
services:
postgres:
image: postgres:14
image: postgres:15
env:
POSTGRES_PASSWORD: postgres
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
Expand Down
24 changes: 18 additions & 6 deletions src/Database/PostgreSQL/PQTypes/Checks.hs
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,13 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version)
sqlWhereEqSql "a.attrelid" $ sqlGetTableID table
sqlOrderBy "a.attnum"
desc <- fetchMany fetchTableColumn

isAbove15 <- checkVersionIsAtLeast15
-- get info about constraints from pg_catalog
pk <- sqlGetPrimaryKey table
runQuery_ $ sqlGetChecks table
checks <- fetchMany fetchTableCheck
runQuery_ $ sqlGetIndexes table
runQuery_ $ sqlGetIndexes isAbove15 table
indexes <- fetchMany fetchTableIndex
runQuery_ $ sqlGetForeignKeys table
fkeys <- fetchMany fetchForeignKey
Expand Down Expand Up @@ -1072,6 +1074,12 @@ checkDBConsistency options domains tablesWithVersions migrations = do
-- | Type synonym for a list of tables along with their database versions.
type TablesWithVersions = [(Table, Int32)]

-- The server_version_num has been there since 8.2
checkVersionIsAtLeast15 :: (MonadDB m, MonadThrow m) => m Bool
checkVersionIsAtLeast15 = do
runSQL01_ "select current_setting('server_version_num',true)::int >= 150000;"
fetchOne runIdentity

-- | Associate each table in the list with its version as it exists in
-- the DB, or 0 if it's missing from the DB.
getTableVersions :: (MonadDB m, MonadThrow m) => [Table] -> m TablesWithVersions
Expand Down Expand Up @@ -1196,15 +1204,18 @@ fetchTableCheck (name, condition, validated) = Check {
}

-- *** INDEXES ***

sqlGetIndexes :: Table -> SQL
sqlGetIndexes table = toSQLCommand . sqlSelect "pg_catalog.pg_class c" $ do
sqlGetIndexes :: Bool -> Table -> SQL
sqlGetIndexes nullsNotDistinctSupported table = toSQLCommand . sqlSelect "pg_catalog.pg_class c" $ do
sqlResult "c.relname::text" -- index name
sqlResult $ "ARRAY(" <> selectCoordinates "0" "i.indnkeyatts" <> ")" -- array of key columns in the index
sqlResult $ "ARRAY(" <> selectCoordinates "i.indnkeyatts" "i.indnatts" <> ")" -- array of included columns in the index
sqlResult "am.amname::text" -- the method used (btree, gin etc)
sqlResult "i.indisunique" -- is it unique?
sqlResult "i.indisvalid" -- is it valid?
-- does it have NULLS NOT DISTINCT ?
if nullsNotDistinctSupported
then sqlResult "i.indnullsnotdistinct"
else sqlResult "false"
-- if partial, get constraint def
sqlResult "pg_catalog.pg_get_expr(i.indpred, i.indrelid, true)"
sqlJoinOn "pg_catalog.pg_index i" "c.oid = i.indexrelid"
Expand All @@ -1227,16 +1238,17 @@ sqlGetIndexes table = toSQLCommand . sqlSelect "pg_catalog.pg_class c" $ do
]

fetchTableIndex
:: (String, Array1 String, Array1 String, String, Bool, Bool, Maybe String)
:: (String, Array1 String, Array1 String, String, Bool, Bool, Bool, Maybe String)
-> (TableIndex, RawSQL ())
fetchTableIndex (name, Array1 keyColumns, Array1 includeColumns, method, unique, valid, mconstraint) =
fetchTableIndex (name, Array1 keyColumns, Array1 includeColumns, method, unique, valid, nullsNotDistinct, mconstraint) =
(TableIndex
{ idxColumns = map (indexColumn . unsafeSQL) keyColumns
, idxInclude = map unsafeSQL includeColumns
, idxMethod = read method
, idxUnique = unique
, idxValid = valid
, idxWhere = unsafeSQL <$> mconstraint
, idxNotDistinctNulls = nullsNotDistinct
}
, unsafeSQL name)

Expand Down
27 changes: 19 additions & 8 deletions src/Database/PostgreSQL/PQTypes/Model/Index.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,20 @@ import qualified Data.Text as T
import qualified Data.Text.Encoding as T

data TableIndex = TableIndex {
idxColumns :: [IndexColumn]
, idxInclude :: [RawSQL ()]
, idxMethod :: IndexMethod
, idxUnique :: Bool
, idxValid :: Bool -- ^ If creation of index with CONCURRENTLY fails, index
-- will be marked as invalid. Set it to 'False' if such
-- situation is expected.
, idxWhere :: Maybe (RawSQL ())
idxColumns :: [IndexColumn]
, idxInclude :: [RawSQL ()]
, idxMethod :: IndexMethod
, idxUnique :: Bool
, idxValid :: Bool
-- ^ If creation of index with CONCURRENTLY fails, index
-- will be marked as invalid. Set it to 'False' if such
-- situation is expected.
, idxWhere :: Maybe (RawSQL ())
, idxNotDistinctNulls :: Bool
-- ^ Adds NULL NOT DISTINCT on the index, meaning that
-- ^ only one NULL value will be accepted; other NULLs
-- ^ will be perceived as a violation of the constraint.
-- ^ NB: will only be used if idxUnique is set to True
} deriving (Eq, Ord, Show)

data IndexColumn
Expand Down Expand Up @@ -91,6 +97,7 @@ tblIndex = TableIndex {
, idxUnique = False
, idxValid = True
, idxWhere = Nothing
, idxNotDistinctNulls = False
}

indexOnColumn :: IndexColumn -> TableIndex
Expand Down Expand Up @@ -122,6 +129,7 @@ uniqueIndexOnColumn column = TableIndex {
, idxUnique = True
, idxValid = True
, idxWhere = Nothing
, idxNotDistinctNulls = False
}

uniqueIndexOnColumns :: [IndexColumn] -> TableIndex
Expand All @@ -132,6 +140,7 @@ uniqueIndexOnColumns columns = TableIndex {
, idxUnique = True
, idxValid = True
, idxWhere = Nothing
, idxNotDistinctNulls = False
}

uniqueIndexOnColumnWithCondition :: IndexColumn -> RawSQL () -> TableIndex
Expand All @@ -142,6 +151,7 @@ uniqueIndexOnColumnWithCondition column whereC = TableIndex {
, idxUnique = True
, idxValid = True
, idxWhere = Just whereC
, idxNotDistinctNulls = False
}

indexName :: RawSQL () -> TableIndex -> RawSQL ()
Expand Down Expand Up @@ -203,6 +213,7 @@ sqlCreateIndex_ concurrently tname idx@TableIndex{..} = mconcat [
, if null idxInclude
then ""
else " INCLUDE (" <> mintercalate ", " idxInclude <> ")"
, if idxUnique && idxNotDistinctNulls then " NULLS NOT DISTINCT" else ""
, maybe "" (" WHERE" <+>) idxWhere
]

Expand Down
53 changes: 53 additions & 0 deletions test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,58 @@ overlapingIndexesTests connSource = do
]
}

nullsNotDistinctTests :: ConnectionSourceM (LogT IO) -> TestTree
nullsNotDistinctTests connSource = do
testCaseSteps' "NULLS NOT DISTINCT tests" connSource $ \step -> do
freshTestDB step

step "Create a database with indexes"
do
migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] [nullTableTest1, nullTableTest2]
[createTableMigration nullTableTest1, createTableMigration nullTableTest2]
checkDatabase defaultExtrasOptions [] [] [nullTableTest1, nullTableTest2]

step "Insert two NULLs on a column with a default UNIQUE index"
do
runQuery_ . sqlInsert "nullTests1" $ do
sqlSet "someUniqueText" (Nothing @T.Text)
runQuery_ . sqlInsert "nullTests1" $ do
sqlSet "someUniqueText" (Nothing @T.Text)

step "Insert NULLs on a column with a NULLS NOT DISTINCT index"
do
runQuery_ . sqlInsert "nullTests2" $ do
sqlSet "someUniqueText" (Nothing @T.Text)
assertDBException "Cannot insert twice a null value with NULLS NOT DISTINCT" $ runQuery_ . sqlInsert "nullTests2" $ do
sqlSet "someUniqueText" (Nothing @T.Text)

where
nullTableTest1 = tblTable
{ tblName = "nullTests1"
, tblVersion = 1
, tblColumns =
[ tblColumn { colName = "id", colType = UuidT, colNullable = False, colDefault = Just "gen_random_uuid()" }
, tblColumn { colName = "someUniqueText", colType = TextT, colNullable = True }
]
, tblPrimaryKey = pkOnColumn "id"
, tblIndexes =
[ uniqueIndexOnColumn "someUniqueText"
]
}

nullTableTest2 = tblTable
{ tblName = "nullTests2"
, tblVersion = 1
, tblColumns =
[ tblColumn { colName = "id", colType = UuidT, colNullable = False, colDefault = Just "gen_random_uuid()" }
, tblColumn { colName = "someUniqueText", colType = TextT, colNullable = True }
]
, tblPrimaryKey = pkOnColumn "id"
, tblIndexes =
[ (uniqueIndexOnColumn "someUniqueText") { idxNotDistinctNulls = True }
]
}

assertNoException :: String -> TestM () -> TestM ()
assertNoException t = eitherExc
(const $ liftIO $ assertFailure ("Exception thrown for: " ++ t))
Expand Down Expand Up @@ -1871,6 +1923,7 @@ main = do
, sqlWithRecursiveTests connSource
, foreignKeyIndexesTests connSource
, overlapingIndexesTests connSource
, nullsNotDistinctTests connSource
]
where
ings =
Expand Down

0 comments on commit 96bcc55

Please sign in to comment.