From 573a865984b0913f87c82aa180b32db8af37834f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20Jano=C5=A1=C3=ADk?= <5196749+zlondrej@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:02:27 +0200 Subject: [PATCH] Introduce `sqlAll`, `sqlAny` and related state types (#112) --- CHANGELOG.md | 3 + src/Database/PostgreSQL/PQTypes/Checks.hs | 7 +- .../PostgreSQL/PQTypes/SQL/Builder.hs | 84 +++++++++++++----- test/Main.hs | 86 ++++++++++++++++--- 4 files changed, 142 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e5c09ed..c20ca35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # hpqtypes-extras-1.17.0.0 (2023-??-??) * Add an optional check that all foreign keys have an index. * Add support for NULLS NOT DISTINCT in unique indexes. +* Add `sqlAll` and `sqlAny` to allow creating `SQL` expressions with + nested `AND` and `OR` conditions. +* Add `SqlWhereAll` and `SqlWhereAny` so they can be used in signatures. # hpqtypes-extras-1.16.4.4 (2023-08-23) * Switch from `cryptonite` to `crypton`. diff --git a/src/Database/PostgreSQL/PQTypes/Checks.hs b/src/Database/PostgreSQL/PQTypes/Checks.hs index 6dc0f08..24e260d 100644 --- a/src/Database/PostgreSQL/PQTypes/Checks.hs +++ b/src/Database/PostgreSQL/PQTypes/Checks.hs @@ -42,7 +42,6 @@ import Database.PostgreSQL.PQTypes.Migrate import Database.PostgreSQL.PQTypes.Model import Database.PostgreSQL.PQTypes.SQL.Builder import Database.PostgreSQL.PQTypes.Versions -import Database.PostgreSQL.PQTypes.Utils.NubList headExc :: String -> [a] -> a headExc s [] = error s @@ -429,7 +428,7 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) runQuery_ $ sqlGetForeignKeys table fkeys <- fetchMany fetchForeignKey triggers <- getDBTriggers tblName - checkedOverlaps <- checkOverlappingIndexes + checkedOverlaps <- checkOverlappingIndexes return $ mconcat [ checkColumns 1 tblColumns desc , checkPrimaryKey tblPrimaryKey pk @@ -575,7 +574,7 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) allCoverage :: [[RawSQL ()]] allCoverage = maybe [] pkColumns pkey:allIndexes - + -- A foreign key is covered if it is a prefix of a list of indices. -- So a FK on a is covered by an index on (a, b) but not an index on (b, a). coveredFK :: ForeignKey -> [[RawSQL ()]] -> Bool @@ -604,7 +603,7 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) if eoCheckOverlappingIndexes options then go else pure mempty - where + where go = do let handleOverlap (contained, contains) = mconcat diff --git a/src/Database/PostgreSQL/PQTypes/SQL/Builder.hs b/src/Database/PostgreSQL/PQTypes/SQL/Builder.hs index b1983af..f224601 100644 --- a/src/Database/PostgreSQL/PQTypes/SQL/Builder.hs +++ b/src/Database/PostgreSQL/PQTypes/SQL/Builder.hs @@ -145,6 +145,10 @@ module Database.PostgreSQL.PQTypes.SQL.Builder , sqlDelete , SqlDelete(..) + , SqlWhereAll(..) + , sqlAll + , SqlWhereAny(..) + , sqlAny , sqlWhereAny , SqlResult @@ -291,10 +295,19 @@ data SqlDelete = SqlDelete , sqlDeleteRecursiveWith :: Recursive } --- | This is not exported and is used as an implementation detail in --- 'sqlWhereAll'. -newtype SqlAll = SqlAll - { sqlAllWhere :: [SqlCondition] + +-- | Type representing a set of conditions that are joined by 'AND'. +-- +-- When no conditions are given, the result is 'TRUE'. +newtype SqlWhereAll = SqlWhereAll + { sqlWhereAllWhere :: [SqlCondition] + } + +-- | Type representing a set of conditions that are joined by 'OR'. +-- +-- When no conditions are given, the result is 'FALSE'. +newtype SqlWhereAny = SqlWhereAny + { sqlWhereAnyWhere :: [SqlCondition] } instance Show SqlSelect where @@ -312,7 +325,10 @@ instance Show SqlUpdate where instance Show SqlDelete where show = show . toSQLCommand -instance Show SqlAll where +instance Show SqlWhereAll where + show = show . toSQLCommand + +instance Show SqlWhereAny where show = show . toSQLCommand emitClause :: Sqlable sql => SQL -> sql -> SQL @@ -464,7 +480,7 @@ materializedClause Materialized = if isWithMaterializedSupported then "MATERIALI materializedClause NonMaterialized = if isWithMaterializedSupported then "NOT MATERIALIZED" else "" recursiveClause :: Recursive -> SQL -recursiveClause Recursive = "WITH RECURSIVE" +recursiveClause Recursive = "WITH RECURSIVE" recursiveClause NonRecursive = "WITH" instance Sqlable SqlUpdate where @@ -486,10 +502,17 @@ instance Sqlable SqlDelete where emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlDeleteWhere cmd) <+> emitClausesSepComma "RETURNING" (sqlDeleteResult cmd) -instance Sqlable SqlAll where - toSQLCommand cmd | null (sqlAllWhere cmd) = "TRUE" - toSQLCommand cmd = - "(" <+> smintercalate "AND" (map (parenthesize . toSQLCommand) (sqlAllWhere cmd)) <+> ")" +instance Sqlable SqlWhereAll where + toSQLCommand cmd = case sqlWhereAllWhere cmd of + [] -> "TRUE" + [cond] -> toSQLCommand cond + conds -> parenthesize $ smintercalate "AND" (map (parenthesize . toSQLCommand) conds) + +instance Sqlable SqlWhereAny where + toSQLCommand cmd = case sqlWhereAnyWhere cmd of + [] -> "FALSE" + [cond] -> toSQLCommand cond + conds -> parenthesize $ smintercalate "OR" (map (parenthesize . toSQLCommand) conds) sqlSelect :: SQL -> State SqlSelect () -> SqlSelect sqlSelect table refine = @@ -546,7 +569,7 @@ data Recursive = Recursive | NonRecursive instance Semigroup Recursive where _ <> Recursive = Recursive Recursive <> _ = Recursive - _ <> _ = NonRecursive + _ <> _ = NonRecursive class SqlWith a where sqlWith1 :: a -> SQL -> SQL -> Materialized -> Recursive -> a @@ -605,9 +628,13 @@ instance SqlWhere SqlDelete where sqlWhere1 cmd cond = cmd { sqlDeleteWhere = sqlDeleteWhere cmd ++ [cond] } sqlGetWhereConditions = sqlDeleteWhere -instance SqlWhere SqlAll where - sqlWhere1 cmd cond = cmd { sqlAllWhere = sqlAllWhere cmd ++ [cond] } - sqlGetWhereConditions = sqlAllWhere +instance SqlWhere SqlWhereAll where + sqlWhere1 cmd cond = cmd { sqlWhereAllWhere = sqlWhereAllWhere cmd ++ [cond] } + sqlGetWhereConditions = sqlWhereAllWhere + +instance SqlWhere SqlWhereAny where + sqlWhere1 cmd cond = cmd { sqlWhereAnyWhere = sqlWhereAnyWhere cmd ++ [cond] } + sqlGetWhereConditions = sqlWhereAnyWhere -- | The @WHERE@ part of an SQL query. See above for a usage -- example. See also 'SqlCondition'. @@ -664,16 +691,29 @@ sqlWhereIsNULL col = sqlWhere $ col <+> "IS NULL" sqlWhereIsNotNULL :: (MonadState v m, SqlWhere v) => SQL -> m () sqlWhereIsNotNULL col = sqlWhere $ col <+> "IS NOT NULL" +-- | Run monad that joins all conditions using 'AND' operator. +-- +-- When no conditions are given, the result is 'TRUE'. +-- +-- Note: This is usally not needed as `SqlSelect`, `SqlUpdate` and `SqlDelete` +-- already join conditions using 'AND' by default, but it can be useful when +-- nested in `sqlAny`. +sqlAll :: State SqlWhereAll () -> SQL +sqlAll = toSQLCommand . (`execState` SqlWhereAll []) + +-- | Run monad that joins all conditions using 'OR' operator. +-- +-- When no conditions are given, the result is 'FALSE'. +sqlAny :: State SqlWhereAny () -> SQL +sqlAny = toSQLCommand . (`execState` SqlWhereAny []) + -- | Add a condition in the WHERE statement that holds if any of the given -- condition holds. -sqlWhereAny :: (MonadState v m, SqlWhere v) => [State SqlAll ()] -> m () -sqlWhereAny = sqlWhere . sqlWhereAnyImpl - -sqlWhereAnyImpl :: [State SqlAll ()] -> SQL -sqlWhereAnyImpl [] = "FALSE" -sqlWhereAnyImpl l = - "(" <+> smintercalate "OR" (map (parenthesize . toSQLCommand - . flip execState (SqlAll [])) l) <+> ")" +-- +-- These conditions are joined with 'OR' operator. +-- When no conditions are given, the result is 'FALSE'. +sqlWhereAny :: (MonadState v m, SqlWhere v) => [State SqlWhereAll ()] -> m () +sqlWhereAny = sqlWhere . sqlAny . mapM_ (sqlWhere . sqlAll) class SqlFrom a where sqlFrom1 :: a -> SQL -> a diff --git a/test/Main.hs b/test/Main.hs index 9d8c4d0..25b7f51 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -2,17 +2,15 @@ {-# HLINT ignore "Use head" #-} module Main where -import Control.Monad.Catch import Control.Monad (forM_) +import Control.Monad.Catch import Control.Monad.IO.Class import Data.Either import Data.List (zip4) -import Data.Monoid -import Prelude -import Data.Typeable -import Data.UUID.Types import qualified Data.Set as Set import qualified Data.Text as T +import Data.Typeable +import Data.UUID.Types import Data.Monoid.Utils import Database.PostgreSQL.PQTypes @@ -1284,7 +1282,7 @@ testSqlWithRecursive :: HasCallStack => (String -> TestM ()) -> TestM () testSqlWithRecursive step = do step "Running WITH RECURSIVE tests" testPass - where + where migrate tables migrations = do migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] tables migrations checkDatabase defaultExtrasOptions [] [] tables @@ -1310,7 +1308,7 @@ testSqlWithRecursive step = do sqlResult "root.cartel_member_id" sqlResult "root.cartel_boss_id" sqlWhere "root.cartel_boss_id IS NULL" - sqlUnionAll + sqlUnionAll [ sqlSelect "cartel child" $ do sqlResult "child.cartel_member_id" sqlResult "child.cartel_boss_id" @@ -1701,7 +1699,7 @@ foreignKeyIndexesTests connSource = let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True } assertException "Foreign keys are missing" $ migrateDatabase options ["pgcrypto"] [] [] [table1, table2, table3] [createTableMigration table1, createTableMigration table2, createTableMigration table3] - + step "Multi column indexes covering a FK pass the checks" do let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True } @@ -1818,8 +1816,8 @@ overlapingIndexesTests connSource = do , tblColumn { colName = "idx3", colType = UuidT } ] , tblPrimaryKey = pkOnColumn "id" - , tblIndexes = - [ indexOnColumns [ indexColumn "idx1", indexColumn "idx2" ] + , tblIndexes = + [ indexOnColumns [ indexColumn "idx1", indexColumn "idx2" ] , indexOnColumns [ indexColumn "idx1" ] ] } @@ -1834,7 +1832,7 @@ nullsNotDistinctTests connSource = do migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] [nullTableTest1, nullTableTest2] [createTableMigration nullTableTest1, createTableMigration nullTableTest2] checkDatabase defaultExtrasOptions [] [] [nullTableTest1, nullTableTest2] - + step "Insert two NULLs on a column with a default UNIQUE index" do runQuery_ . sqlInsert "nulltests1" $ do @@ -1848,7 +1846,7 @@ nullsNotDistinctTests connSource = do sqlSet "content" (Nothing @T.Text) assertDBException "Cannot insert twice a null value with NULLS NOT DISTINCT" $ runQuery_ . sqlInsert "nulltests2" $ do sqlSet "content" (Nothing @T.Text) - + where nullTableTest1 = tblTable { tblName = "nulltests1" @@ -1876,6 +1874,69 @@ nullsNotDistinctTests connSource = do ] } +sqlAnyAllTests :: TestTree +sqlAnyAllTests = testGroup "SQL ANY/ALL tests" + [ testCase "sqlAll produces correct queries" $ do + assertSqlEqual "empty sqlAll is TRUE" "TRUE" . sqlAll $ pure () + assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAll $ sqlWhere "cond" + assertSqlEqual "each condition as well as entire condition is parenthesized" + "((cond1) AND (cond2))" $ sqlAll $ do + sqlWhere "cond1" + sqlWhere "cond2" + + assertSqlEqual "sqlAll can be nested" + "((cond1) AND (cond2) AND (((cond3) AND (cond4))))" $ + sqlAll $ do + sqlWhere "cond1" + sqlWhere "cond2" + sqlWhere . sqlAll $ do + sqlWhere "cond3" + sqlWhere "cond4" + , testCase "sqlAny produces correct queries" $ do + assertSqlEqual "empty sqlAny is FALSE" "FALSE" . sqlAny $ pure () + assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAny $ sqlWhere "cond" + assertSqlEqual "each condition as well as entire condition is parenthesized" + "((cond1) OR (cond2))" $ sqlAny $ do + sqlWhere "cond1" + sqlWhere "cond2" + + assertSqlEqual "sqlAny can be nested" + "((cond1) OR (cond2) OR (((cond3) OR (cond4))))" $ + sqlAny $ do + sqlWhere "cond1" + sqlWhere "cond2" + sqlWhere . sqlAny $ do + sqlWhere "cond3" + sqlWhere "cond4" + , testCase "mixing sqlAny and all produces correct queries" $ do + assertSqlEqual "sqlAny and sqlAll can be mixed" + "((((cond1) OR (cond2))) AND (((cond3) OR (cond4))))" $ + sqlAll $ do + sqlWhere . sqlAny $ do + sqlWhere "cond1" + sqlWhere "cond2" + sqlWhere . sqlAny $ do + sqlWhere "cond3" + sqlWhere "cond4" + , testCase "sqlWhereAny produces correct queries" $ do + -- `sqlWhereAny` has to be wrapped in `sqlAll` to disambiguate the `SqlWhere` monad. + assertSqlEqual "empty sqlWhereAny is FALSE" "FALSE" . sqlAll $ sqlWhereAny [] + assertSqlEqual "each condition as well as entire condition is parenthesized and joined with OR" + "((cond1) OR (cond2))" . sqlAll $ sqlWhereAny [sqlWhere "cond1", sqlWhere "cond2"] + assertSqlEqual "nested multi-conditions are parenthesized and joined with AND" + "((cond1) OR (((cond2) AND (cond3))) OR (cond4))" . sqlAll $ sqlWhereAny + [ sqlWhere "cond1" + , do + sqlWhere "cond2" + sqlWhere "cond3" + , sqlWhere "cond4" + ] + ] + where + assertSqlEqual :: (Sqlable a) => String -> a -> a -> Assertion + assertSqlEqual msg a b = assertEqual msg + (show $ toSQLCommand a) (show $ toSQLCommand b) + assertNoException :: String -> TestM () -> TestM () assertNoException t = eitherExc (const $ liftIO $ assertFailure ("Exception thrown for: " ++ t)) @@ -1924,6 +1985,7 @@ main = do , foreignKeyIndexesTests connSource , overlapingIndexesTests connSource , nullsNotDistinctTests connSource + , sqlAnyAllTests ] where ings =