Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce sqlAll, sqlAny and related state types #112

Merged
merged 6 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# hpqtypes-extras-1.17.0.0 (2023-??-??)
* Add an optional check that all foreign keys have an index.
* Add support for NULLS NOT DISTINCT in unique indexes.
* Add `sqlAll` and `sqlAny` to allow creating `SQL` expressions with
nested `AND` and `OR` conditions.
* Add `SqlWhereAll` and `SqlWhereAny` so they can be used in signatures.

# hpqtypes-extras-1.16.4.4 (2023-08-23)
* Switch from `cryptonite` to `crypton`.
Expand Down
7 changes: 3 additions & 4 deletions src/Database/PostgreSQL/PQTypes/Checks.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ import Database.PostgreSQL.PQTypes.Migrate
import Database.PostgreSQL.PQTypes.Model
import Database.PostgreSQL.PQTypes.SQL.Builder
import Database.PostgreSQL.PQTypes.Versions
import Database.PostgreSQL.PQTypes.Utils.NubList

headExc :: String -> [a] -> a
headExc s [] = error s
Expand Down Expand Up @@ -429,7 +428,7 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version)
runQuery_ $ sqlGetForeignKeys table
fkeys <- fetchMany fetchForeignKey
triggers <- getDBTriggers tblName
checkedOverlaps <- checkOverlappingIndexes
checkedOverlaps <- checkOverlappingIndexes
return $ mconcat [
checkColumns 1 tblColumns desc
, checkPrimaryKey tblPrimaryKey pk
Expand Down Expand Up @@ -575,7 +574,7 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version)

allCoverage :: [[RawSQL ()]]
allCoverage = maybe [] pkColumns pkey:allIndexes

-- A foreign key is covered if it is a prefix of a list of indices.
-- So a FK on a is covered by an index on (a, b) but not an index on (b, a).
coveredFK :: ForeignKey -> [[RawSQL ()]] -> Bool
Expand Down Expand Up @@ -604,7 +603,7 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version)
if eoCheckOverlappingIndexes options
then go
else pure mempty
where
where
go = do
let handleOverlap (contained, contains) =
mconcat
Expand Down
84 changes: 62 additions & 22 deletions src/Database/PostgreSQL/PQTypes/SQL/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ module Database.PostgreSQL.PQTypes.SQL.Builder
, sqlDelete
, SqlDelete(..)

, SqlWhereAll(..)
, sqlAll
, SqlWhereAny(..)
, sqlAny
, sqlWhereAny

, SqlResult
Expand Down Expand Up @@ -291,10 +295,19 @@ data SqlDelete = SqlDelete
, sqlDeleteRecursiveWith :: Recursive
}

-- | This is not exported and is used as an implementation detail in
-- 'sqlWhereAll'.
newtype SqlAll = SqlAll
{ sqlAllWhere :: [SqlCondition]

-- | Type representing a set of conditions that are joined by 'AND'.
--
-- When no conditions are given, the result is 'TRUE'.
newtype SqlWhereAll = SqlWhereAll
{ sqlWhereAllWhere :: [SqlCondition]
}

-- | Type representing a set of conditions that are joined by 'OR'.
arybczak marked this conversation as resolved.
Show resolved Hide resolved
--
-- When no conditions are given, the result is 'FALSE'.
newtype SqlWhereAny = SqlWhereAny
{ sqlWhereAnyWhere :: [SqlCondition]
}

instance Show SqlSelect where
Expand All @@ -312,7 +325,10 @@ instance Show SqlUpdate where
instance Show SqlDelete where
show = show . toSQLCommand

instance Show SqlAll where
instance Show SqlWhereAll where
show = show . toSQLCommand

instance Show SqlWhereAny where
show = show . toSQLCommand

emitClause :: Sqlable sql => SQL -> sql -> SQL
Expand Down Expand Up @@ -464,7 +480,7 @@ materializedClause Materialized = if isWithMaterializedSupported then "MATERIALI
materializedClause NonMaterialized = if isWithMaterializedSupported then "NOT MATERIALIZED" else ""

recursiveClause :: Recursive -> SQL
recursiveClause Recursive = "WITH RECURSIVE"
recursiveClause Recursive = "WITH RECURSIVE"
recursiveClause NonRecursive = "WITH"

instance Sqlable SqlUpdate where
Expand All @@ -486,10 +502,17 @@ instance Sqlable SqlDelete where
emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlDeleteWhere cmd) <+>
emitClausesSepComma "RETURNING" (sqlDeleteResult cmd)

instance Sqlable SqlAll where
toSQLCommand cmd | null (sqlAllWhere cmd) = "TRUE"
toSQLCommand cmd =
"(" <+> smintercalate "AND" (map (parenthesize . toSQLCommand) (sqlAllWhere cmd)) <+> ")"
instance Sqlable SqlWhereAll where
toSQLCommand cmd = case sqlWhereAllWhere cmd of
[] -> "TRUE"
[cond] -> toSQLCommand cond
conds -> parenthesize $ smintercalate "AND" (map (parenthesize . toSQLCommand) conds)

instance Sqlable SqlWhereAny where
toSQLCommand cmd = case sqlWhereAnyWhere cmd of
[] -> "FALSE"
[cond] -> toSQLCommand cond
conds -> parenthesize $ smintercalate "OR" (map (parenthesize . toSQLCommand) conds)

sqlSelect :: SQL -> State SqlSelect () -> SqlSelect
sqlSelect table refine =
Expand Down Expand Up @@ -546,7 +569,7 @@ data Recursive = Recursive | NonRecursive
instance Semigroup Recursive where
_ <> Recursive = Recursive
Recursive <> _ = Recursive
_ <> _ = NonRecursive
_ <> _ = NonRecursive

class SqlWith a where
sqlWith1 :: a -> SQL -> SQL -> Materialized -> Recursive -> a
Expand Down Expand Up @@ -605,9 +628,13 @@ instance SqlWhere SqlDelete where
sqlWhere1 cmd cond = cmd { sqlDeleteWhere = sqlDeleteWhere cmd ++ [cond] }
sqlGetWhereConditions = sqlDeleteWhere

instance SqlWhere SqlAll where
sqlWhere1 cmd cond = cmd { sqlAllWhere = sqlAllWhere cmd ++ [cond] }
sqlGetWhereConditions = sqlAllWhere
instance SqlWhere SqlWhereAll where
sqlWhere1 cmd cond = cmd { sqlWhereAllWhere = sqlWhereAllWhere cmd ++ [cond] }
sqlGetWhereConditions = sqlWhereAllWhere

instance SqlWhere SqlWhereAny where
sqlWhere1 cmd cond = cmd { sqlWhereAnyWhere = sqlWhereAnyWhere cmd ++ [cond] }
sqlGetWhereConditions = sqlWhereAnyWhere

-- | The @WHERE@ part of an SQL query. See above for a usage
-- example. See also 'SqlCondition'.
Expand Down Expand Up @@ -664,16 +691,29 @@ sqlWhereIsNULL col = sqlWhere $ col <+> "IS NULL"
sqlWhereIsNotNULL :: (MonadState v m, SqlWhere v) => SQL -> m ()
sqlWhereIsNotNULL col = sqlWhere $ col <+> "IS NOT NULL"

-- | Run monad that joins all conditions using 'AND' operator.
--
-- When no conditions are given, the result is 'TRUE'.
--
-- Note: This is usally not needed as `SqlSelect`, `SqlUpdate` and `SqlDelete`
-- already join conditions using 'AND' by default, but it can be useful when
-- nested in `sqlAny`.
sqlAll :: State SqlWhereAll () -> SQL
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is it for? Apart from being used in sqlWhereAny. Tests?

Because sqlWhere . sqlAll is redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well for nesting ANDs in OR's and parity with sqlAny. How else are you going to write (a OR (b AND c))?

I mean, you could use sqlWhereAny, but if you want to go with sqlWhere . sqlAny $ ..., then at one point you have to use sqlWhere . sqlAll $ ... inside.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. The comment you added to the code makes it clear, thanks.

sqlAll = toSQLCommand . (`execState` SqlWhereAll [])

-- | Run monad that joins all conditions using 'OR' operator.
--
-- When no conditions are given, the result is 'FALSE'.
sqlAny :: State SqlWhereAny () -> SQL
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should deprecate sqlWhereAny and start using sqlWhere . sqlAny $ do ... instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depends if you'd like to remove it eventually. But if you want to deprecate sqlWhereAny, then I'd like to deprecate sqlOR, sqlConcatAND and sqlConcatOR as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it comes to kontrakcja, sqlOR is used 1 time, sqlConcatAND 5 times and sqlConcatOR 4 times, so it seems feasible. Let's do it a separate PR though.

sqlAny = toSQLCommand . (`execState` SqlWhereAny [])

-- | Add a condition in the WHERE statement that holds if any of the given
-- condition holds.
sqlWhereAny :: (MonadState v m, SqlWhere v) => [State SqlAll ()] -> m ()
sqlWhereAny = sqlWhere . sqlWhereAnyImpl

sqlWhereAnyImpl :: [State SqlAll ()] -> SQL
sqlWhereAnyImpl [] = "FALSE"
sqlWhereAnyImpl l =
"(" <+> smintercalate "OR" (map (parenthesize . toSQLCommand
. flip execState (SqlAll [])) l) <+> ")"
--
-- These conditions are joined with 'OR' operator.
-- When no conditions are given, the result is 'FALSE'.
sqlWhereAny :: (MonadState v m, SqlWhere v) => [State SqlWhereAll ()] -> m ()
sqlWhereAny = sqlWhere . sqlAny . mapM_ (sqlWhere . sqlAll)

class SqlFrom a where
sqlFrom1 :: a -> SQL -> a
Expand Down
86 changes: 74 additions & 12 deletions test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
{-# HLINT ignore "Use head" #-}
module Main where

import Control.Monad.Catch
import Control.Monad (forM_)
import Control.Monad.Catch
import Control.Monad.IO.Class
import Data.Either
import Data.List (zip4)
import Data.Monoid
import Prelude
import Data.Typeable
import Data.UUID.Types
import qualified Data.Set as Set
import qualified Data.Text as T
import Data.Typeable
import Data.UUID.Types

import Data.Monoid.Utils
import Database.PostgreSQL.PQTypes
Expand Down Expand Up @@ -1284,7 +1282,7 @@ testSqlWithRecursive :: HasCallStack => (String -> TestM ()) -> TestM ()
testSqlWithRecursive step = do
step "Running WITH RECURSIVE tests"
testPass
where
where
migrate tables migrations = do
migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] tables migrations
checkDatabase defaultExtrasOptions [] [] tables
Expand All @@ -1310,7 +1308,7 @@ testSqlWithRecursive step = do
sqlResult "root.cartel_member_id"
sqlResult "root.cartel_boss_id"
sqlWhere "root.cartel_boss_id IS NULL"
sqlUnionAll
sqlUnionAll
[ sqlSelect "cartel child" $ do
sqlResult "child.cartel_member_id"
sqlResult "child.cartel_boss_id"
Expand Down Expand Up @@ -1701,7 +1699,7 @@ foreignKeyIndexesTests connSource =
let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True }
assertException "Foreign keys are missing" $ migrateDatabase options ["pgcrypto"] [] [] [table1, table2, table3]
[createTableMigration table1, createTableMigration table2, createTableMigration table3]

step "Multi column indexes covering a FK pass the checks"
do
let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True }
Expand Down Expand Up @@ -1818,8 +1816,8 @@ overlapingIndexesTests connSource = do
, tblColumn { colName = "idx3", colType = UuidT }
]
, tblPrimaryKey = pkOnColumn "id"
, tblIndexes =
[ indexOnColumns [ indexColumn "idx1", indexColumn "idx2" ]
, tblIndexes =
[ indexOnColumns [ indexColumn "idx1", indexColumn "idx2" ]
, indexOnColumns [ indexColumn "idx1" ]
]
}
Expand All @@ -1834,7 +1832,7 @@ nullsNotDistinctTests connSource = do
migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] [nullTableTest1, nullTableTest2]
[createTableMigration nullTableTest1, createTableMigration nullTableTest2]
checkDatabase defaultExtrasOptions [] [] [nullTableTest1, nullTableTest2]

step "Insert two NULLs on a column with a default UNIQUE index"
do
runQuery_ . sqlInsert "nulltests1" $ do
Expand All @@ -1848,7 +1846,7 @@ nullsNotDistinctTests connSource = do
sqlSet "content" (Nothing @T.Text)
assertDBException "Cannot insert twice a null value with NULLS NOT DISTINCT" $ runQuery_ . sqlInsert "nulltests2" $ do
sqlSet "content" (Nothing @T.Text)

where
nullTableTest1 = tblTable
{ tblName = "nulltests1"
Expand Down Expand Up @@ -1876,6 +1874,69 @@ nullsNotDistinctTests connSource = do
]
}

sqlAnyAllTests :: TestTree
sqlAnyAllTests = testGroup "SQL ANY/ALL tests"
[ testCase "sqlAll produces correct queries" $ do
assertSqlEqual "empty sqlAll is TRUE" "TRUE" . sqlAll $ pure ()
assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAll $ sqlWhere "cond"
assertSqlEqual "each condition as well as entire condition is parenthesized"
"((cond1) AND (cond2))" $ sqlAll $ do
sqlWhere "cond1"
sqlWhere "cond2"

assertSqlEqual "sqlAll can be nested"
"((cond1) AND (cond2) AND (((cond3) AND (cond4))))" $
sqlAll $ do
sqlWhere "cond1"
sqlWhere "cond2"
sqlWhere . sqlAll $ do
sqlWhere "cond3"
sqlWhere "cond4"
, testCase "sqlAny produces correct queries" $ do
assertSqlEqual "empty sqlAny is FALSE" "FALSE" . sqlAny $ pure ()
assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAny $ sqlWhere "cond"
assertSqlEqual "each condition as well as entire condition is parenthesized"
"((cond1) OR (cond2))" $ sqlAny $ do
sqlWhere "cond1"
sqlWhere "cond2"

assertSqlEqual "sqlAny can be nested"
"((cond1) OR (cond2) OR (((cond3) OR (cond4))))" $
sqlAny $ do
sqlWhere "cond1"
sqlWhere "cond2"
sqlWhere . sqlAny $ do
sqlWhere "cond3"
sqlWhere "cond4"
, testCase "mixing sqlAny and all produces correct queries" $ do
assertSqlEqual "sqlAny and sqlAll can be mixed"
"((((cond1) OR (cond2))) AND (((cond3) OR (cond4))))" $
sqlAll $ do
sqlWhere . sqlAny $ do
sqlWhere "cond1"
sqlWhere "cond2"
sqlWhere . sqlAny $ do
sqlWhere "cond3"
sqlWhere "cond4"
, testCase "sqlWhereAny produces correct queries" $ do
-- `sqlWhereAny` has to be wrapped in `sqlAll` to disambiguate the `SqlWhere` monad.
assertSqlEqual "empty sqlWhereAny is FALSE" "FALSE" . sqlAll $ sqlWhereAny []
assertSqlEqual "each condition as well as entire condition is parenthesized and joined with OR"
"((cond1) OR (cond2))" . sqlAll $ sqlWhereAny [sqlWhere "cond1", sqlWhere "cond2"]
assertSqlEqual "nested multi-conditions are parenthesized and joined with AND"
"((cond1) OR (((cond2) AND (cond3))) OR (cond4))" . sqlAll $ sqlWhereAny
[ sqlWhere "cond1"
, do
sqlWhere "cond2"
sqlWhere "cond3"
, sqlWhere "cond4"
]
]
where
assertSqlEqual :: (Sqlable a) => String -> a -> a -> Assertion
assertSqlEqual msg a b = assertEqual msg
(show $ toSQLCommand a) (show $ toSQLCommand b)

assertNoException :: String -> TestM () -> TestM ()
assertNoException t = eitherExc
(const $ liftIO $ assertFailure ("Exception thrown for: " ++ t))
Expand Down Expand Up @@ -1924,6 +1985,7 @@ main = do
, foreignKeyIndexesTests connSource
, overlapingIndexesTests connSource
, nullsNotDistinctTests connSource
, sqlAnyAllTests
]
where
ings =
Expand Down
Loading