Skip to content

Commit

Permalink
Introduce sqlAll, sqlAny and related state types (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
zlondrej authored Sep 30, 2024
1 parent ac8008a commit 573a865
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 38 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# hpqtypes-extras-1.17.0.0 (2023-??-??)
* Add an optional check that all foreign keys have an index.
* Add support for NULLS NOT DISTINCT in unique indexes.
* Add `sqlAll` and `sqlAny` to allow creating `SQL` expressions with
nested `AND` and `OR` conditions.
* Add `SqlWhereAll` and `SqlWhereAny` so they can be used in signatures.

# hpqtypes-extras-1.16.4.4 (2023-08-23)
* Switch from `cryptonite` to `crypton`.
Expand Down
7 changes: 3 additions & 4 deletions src/Database/PostgreSQL/PQTypes/Checks.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import Database.PostgreSQL.PQTypes.Migrate
import Database.PostgreSQL.PQTypes.Model
import Database.PostgreSQL.PQTypes.SQL.Builder
import Database.PostgreSQL.PQTypes.Versions
import Database.PostgreSQL.PQTypes.Utils.NubList

headExc :: String -> [a] -> a
headExc s [] = error s
Expand Down Expand Up @@ -429,7 +428,7 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version)
runQuery_ $ sqlGetForeignKeys table
fkeys <- fetchMany fetchForeignKey
triggers <- getDBTriggers tblName
checkedOverlaps <- checkOverlappingIndexes
checkedOverlaps <- checkOverlappingIndexes
return $ mconcat [
checkColumns 1 tblColumns desc
, checkPrimaryKey tblPrimaryKey pk
Expand Down Expand Up @@ -575,7 +574,7 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version)

allCoverage :: [[RawSQL ()]]
allCoverage = maybe [] pkColumns pkey:allIndexes

-- A foreign key is covered if it is a prefix of a list of indices.
-- So a FK on a is covered by an index on (a, b) but not an index on (b, a).
coveredFK :: ForeignKey -> [[RawSQL ()]] -> Bool
Expand Down Expand Up @@ -604,7 +603,7 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version)
if eoCheckOverlappingIndexes options
then go
else pure mempty
where
where
go = do
let handleOverlap (contained, contains) =
mconcat
Expand Down
84 changes: 62 additions & 22 deletions src/Database/PostgreSQL/PQTypes/SQL/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ module Database.PostgreSQL.PQTypes.SQL.Builder
, sqlDelete
, SqlDelete(..)

, SqlWhereAll(..)
, sqlAll
, SqlWhereAny(..)
, sqlAny
, sqlWhereAny

, SqlResult
Expand Down Expand Up @@ -291,10 +295,19 @@ data SqlDelete = SqlDelete
, sqlDeleteRecursiveWith :: Recursive
}

-- | This is not exported and is used as an implementation detail in
-- 'sqlWhereAll'.
newtype SqlAll = SqlAll
{ sqlAllWhere :: [SqlCondition]

-- | Type representing a set of conditions that are joined by 'AND'.
--
-- When no conditions are given, the result is 'TRUE'.
newtype SqlWhereAll = SqlWhereAll
{ sqlWhereAllWhere :: [SqlCondition]
}

-- | Type representing a set of conditions that are joined by 'OR'.
--
-- When no conditions are given, the result is 'FALSE'.
newtype SqlWhereAny = SqlWhereAny
{ sqlWhereAnyWhere :: [SqlCondition]
}

instance Show SqlSelect where
Expand All @@ -312,7 +325,10 @@ instance Show SqlUpdate where
instance Show SqlDelete where
show = show . toSQLCommand

instance Show SqlAll where
instance Show SqlWhereAll where
show = show . toSQLCommand

instance Show SqlWhereAny where
show = show . toSQLCommand

emitClause :: Sqlable sql => SQL -> sql -> SQL
Expand Down Expand Up @@ -464,7 +480,7 @@ materializedClause Materialized = if isWithMaterializedSupported then "MATERIALI
materializedClause NonMaterialized = if isWithMaterializedSupported then "NOT MATERIALIZED" else ""

recursiveClause :: Recursive -> SQL
recursiveClause Recursive = "WITH RECURSIVE"
recursiveClause Recursive = "WITH RECURSIVE"
recursiveClause NonRecursive = "WITH"

instance Sqlable SqlUpdate where
Expand All @@ -486,10 +502,17 @@ instance Sqlable SqlDelete where
emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlDeleteWhere cmd) <+>
emitClausesSepComma "RETURNING" (sqlDeleteResult cmd)

instance Sqlable SqlAll where
toSQLCommand cmd | null (sqlAllWhere cmd) = "TRUE"
toSQLCommand cmd =
"(" <+> smintercalate "AND" (map (parenthesize . toSQLCommand) (sqlAllWhere cmd)) <+> ")"
instance Sqlable SqlWhereAll where
toSQLCommand cmd = case sqlWhereAllWhere cmd of
[] -> "TRUE"
[cond] -> toSQLCommand cond
conds -> parenthesize $ smintercalate "AND" (map (parenthesize . toSQLCommand) conds)

instance Sqlable SqlWhereAny where
toSQLCommand cmd = case sqlWhereAnyWhere cmd of
[] -> "FALSE"
[cond] -> toSQLCommand cond
conds -> parenthesize $ smintercalate "OR" (map (parenthesize . toSQLCommand) conds)

sqlSelect :: SQL -> State SqlSelect () -> SqlSelect
sqlSelect table refine =
Expand Down Expand Up @@ -546,7 +569,7 @@ data Recursive = Recursive | NonRecursive
instance Semigroup Recursive where
_ <> Recursive = Recursive
Recursive <> _ = Recursive
_ <> _ = NonRecursive
_ <> _ = NonRecursive

class SqlWith a where
sqlWith1 :: a -> SQL -> SQL -> Materialized -> Recursive -> a
Expand Down Expand Up @@ -605,9 +628,13 @@ instance SqlWhere SqlDelete where
sqlWhere1 cmd cond = cmd { sqlDeleteWhere = sqlDeleteWhere cmd ++ [cond] }
sqlGetWhereConditions = sqlDeleteWhere

instance SqlWhere SqlAll where
sqlWhere1 cmd cond = cmd { sqlAllWhere = sqlAllWhere cmd ++ [cond] }
sqlGetWhereConditions = sqlAllWhere
instance SqlWhere SqlWhereAll where
sqlWhere1 cmd cond = cmd { sqlWhereAllWhere = sqlWhereAllWhere cmd ++ [cond] }
sqlGetWhereConditions = sqlWhereAllWhere

instance SqlWhere SqlWhereAny where
sqlWhere1 cmd cond = cmd { sqlWhereAnyWhere = sqlWhereAnyWhere cmd ++ [cond] }
sqlGetWhereConditions = sqlWhereAnyWhere

-- | The @WHERE@ part of an SQL query. See above for a usage
-- example. See also 'SqlCondition'.
Expand Down Expand Up @@ -664,16 +691,29 @@ sqlWhereIsNULL col = sqlWhere $ col <+> "IS NULL"
sqlWhereIsNotNULL :: (MonadState v m, SqlWhere v) => SQL -> m ()
sqlWhereIsNotNULL col = sqlWhere $ col <+> "IS NOT NULL"

-- | Run monad that joins all conditions using 'AND' operator.
--
-- When no conditions are given, the result is 'TRUE'.
--
-- Note: This is usally not needed as `SqlSelect`, `SqlUpdate` and `SqlDelete`
-- already join conditions using 'AND' by default, but it can be useful when
-- nested in `sqlAny`.
sqlAll :: State SqlWhereAll () -> SQL
sqlAll = toSQLCommand . (`execState` SqlWhereAll [])

-- | Run monad that joins all conditions using 'OR' operator.
--
-- When no conditions are given, the result is 'FALSE'.
sqlAny :: State SqlWhereAny () -> SQL
sqlAny = toSQLCommand . (`execState` SqlWhereAny [])

-- | Add a condition in the WHERE statement that holds if any of the given
-- condition holds.
sqlWhereAny :: (MonadState v m, SqlWhere v) => [State SqlAll ()] -> m ()
sqlWhereAny = sqlWhere . sqlWhereAnyImpl

sqlWhereAnyImpl :: [State SqlAll ()] -> SQL
sqlWhereAnyImpl [] = "FALSE"
sqlWhereAnyImpl l =
"(" <+> smintercalate "OR" (map (parenthesize . toSQLCommand
. flip execState (SqlAll [])) l) <+> ")"
--
-- These conditions are joined with 'OR' operator.
-- When no conditions are given, the result is 'FALSE'.
sqlWhereAny :: (MonadState v m, SqlWhere v) => [State SqlWhereAll ()] -> m ()
sqlWhereAny = sqlWhere . sqlAny . mapM_ (sqlWhere . sqlAll)

class SqlFrom a where
sqlFrom1 :: a -> SQL -> a
Expand Down
86 changes: 74 additions & 12 deletions test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
{-# HLINT ignore "Use head" #-}
module Main where

import Control.Monad.Catch
import Control.Monad (forM_)
import Control.Monad.Catch
import Control.Monad.IO.Class
import Data.Either
import Data.List (zip4)
import Data.Monoid
import Prelude
import Data.Typeable
import Data.UUID.Types
import qualified Data.Set as Set
import qualified Data.Text as T
import Data.Typeable
import Data.UUID.Types

import Data.Monoid.Utils
import Database.PostgreSQL.PQTypes
Expand Down Expand Up @@ -1284,7 +1282,7 @@ testSqlWithRecursive :: HasCallStack => (String -> TestM ()) -> TestM ()
testSqlWithRecursive step = do
step "Running WITH RECURSIVE tests"
testPass
where
where
migrate tables migrations = do
migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] tables migrations
checkDatabase defaultExtrasOptions [] [] tables
Expand All @@ -1310,7 +1308,7 @@ testSqlWithRecursive step = do
sqlResult "root.cartel_member_id"
sqlResult "root.cartel_boss_id"
sqlWhere "root.cartel_boss_id IS NULL"
sqlUnionAll
sqlUnionAll
[ sqlSelect "cartel child" $ do
sqlResult "child.cartel_member_id"
sqlResult "child.cartel_boss_id"
Expand Down Expand Up @@ -1701,7 +1699,7 @@ foreignKeyIndexesTests connSource =
let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True }
assertException "Foreign keys are missing" $ migrateDatabase options ["pgcrypto"] [] [] [table1, table2, table3]
[createTableMigration table1, createTableMigration table2, createTableMigration table3]

step "Multi column indexes covering a FK pass the checks"
do
let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True }
Expand Down Expand Up @@ -1818,8 +1816,8 @@ overlapingIndexesTests connSource = do
, tblColumn { colName = "idx3", colType = UuidT }
]
, tblPrimaryKey = pkOnColumn "id"
, tblIndexes =
[ indexOnColumns [ indexColumn "idx1", indexColumn "idx2" ]
, tblIndexes =
[ indexOnColumns [ indexColumn "idx1", indexColumn "idx2" ]
, indexOnColumns [ indexColumn "idx1" ]
]
}
Expand All @@ -1834,7 +1832,7 @@ nullsNotDistinctTests connSource = do
migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] [nullTableTest1, nullTableTest2]
[createTableMigration nullTableTest1, createTableMigration nullTableTest2]
checkDatabase defaultExtrasOptions [] [] [nullTableTest1, nullTableTest2]

step "Insert two NULLs on a column with a default UNIQUE index"
do
runQuery_ . sqlInsert "nulltests1" $ do
Expand All @@ -1848,7 +1846,7 @@ nullsNotDistinctTests connSource = do
sqlSet "content" (Nothing @T.Text)
assertDBException "Cannot insert twice a null value with NULLS NOT DISTINCT" $ runQuery_ . sqlInsert "nulltests2" $ do
sqlSet "content" (Nothing @T.Text)

where
nullTableTest1 = tblTable
{ tblName = "nulltests1"
Expand Down Expand Up @@ -1876,6 +1874,69 @@ nullsNotDistinctTests connSource = do
]
}

sqlAnyAllTests :: TestTree
sqlAnyAllTests = testGroup "SQL ANY/ALL tests"
[ testCase "sqlAll produces correct queries" $ do
assertSqlEqual "empty sqlAll is TRUE" "TRUE" . sqlAll $ pure ()
assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAll $ sqlWhere "cond"
assertSqlEqual "each condition as well as entire condition is parenthesized"
"((cond1) AND (cond2))" $ sqlAll $ do
sqlWhere "cond1"
sqlWhere "cond2"

assertSqlEqual "sqlAll can be nested"
"((cond1) AND (cond2) AND (((cond3) AND (cond4))))" $
sqlAll $ do
sqlWhere "cond1"
sqlWhere "cond2"
sqlWhere . sqlAll $ do
sqlWhere "cond3"
sqlWhere "cond4"
, testCase "sqlAny produces correct queries" $ do
assertSqlEqual "empty sqlAny is FALSE" "FALSE" . sqlAny $ pure ()
assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAny $ sqlWhere "cond"
assertSqlEqual "each condition as well as entire condition is parenthesized"
"((cond1) OR (cond2))" $ sqlAny $ do
sqlWhere "cond1"
sqlWhere "cond2"

assertSqlEqual "sqlAny can be nested"
"((cond1) OR (cond2) OR (((cond3) OR (cond4))))" $
sqlAny $ do
sqlWhere "cond1"
sqlWhere "cond2"
sqlWhere . sqlAny $ do
sqlWhere "cond3"
sqlWhere "cond4"
, testCase "mixing sqlAny and all produces correct queries" $ do
assertSqlEqual "sqlAny and sqlAll can be mixed"
"((((cond1) OR (cond2))) AND (((cond3) OR (cond4))))" $
sqlAll $ do
sqlWhere . sqlAny $ do
sqlWhere "cond1"
sqlWhere "cond2"
sqlWhere . sqlAny $ do
sqlWhere "cond3"
sqlWhere "cond4"
, testCase "sqlWhereAny produces correct queries" $ do
-- `sqlWhereAny` has to be wrapped in `sqlAll` to disambiguate the `SqlWhere` monad.
assertSqlEqual "empty sqlWhereAny is FALSE" "FALSE" . sqlAll $ sqlWhereAny []
assertSqlEqual "each condition as well as entire condition is parenthesized and joined with OR"
"((cond1) OR (cond2))" . sqlAll $ sqlWhereAny [sqlWhere "cond1", sqlWhere "cond2"]
assertSqlEqual "nested multi-conditions are parenthesized and joined with AND"
"((cond1) OR (((cond2) AND (cond3))) OR (cond4))" . sqlAll $ sqlWhereAny
[ sqlWhere "cond1"
, do
sqlWhere "cond2"
sqlWhere "cond3"
, sqlWhere "cond4"
]
]
where
assertSqlEqual :: (Sqlable a) => String -> a -> a -> Assertion
assertSqlEqual msg a b = assertEqual msg
(show $ toSQLCommand a) (show $ toSQLCommand b)

assertNoException :: String -> TestM () -> TestM ()
assertNoException t = eitherExc
(const $ liftIO $ assertFailure ("Exception thrown for: " ++ t))
Expand Down Expand Up @@ -1924,6 +1985,7 @@ main = do
, foreignKeyIndexesTests connSource
, overlapingIndexesTests connSource
, nullsNotDistinctTests connSource
, sqlAnyAllTests
]
where
ings =
Expand Down

0 comments on commit 573a865

Please sign in to comment.