diff --git a/.github/workflows/fourmolu.yaml b/.github/workflows/fourmolu.yaml new file mode 100644 index 0000000..753ac08 --- /dev/null +++ b/.github/workflows/fourmolu.yaml @@ -0,0 +1,11 @@ +name: Fourmolu +on: push +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: haskell-actions/run-fourmolu@v10 + with: + version: "0.15.0.0" + diff --git a/.github/workflows/haskell-ci.yml b/.github/workflows/haskell-ci.yml index d6c5a41..89b3141 100644 --- a/.github/workflows/haskell-ci.yml +++ b/.github/workflows/haskell-ci.yml @@ -38,6 +38,16 @@ jobs: strategy: matrix: include: + - compiler: ghc-9.10.1 + compilerKind: ghc + compilerVersion: 9.10.1 + setup-method: ghcup + allow-failure: false + - compiler: ghc-9.8.2 + compilerKind: ghc + compilerVersion: 9.8.2 + setup-method: ghcup + allow-failure: false - compiler: ghc-9.6.2 compilerKind: ghc compilerVersion: 9.6.2 @@ -63,11 +73,6 @@ jobs: compilerVersion: 8.10.7 setup-method: ghcup allow-failure: false - - compiler: ghc-8.8.4 - compilerKind: ghc - compilerVersion: 8.8.4 - setup-method: ghcup - allow-failure: false fail-fast: false steps: - name: apt diff --git a/Setup.hs b/Setup.hs deleted file mode 100644 index 9a994af..0000000 --- a/Setup.hs +++ /dev/null @@ -1,2 +0,0 @@ -import Distribution.Simple -main = defaultMain diff --git a/benchmark/Main.hs b/benchmark/Main.hs index 10a1acd..3b844b5 100644 --- a/benchmark/Main.hs +++ b/benchmark/Main.hs @@ -7,22 +7,71 @@ import Test.Tasty.Bench import Database.PostgreSQL.PQTypes.Deriving main :: IO () -main = defaultMain - [ bgroup "enum" - [ bench "encode" $ nf (encodeEnum @T) T42 - , bench "decode" $ nf (decodeEnum @T) 42 +main = + defaultMain + [ bgroup + "enum" + [ bench "encode" $ nf (encodeEnum @T) T42 + , bench "decode" $ nf (decodeEnum @T) 42 + ] + , bgroup + "enum-text" + [ bench "encode" $ nf (encodeEnumAsText @S) S42 + , bench "decode" $ nf (decodeEnumAsText @S) "text_42" + ] ] - , bgroup "enum-text" - [ bench "encode" $ nf (encodeEnumAsText @S) S42 - , bench "decode" $ nf (decodeEnumAsText @S) "text_42" - ] - ] -data T = T01 | T02 | T03 | T04 | T05 | T06 | T07 | T08 | T09 | T10 - | T11 | T12 | T13 | T14 | T15 | T16 | T17 | T18 | T19 | T20 - | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 - | T31 | T32 | T33 | T34 | T35 | T36 | T37 | T38 | T39 | T40 - | T41 | T42 | T43 | T44 | T45 | T46 | T47 | T48 | T49 | T50 +data T + = T01 + | T02 + | T03 + | T04 + | T05 + | T06 + | T07 + | T08 + | T09 + | T10 + | T11 + | T12 + | T13 + | T14 + | T15 + | T16 + | T17 + | T18 + | T19 + | T20 + | T21 + | T22 + | T23 + | T24 + | T25 + | T26 + | T27 + | T28 + | T29 + | T30 + | T31 + | T32 + | T33 + | T34 + | T35 + | T36 + | T37 + | T38 + | T39 + | T40 + | T41 + | T42 + | T43 + | T44 + | T45 + | T46 + | T47 + | T48 + | T49 + | T50 deriving (Eq, Show, Enum, Bounded) instance NFData T where @@ -33,26 +82,112 @@ instance NFData T where -- >>> isInjective (encodeEnum @T) -- True instance EnumEncoding T where - type EnumBase T = Int16 - encodeEnum = \case - T01 -> 1; T02 -> 2; T03 -> 3; T04 -> 4; T05 -> 5 - T06 -> 6; T07 -> 7; T08 -> 8; T09 -> 9; T10 -> 10 - T11 -> 11; T12 -> 12; T13 -> 13; T14 -> 14; T15 -> 15 - T16 -> 16; T17 -> 17; T18 -> 18; T19 -> 19; T20 -> 20 - T21 -> 21; T22 -> 22; T23 -> 23; T24 -> 24; T25 -> 25 - T26 -> 26; T27 -> 27; T28 -> 28; T29 -> 29; T30 -> 30 - T31 -> 31; T32 -> 32; T33 -> 33; T34 -> 34; T35 -> 35 - T36 -> 36; T37 -> 37; T38 -> 38; T39 -> 39; T40 -> 40 - T41 -> 41; T42 -> 42; T43 -> 43; T44 -> 44; T45 -> 45 - T46 -> 46; T47 -> 47; T48 -> 48; T49 -> 49; T50 -> 50 + type EnumBase T = Int16 + encodeEnum = \case + T01 -> 1 + T02 -> 2 + T03 -> 3 + T04 -> 4 + T05 -> 5 + T06 -> 6 + T07 -> 7 + T08 -> 8 + T09 -> 9 + T10 -> 10 + T11 -> 11 + T12 -> 12 + T13 -> 13 + T14 -> 14 + T15 -> 15 + T16 -> 16 + T17 -> 17 + T18 -> 18 + T19 -> 19 + T20 -> 20 + T21 -> 21 + T22 -> 22 + T23 -> 23 + T24 -> 24 + T25 -> 25 + T26 -> 26 + T27 -> 27 + T28 -> 28 + T29 -> 29 + T30 -> 30 + T31 -> 31 + T32 -> 32 + T33 -> 33 + T34 -> 34 + T35 -> 35 + T36 -> 36 + T37 -> 37 + T38 -> 38 + T39 -> 39 + T40 -> 40 + T41 -> 41 + T42 -> 42 + T43 -> 43 + T44 -> 44 + T45 -> 45 + T46 -> 46 + T47 -> 47 + T48 -> 48 + T49 -> 49 + T50 -> 50 ---------------------------------------- -data S = S01 | S02 | S03 | S04 | S05 | S06 | S07 | S08 | S09 | S10 - | S11 | S12 | S13 | S14 | S15 | S16 | S17 | S18 | S19 | S20 - | S21 | S22 | S23 | S24 | S25 | S26 | S27 | S28 | S29 | S30 - | S31 | S32 | S33 | S34 | S35 | S36 | S37 | S38 | S39 | S40 - | S41 | S42 | S43 | S44 | S45 | S46 | S47 | S48 | S49 | S50 +data S + = S01 + | S02 + | S03 + | S04 + | S05 + | S06 + | S07 + | S08 + | S09 + | S10 + | S11 + | S12 + | S13 + | S14 + | S15 + | S16 + | S17 + | S18 + | S19 + | S20 + | S21 + | S22 + | S23 + | S24 + | S25 + | S26 + | S27 + | S28 + | S29 + | S30 + | S31 + | S32 + | S33 + | S34 + | S35 + | S36 + | S37 + | S38 + | S39 + | S40 + | S41 + | S42 + | S43 + | S44 + | S45 + | S46 + | S47 + | S48 + | S49 + | S50 deriving (Eq, Show, Enum, Bounded) instance NFData S where @@ -64,16 +199,53 @@ instance NFData S where -- True instance EnumAsTextEncoding S where encodeEnumAsText = \case - S01 -> "text_01"; S02 -> "text_02"; S03 -> "text_03"; S04 -> "text_04" - S05 -> "text_05"; S06 -> "text_06"; S07 -> "text_07"; S08 -> "text_08" - S09 -> "text_09"; S10 -> "text_10"; S11 -> "text_11"; S12 -> "text_12" - S13 -> "text_13"; S14 -> "text_14"; S15 -> "text_15"; S16 -> "text_16" - S17 -> "text_17"; S18 -> "text_18"; S19 -> "text_19"; S20 -> "text_20" - S21 -> "text_21"; S22 -> "text_22"; S23 -> "text_23"; S24 -> "text_24" - S25 -> "text_25"; S26 -> "text_26"; S27 -> "text_27"; S28 -> "text_28" - S29 -> "text_29"; S30 -> "text_30"; S31 -> "text_31"; S32 -> "text_32" - S33 -> "text_33"; S34 -> "text_34"; S35 -> "text_35"; S36 -> "text_36" - S37 -> "text_37"; S38 -> "text_38"; S39 -> "text_39"; S40 -> "text_40" - S41 -> "text_41"; S42 -> "text_42"; S43 -> "text_43"; S44 -> "text_44" - S45 -> "text_45"; S46 -> "text_46"; S47 -> "text_47"; S48 -> "text_48" - S49 -> "text_49"; S50 -> "text_50"; + S01 -> "text_01" + S02 -> "text_02" + S03 -> "text_03" + S04 -> "text_04" + S05 -> "text_05" + S06 -> "text_06" + S07 -> "text_07" + S08 -> "text_08" + S09 -> "text_09" + S10 -> "text_10" + S11 -> "text_11" + S12 -> "text_12" + S13 -> "text_13" + S14 -> "text_14" + S15 -> "text_15" + S16 -> "text_16" + S17 -> "text_17" + S18 -> "text_18" + S19 -> "text_19" + S20 -> "text_20" + S21 -> "text_21" + S22 -> "text_22" + S23 -> "text_23" + S24 -> "text_24" + S25 -> "text_25" + S26 -> "text_26" + S27 -> "text_27" + S28 -> "text_28" + S29 -> "text_29" + S30 -> "text_30" + S31 -> "text_31" + S32 -> "text_32" + S33 -> "text_33" + S34 -> "text_34" + S35 -> "text_35" + S36 -> "text_36" + S37 -> "text_37" + S38 -> "text_38" + S39 -> "text_39" + S40 -> "text_40" + S41 -> "text_41" + S42 -> "text_42" + S43 -> "text_43" + S44 -> "text_44" + S45 -> "text_45" + S46 -> "text_46" + S47 -> "text_47" + S48 -> "text_48" + S49 -> "text_49" + S50 -> "text_50" diff --git a/fourmolu.yaml b/fourmolu.yaml new file mode 100644 index 0000000..b47fe72 --- /dev/null +++ b/fourmolu.yaml @@ -0,0 +1,53 @@ +# Number of spaces per indentation step +indentation: 2 + +# Max line length for automatic line breaking +column-limit: none + +# Styling of arrows in type signatures (choices: trailing, leading, or leading-args) +function-arrows: leading + +# How to place commas in multi-line lists, records, etc. (choices: leading or trailing) +comma-style: leading + +# Styling of import/export lists (choices: leading, trailing, or diff-friendly) +import-export-style: leading + +# Whether to full-indent or half-indent 'where' bindings past the preceding body +indent-wheres: true + +# Whether to leave a space before an opening record brace +record-brace-space: true + +# Number of spaces between top-level declarations +newlines-between-decls: 1 + +# How to print Haddock comments (choices: single-line, multi-line, or multi-line-compact) +haddock-style: single-line + +# How to print module docstring +haddock-style-module: null + +# Styling of let blocks (choices: auto, inline, newline, or mixed) +let-style: inline + +# How to align the 'in' keyword with respect to the 'let' keyword (choices: left-align, right-align, or no-space) +in-style: no-space + +# Whether to put parentheses around a single constraint (choices: auto, always, or never) +single-constraint-parens: never + +# Whether to put parentheses around a single deriving class (choices: auto, always, or never) +single-deriving-parens: always + +# Output Unicode syntax (choices: detect, always, or never) +unicode: never + +# Give the programmer more choice on where to insert blank lines +respectful: true + +# Fixity information for operators +fixities: [] + +# Module reexports Fourmolu should know about +reexports: [] diff --git a/hpqtypes-extras.cabal b/hpqtypes-extras.cabal index 852e4e0..549a216 100644 --- a/hpqtypes-extras.cabal +++ b/hpqtypes-extras.cabal @@ -1,4 +1,4 @@ -cabal-version: 2.2 +cabal-version: 3.0 name: hpqtypes-extras version: 1.16.4.4 synopsis: Extra utilities for hpqtypes library @@ -20,7 +20,7 @@ maintainer: Andrzej Rybczak , copyright: Scrive AB category: Database build-type: Simple -tested-with: GHC ==8.8.4 || ==8.10.7 || ==9.0.2 || ==9.2.8 || ==9.4.6 || ==9.6.2 +tested-with: GHC == { 8.10.7, 9.0.2, 9.2.8, 9.4.6, 9.6.2, 9.8.2, 9.10.1 } Source-repository head Type: git @@ -33,6 +33,7 @@ common common-stanza , ExistentialQuantification , FlexibleContexts , GeneralizedNewtypeDeriving + , ImportQualifiedPost , LambdaCase , MultiWayIf , OverloadedStrings @@ -45,6 +46,7 @@ common common-stanza , TypeFamilies , UndecidableInstances , ViewPatterns + ghc-options: -Werror=prepositive-qualified-module library import: common-stanza diff --git a/src/Database/PostgreSQL/PQTypes/Checks.hs b/src/Database/PostgreSQL/PQTypes/Checks.hs index 24e260d..9b40f1a 100644 --- a/src/Database/PostgreSQL/PQTypes/Checks.hs +++ b/src/Database/PostgreSQL/PQTypes/Checks.hs @@ -1,15 +1,15 @@ -module Database.PostgreSQL.PQTypes.Checks ( - -- * Checks +module Database.PostgreSQL.PQTypes.Checks + ( -- * Checks checkDatabase , createTable , createDomain - -- * Options - , ExtrasOptions(..) + -- * Options + , ExtrasOptions (..) , defaultExtrasOptions - , ObjectsValidationMode(..) + , ObjectsValidationMode (..) - -- * Migrations + -- * Migrations , migrateDatabase ) where @@ -18,34 +18,34 @@ import Control.Concurrent (threadDelay) import Control.Monad import Control.Monad.Catch import Control.Monad.Reader -import Data.Int import Data.Foldable (foldMap') import Data.Function +import Data.Int import Data.List (partition) +import Data.List qualified as L +import Data.Map qualified as M import Data.Maybe import Data.Monoid.Utils -import Data.Typeable (cast) -import qualified Data.String +import Data.Set qualified as S +import Data.String qualified import Data.Text (Text) +import Data.Text qualified as T +import Data.Typeable (cast) import Database.PostgreSQL.PQTypes import GHC.Stack (HasCallStack) import Log import TextShow -import qualified Data.List as L -import qualified Data.Map as M -import qualified Data.Set as S -import qualified Data.Text as T -import Database.PostgreSQL.PQTypes.ExtrasOptions import Database.PostgreSQL.PQTypes.Checks.Util +import Database.PostgreSQL.PQTypes.ExtrasOptions import Database.PostgreSQL.PQTypes.Migrate import Database.PostgreSQL.PQTypes.Model import Database.PostgreSQL.PQTypes.SQL.Builder import Database.PostgreSQL.PQTypes.Versions headExc :: String -> [a] -> a -headExc s [] = error s -headExc _ (x:_) = x +headExc s [] = error s +headExc _ (x : _) = x ---------------------------------------- @@ -59,35 +59,43 @@ migrateDatabase -> [Table] -> [Migration m] -> m () -migrateDatabase options - extensions composites domains tables migrations = do - setDBTimeZoneToUTC - mapM_ checkExtension extensions - tablesWithVersions <- getTableVersions (tableVersions : tables) - -- 'checkDBConsistency' also performs migrations. - checkDBConsistency options domains tablesWithVersions migrations - resultCheck =<< checkCompositesStructure tablesWithVersions - CreateCompositesIfDatabaseEmpty - (eoObjectsValidationMode options) - composites - resultCheck =<< checkDomainsStructure domains - resultCheck =<< checkDBStructure options tablesWithVersions - resultCheck =<< checkTablesWereDropped migrations - - when (eoObjectsValidationMode options == DontAllowUnknownObjects) $ do - resultCheck =<< checkUnknownTables tables - resultCheck =<< checkExistenceOfVersionsForTables (tableVersions : tables) - - -- After migrations are done make sure the table versions are correct. - resultCheck . checkVersions options =<< getTableVersions (tableVersions : tables) - - -- everything is OK, commit changes - commit +migrateDatabase + options + extensions + composites + domains + tables + migrations = do + setDBTimeZoneToUTC + mapM_ checkExtension extensions + tablesWithVersions <- getTableVersions (tableVersions : tables) + -- 'checkDBConsistency' also performs migrations. + checkDBConsistency options domains tablesWithVersions migrations + resultCheck + =<< checkCompositesStructure + tablesWithVersions + CreateCompositesIfDatabaseEmpty + (eoObjectsValidationMode options) + composites + resultCheck =<< checkDomainsStructure domains + resultCheck =<< checkDBStructure options tablesWithVersions + resultCheck =<< checkTablesWereDropped migrations + + when (eoObjectsValidationMode options == DontAllowUnknownObjects) $ do + resultCheck =<< checkUnknownTables tables + resultCheck =<< checkExistenceOfVersionsForTables (tableVersions : tables) + + -- After migrations are done make sure the table versions are correct. + resultCheck . checkVersions options =<< getTableVersions (tableVersions : tables) + + -- everything is OK, commit changes + commit -- | Run checks on the database structure and whether the database needs to be -- migrated. Will do a full check of DB structure. checkDatabase - :: forall m . (MonadDB m, MonadLog m, MonadThrow m) + :: forall m + . (MonadDB m, MonadLog m, MonadThrow m) => ExtrasOptions -> [CompositeType] -> [Domain] @@ -96,10 +104,12 @@ checkDatabase checkDatabase options composites domains tables = do tablesWithVersions <- getTableVersions (tableVersions : tables) resultCheck $ checkVersions options tablesWithVersions - resultCheck =<< checkCompositesStructure tablesWithVersions - DontCreateComposites - (eoObjectsValidationMode options) - composites + resultCheck + =<< checkCompositesStructure + tablesWithVersions + DontCreateComposites + (eoObjectsValidationMode options) + composites resultCheck =<< checkDomainsStructure domains resultCheck =<< checkDBStructure options tablesWithVersions when (eoObjectsValidationMode options == DontAllowUnknownObjects) $ do @@ -109,18 +119,21 @@ checkDatabase options composites domains tables = do -- Check initial setups only after database structure is considered -- consistent as before that some of the checks may fail internally. resultCheck =<< checkInitialSetups tables - where checkInitialSetups :: [Table] -> m ValidationResult checkInitialSetups = fmap mconcat . mapM checkInitialSetup' checkInitialSetup' :: Table -> m ValidationResult - checkInitialSetup' t@Table{..} = case tblInitialSetup of + checkInitialSetup' t@Table {..} = case tblInitialSetup of Nothing -> return mempty - Just is -> checkInitialSetup is >>= \case - True -> return mempty - False -> return . validationError $ "Initial setup for table '" - <> tblNameText t <> "' is not valid" + Just is -> + checkInitialSetup is >>= \case + True -> return mempty + False -> + return . validationError $ + "Initial setup for table '" + <> tblNameText t + <> "' is not valid" -- | Return SQL fragment of current catalog within quotes currentCatalog :: (MonadDB m, MonadThrow m) => m (RawSQL ()) @@ -153,15 +166,17 @@ setDBTimeZoneToUTC = do timezone :: String <- fetchOne runIdentity when (timezone /= "UTC") $ do dbname <- currentCatalog - logInfo_ $ "Setting '" <> unRawSQL dbname - <> "' database to return timestamps in UTC" + logInfo_ $ + "Setting '" + <> unRawSQL dbname + <> "' database to return timestamps in UTC" runQuery_ $ "ALTER DATABASE" <+> dbname <+> "SET TIMEZONE = 'UTC'" -- Setting the database timezone doesn't change the session timezone. runSQL_ "SET timezone = 'UTC'" -- | Get the names of all user-defined tables that actually exist in -- the DB. -getDBTableNames :: (MonadDB m) => m [Text] +getDBTableNames :: MonadDB m => m [Text] getDBTableNames = do runQuery_ $ sqlSelect "information_schema.tables" $ do sqlResult "table_name::text" @@ -176,142 +191,184 @@ checkVersions :: ExtrasOptions -> TablesWithVersions -> ValidationResult checkVersions options = mconcat . map checkVersion where checkVersion :: (Table, Int32) -> ValidationResult - checkVersion (t@Table{..}, v) + checkVersion (t@Table {..}, v) | if eoAllowHigherTableVersions options - then tblVersion <= v - else tblVersion == v = mempty - | v == 0 = validationError $ - "Table '" <> tblNameText t <> "' must be created" - | otherwise = validationError $ - "Table '" <> tblNameText t - <> "' must be migrated" <+> showt v <+> "->" - <+> showt tblVersion + then tblVersion <= v + else tblVersion == v = + mempty + | v == 0 = + validationError $ + "Table '" <> tblNameText t <> "' must be created" + | otherwise = + validationError $ + "Table '" + <> tblNameText t + <> "' must be migrated" + <+> showt v + <+> "->" + <+> showt tblVersion -- | Check that there's a 1-to-1 correspondence between the list of -- 'Table's and what's actually in the database. checkUnknownTables :: (MonadDB m, MonadLog m) => [Table] -> m ValidationResult checkUnknownTables tables = do - dbTableNames <- getDBTableNames + dbTableNames <- getDBTableNames let tableNames = map (unRawSQL . tblName) tables - absent = dbTableNames L.\\ tableNames - notPresent = tableNames L.\\ dbTableNames + absent = dbTableNames L.\\ tableNames + notPresent = tableNames L.\\ dbTableNames if (not . null $ absent) || (not . null $ notPresent) then do - mapM_ (logInfo_ . (<+>) "Unknown table:") absent - mapM_ (logInfo_ . (<+>) "Table not present in the database:") notPresent - return $ - validateIsNull "Unknown tables:" absent <> - validateIsNull "Tables not present in the database:" notPresent + mapM_ (logInfo_ . (<+>) "Unknown table:") absent + mapM_ (logInfo_ . (<+>) "Table not present in the database:") notPresent + return $ + validateIsNull "Unknown tables:" absent + <> validateIsNull "Tables not present in the database:" notPresent else return mempty validateIsNull :: Text -> [Text] -> ValidationResult -validateIsNull _ [] = mempty +validateIsNull _ [] = mempty validateIsNull msg ts = validationError $ msg <+> T.intercalate ", " ts -- | Check that there's a 1-to-1 correspondence between the list of -- 'Table's and what's actually in the table 'table_versions'. checkExistenceOfVersionsForTables :: (MonadDB m, MonadLog m) - => [Table] -> m ValidationResult + => [Table] + -> m ValidationResult checkExistenceOfVersionsForTables tables = do runQuery_ $ sqlSelect "table_versions" $ do sqlResult "name::text" (existingTableNames :: [Text]) <- fetchMany runIdentity let tableNames = map (unRawSQL . tblName) tables - absent = existingTableNames L.\\ tableNames - notPresent = tableNames L.\\ existingTableNames + absent = existingTableNames L.\\ tableNames + notPresent = tableNames L.\\ existingTableNames if (not . null $ absent) || (not . null $ notPresent) then do - mapM_ (logInfo_ . (<+>) "Unknown entry in 'table_versions':") absent - mapM_ (logInfo_ . (<+>) "Table not present in the 'table_versions':") - notPresent - return $ - validateIsNull "Unknown entry in table_versions':" absent <> - validateIsNull "Tables not present in the 'table_versions':" notPresent + mapM_ (logInfo_ . (<+>) "Unknown entry in 'table_versions':") absent + mapM_ + (logInfo_ . (<+>) "Table not present in the 'table_versions':") + notPresent + return $ + validateIsNull "Unknown entry in table_versions':" absent + <> validateIsNull "Tables not present in the 'table_versions':" notPresent else return mempty - -checkDomainsStructure :: (MonadDB m, MonadThrow m) - => [Domain] -> m ValidationResult +checkDomainsStructure + :: (MonadDB m, MonadThrow m) + => [Domain] + -> m ValidationResult checkDomainsStructure defs = fmap mconcat . forM defs $ \def -> do runQuery_ . sqlSelect "pg_catalog.pg_type t1" $ do sqlResult "t1.typname::text" -- name - sqlResult "(SELECT pg_catalog.format_type(t2.oid, t2.typtypmod) \ - \FROM pg_catalog.pg_type t2 \ - \WHERE t2.oid = t1.typbasetype)" -- type + sqlResult + "(SELECT pg_catalog.format_type(t2.oid, t2.typtypmod) \ + \FROM pg_catalog.pg_type t2 \ + \WHERE t2.oid = t1.typbasetype)" -- type sqlResult "NOT t1.typnotnull" -- nullable sqlResult "t1.typdefault" -- default value - sqlResult "ARRAY(SELECT c.conname::text FROM pg_catalog.pg_constraint c \ - \WHERE c.contypid = t1.oid ORDER by c.oid)" -- constraint names - sqlResult "ARRAY(SELECT regexp_replace(pg_get_constraintdef(c.oid, true), '\ - \CHECK \\((.*)\\)', '\\1') FROM pg_catalog.pg_constraint c \ - \WHERE c.contypid = t1.oid \ - \ORDER by c.oid)" -- constraint definitions - sqlResult "ARRAY(SELECT c.convalidated FROM pg_catalog.pg_constraint c \ - \WHERE c.contypid = t1.oid \ - \ORDER by c.oid)" -- are constraints validated? + sqlResult + "ARRAY(SELECT c.conname::text FROM pg_catalog.pg_constraint c \ + \WHERE c.contypid = t1.oid ORDER by c.oid)" -- constraint names + sqlResult + "ARRAY(SELECT regexp_replace(pg_get_constraintdef(c.oid, true), '\ + \CHECK \\((.*)\\)', '\\1') FROM pg_catalog.pg_constraint c \ + \WHERE c.contypid = t1.oid \ + \ORDER by c.oid)" -- constraint definitions + sqlResult + "ARRAY(SELECT c.convalidated FROM pg_catalog.pg_constraint c \ + \WHERE c.contypid = t1.oid \ + \ORDER by c.oid)" -- are constraints validated? sqlWhereEq "t1.typname" $ unRawSQL $ domName def mdom <- fetchMaybe $ \(dname, dtype, nullable, defval, cnames, conds, valids) -> Domain - { domName = unsafeSQL dname - , domType = dtype - , domNullable = nullable - , domDefault = unsafeSQL <$> defval - , domChecks = - mkChecks $ zipWith3 - (\cname cond validated -> - Check - { chkName = unsafeSQL cname - , chkCondition = unsafeSQL cond - , chkValidated = validated - }) (unArray1 cnames) (unArray1 conds) (unArray1 valids) - } + { domName = unsafeSQL dname + , domType = dtype + , domNullable = nullable + , domDefault = unsafeSQL <$> defval + , domChecks = + mkChecks $ + zipWith3 + ( \cname cond validated -> + Check + { chkName = unsafeSQL cname + , chkCondition = unsafeSQL cond + , chkValidated = validated + } + ) + (unArray1 cnames) + (unArray1 conds) + (unArray1 valids) + } return $ case mdom of Just dom - | dom /= def -> topMessage "domain" (unRawSQL $ domName dom) $ mconcat [ - compareAttr dom def "name" domName - , compareAttr dom def "type" domType - , compareAttr dom def "nullable" domNullable - , compareAttr dom def "default" domDefault - , compareAttr dom def "checks" domChecks - ] + | dom /= def -> + topMessage "domain" (unRawSQL $ domName dom) $ + mconcat + [ compareAttr dom def "name" domName + , compareAttr dom def "type" domType + , compareAttr dom def "nullable" domNullable + , compareAttr dom def "default" domDefault + , compareAttr dom def "checks" domChecks + ] | otherwise -> mempty - Nothing -> validationError $ "Domain '" <> unRawSQL (domName def) - <> "' doesn't exist in the database" + Nothing -> + validationError $ + "Domain '" + <> unRawSQL (domName def) + <> "' doesn't exist in the database" where - compareAttr :: (Eq a, Show a) - => Domain -> Domain -> Text -> (Domain -> a) -> ValidationResult + compareAttr + :: (Eq a, Show a) + => Domain + -> Domain + -> Text + -> (Domain -> a) + -> ValidationResult compareAttr dom def attrname attr | attr dom == attr def = mempty - | otherwise = validationError $ - "Attribute '" <> attrname - <> "' does not match (database:" <+> T.pack (show $ attr dom) - <> ", definition:" <+> T.pack (show $ attr def) <> ")" + | otherwise = + validationError $ + "Attribute '" + <> attrname + <> "' does not match (database:" + <+> T.pack (show $ attr dom) + <> ", definition:" + <+> T.pack (show $ attr def) + <> ")" -- | Check that the tables that must have been dropped are actually -- missing from the DB. -checkTablesWereDropped :: (MonadDB m, MonadThrow m) => - [Migration m] -> m ValidationResult +checkTablesWereDropped + :: (MonadDB m, MonadThrow m) + => [Migration m] + -> m ValidationResult checkTablesWereDropped mgrs = do - let droppedTableNames = [ mgrTableName mgr - | mgr <- mgrs, isDropTableMigration mgr ] + let droppedTableNames = + [ mgrTableName mgr + | mgr <- mgrs + , isDropTableMigration mgr + ] fmap mconcat . forM droppedTableNames $ \tblName -> do mver <- checkTableVersion (T.unpack . unRawSQL $ tblName) - return $ if isNothing mver - then mempty - else validationError $ "The table '" <> unRawSQL tblName - <> "' that must have been dropped" - <> " is still present in the database." + return $ + if isNothing mver + then mempty + else + validationError $ + "The table '" + <> unRawSQL tblName + <> "' that must have been dropped" + <> " is still present in the database." data CompositesCreationMode = CreateCompositesIfDatabaseEmpty | DontCreateComposites - deriving Eq + deriving (Eq) -- | Check that there is 1 to 1 correspondence between composite types in the -- database and the list of their code definitions. @@ -322,65 +379,87 @@ checkCompositesStructure -> ObjectsValidationMode -> [CompositeType] -> m ValidationResult -checkCompositesStructure tablesWithVersions ccm ovm compositeList = getDBCompositeTypes >>= \case - [] | noTablesPresent tablesWithVersions && ccm == CreateCompositesIfDatabaseEmpty -> do - -- DB is not initialized, create composites if there are any defined. - mapM_ (runQuery_ . sqlCreateComposite) compositeList - return mempty - dbCompositeTypes -> pure $ mconcat - [ checkNotPresentComposites - , checkDatabaseComposites - ] - where - compositeMap = M.fromList $ - map ((unRawSQL . ctName) &&& ctColumns) compositeList - - checkNotPresentComposites = - let notPresent = S.toList $ M.keysSet compositeMap - S.\\ S.fromList (map (unRawSQL . ctName) dbCompositeTypes) - in validateIsNull "Composite types not present in the database:" notPresent - - checkDatabaseComposites = mconcat . (`map` dbCompositeTypes) $ \dbComposite -> - let cname = unRawSQL $ ctName dbComposite - in case cname `M.lookup` compositeMap of - Just columns -> topMessage "composite type" cname $ - checkColumns 1 columns (ctColumns dbComposite) - Nothing -> case ovm of - AllowUnknownObjects -> mempty - DontAllowUnknownObjects -> validationError $ mconcat - [ "Composite type '" - , T.pack $ show dbComposite - , "' from the database doesn't have a corresponding code definition" - ] - where - checkColumns - :: Int -> [CompositeColumn] -> [CompositeColumn] -> ValidationResult - checkColumns _ [] [] = mempty - checkColumns _ rest [] = validationError $ - objectHasLess "Composite type" "columns" rest - checkColumns _ [] rest = validationError $ - objectHasMore "Composite type" "columns" rest - checkColumns !n (d:defs) (c:cols) = mconcat [ - validateNames $ ccName d == ccName c - , validateTypes $ ccType d == ccType c - , checkColumns (n+1) defs cols - ] - where - validateNames True = mempty - validateNames False = validationError $ - errorMsg ("no. " <> showt n) "names" (unRawSQL . ccName) - - validateTypes True = mempty - validateTypes False = validationError $ - errorMsg (unRawSQL $ ccName d) "types" (T.pack . show . ccType) - - errorMsg ident attr f = - "Column '" <> ident <> "' differs in" - <+> attr <+> "(database:" <+> f c <> ", definition:" <+> f d <> ")." +checkCompositesStructure tablesWithVersions ccm ovm compositeList = + getDBCompositeTypes >>= \case + [] | noTablesPresent tablesWithVersions && ccm == CreateCompositesIfDatabaseEmpty -> do + -- DB is not initialized, create composites if there are any defined. + mapM_ (runQuery_ . sqlCreateComposite) compositeList + return mempty + dbCompositeTypes -> + pure $ + mconcat + [ checkNotPresentComposites + , checkDatabaseComposites + ] + where + compositeMap = + M.fromList $ + map ((unRawSQL . ctName) &&& ctColumns) compositeList + + checkNotPresentComposites = + let notPresent = + S.toList $ + M.keysSet compositeMap + S.\\ S.fromList (map (unRawSQL . ctName) dbCompositeTypes) + in validateIsNull "Composite types not present in the database:" notPresent + + checkDatabaseComposites = mconcat . (`map` dbCompositeTypes) $ \dbComposite -> + let cname = unRawSQL $ ctName dbComposite + in case cname `M.lookup` compositeMap of + Just columns -> + topMessage "composite type" cname $ + checkColumns 1 columns (ctColumns dbComposite) + Nothing -> case ovm of + AllowUnknownObjects -> mempty + DontAllowUnknownObjects -> + validationError $ + mconcat + [ "Composite type '" + , T.pack $ show dbComposite + , "' from the database doesn't have a corresponding code definition" + ] + where + checkColumns + :: Int -> [CompositeColumn] -> [CompositeColumn] -> ValidationResult + checkColumns _ [] [] = mempty + checkColumns _ rest [] = + validationError $ + objectHasLess "Composite type" "columns" rest + checkColumns _ [] rest = + validationError $ + objectHasMore "Composite type" "columns" rest + checkColumns !n (d : defs) (c : cols) = + mconcat + [ validateNames $ ccName d == ccName c + , validateTypes $ ccType d == ccType c + , checkColumns (n + 1) defs cols + ] + where + validateNames True = mempty + validateNames False = + validationError $ + errorMsg ("no. " <> showt n) "names" (unRawSQL . ccName) + + validateTypes True = mempty + validateTypes False = + validationError $ + errorMsg (unRawSQL $ ccName d) "types" (T.pack . show . ccType) + + errorMsg ident attr f = + "Column '" + <> ident + <> "' differs in" + <+> attr + <+> "(database:" + <+> f c + <> ", definition:" + <+> f d + <> ")." -- | Checks whether the database is consistent. checkDBStructure - :: forall m. (MonadDB m, MonadThrow m) + :: forall m + . (MonadDB m, MonadThrow m) => ExtrasOptions -> TablesWithVersions -> m ValidationResult @@ -388,12 +467,13 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) result <- topMessage "table" (tblNameText table) <$> checkTableStructure table -- If we allow higher table versions in the database, show inconsistencies as -- info messages only. - return $ if eoAllowHigherTableVersions options && tblVersion table < version - then validationErrorsToInfos result - else result + return $ + if eoAllowHigherTableVersions options && tblVersion table < version + then validationErrorsToInfos result + else result where checkTableStructure :: Table -> m ValidationResult - checkTableStructure table@Table{..} = do + checkTableStructure table@Table {..} = do -- get table description from pg_catalog as describeTable -- mechanism from HDBC doesn't give accurate results runQuery_ $ sqlSelect "pg_catalog.pg_attribute a" $ do @@ -429,93 +509,120 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) fkeys <- fetchMany fetchForeignKey triggers <- getDBTriggers tblName checkedOverlaps <- checkOverlappingIndexes - return $ mconcat [ - checkColumns 1 tblColumns desc - , checkPrimaryKey tblPrimaryKey pk - , checkChecks tblChecks checks - , checkIndexes tblIndexes indexes - , checkForeignKeys tblForeignKeys fkeys - , checkForeignKeyIndexes tblPrimaryKey tblForeignKeys tblIndexes - , checkTriggers tblTriggers triggers - , checkedOverlaps - ] + return $ + mconcat + [ checkColumns 1 tblColumns desc + , checkPrimaryKey tblPrimaryKey pk + , checkChecks tblChecks checks + , checkIndexes tblIndexes indexes + , checkForeignKeys tblForeignKeys fkeys + , checkForeignKeyIndexes tblPrimaryKey tblForeignKeys tblIndexes + , checkTriggers tblTriggers triggers + , checkedOverlaps + ] where fetchTableColumn :: (String, ColumnType, Maybe Text, Bool, Maybe String) -> TableColumn - fetchTableColumn (name, ctype, collation, nullable, mdefault) = TableColumn { - colName = unsafeSQL name - , colType = ctype - , colCollation = flip rawSQL () <$> collation - , colNullable = nullable - , colDefault = unsafeSQL <$> mdefault - } + fetchTableColumn (name, ctype, collation, nullable, mdefault) = + TableColumn + { colName = unsafeSQL name + , colType = ctype + , colCollation = flip rawSQL () <$> collation + , colNullable = nullable + , colDefault = unsafeSQL <$> mdefault + } checkColumns :: Int -> [TableColumn] -> [TableColumn] -> ValidationResult checkColumns _ [] [] = mempty - checkColumns _ rest [] = validationError $ - objectHasLess "Table" "columns" rest - checkColumns _ [] rest = validationError $ - objectHasMore "Table" "columns" rest - checkColumns !n (d:defs) (c:cols) = mconcat [ - validateNames $ colName d == colName c - -- bigserial == bigint + autoincrement and there is no - -- distinction between them after table is created. - , validateTypes $ colType d == colType c || - (colType d == BigSerialT && colType c == BigIntT) - -- There is a problem with default values determined by - -- sequences as they're implicitly specified by db, so - -- let's omit them in such case. - , validateDefaults $ colDefault d == colDefault c || - (isNothing (colDefault d) - && (T.isPrefixOf "nextval('" . unRawSQL <$> colDefault c) - == Just True) - , validateNullables $ colNullable d == colNullable c - , checkColumns (n+1) defs cols - ] + checkColumns _ rest [] = + validationError $ + objectHasLess "Table" "columns" rest + checkColumns _ [] rest = + validationError $ + objectHasMore "Table" "columns" rest + checkColumns !n (d : defs) (c : cols) = + mconcat + [ validateNames $ colName d == colName c + , -- bigserial == bigint + autoincrement and there is no + -- distinction between them after table is created. + validateTypes $ + colType d == colType c + || (colType d == BigSerialT && colType c == BigIntT) + , -- There is a problem with default values determined by + -- sequences as they're implicitly specified by db, so + -- let's omit them in such case. + validateDefaults $ + colDefault d == colDefault c + || ( isNothing (colDefault d) + && (T.isPrefixOf "nextval('" . unRawSQL <$> colDefault c) + == Just True + ) + , validateNullables $ colNullable d == colNullable c + , checkColumns (n + 1) defs cols + ] where - validateNames True = mempty - validateNames False = validationError $ - errorMsg ("no. " <> showt n) "names" (unRawSQL . colName) - - validateTypes True = mempty - validateTypes False = validationError $ - errorMsg cname "types" (T.pack . show . colType) - <+> sqlHint ("TYPE" <+> columnTypeToSQL (colType d)) - - validateNullables True = mempty - validateNullables False = validationError $ - errorMsg cname "nullables" (showt . colNullable) - <+> sqlHint ((if colNullable d then "DROP" else "SET") - <+> "NOT NULL") - - validateDefaults True = mempty - validateDefaults False = validationError $ - errorMsg cname "defaults" (showt . fmap unRawSQL . colDefault) - <+> sqlHint set_default + validateNames True = mempty + validateNames False = + validationError $ + errorMsg ("no. " <> showt n) "names" (unRawSQL . colName) + + validateTypes True = mempty + validateTypes False = + validationError $ + errorMsg cname "types" (T.pack . show . colType) + <+> sqlHint ("TYPE" <+> columnTypeToSQL (colType d)) + + validateNullables True = mempty + validateNullables False = + validationError $ + errorMsg cname "nullables" (showt . colNullable) + <+> sqlHint + ( (if colNullable d then "DROP" else "SET") + <+> "NOT NULL" + ) + + validateDefaults True = mempty + validateDefaults False = + validationError $ + errorMsg cname "defaults" (showt . fmap unRawSQL . colDefault) + <+> sqlHint set_default where set_default = case colDefault d of - Just v -> "SET DEFAULT" <+> v + Just v -> "SET DEFAULT" <+> v Nothing -> "DROP DEFAULT" cname = unRawSQL $ colName d errorMsg ident attr f = - "Column '" <> ident <> "' differs in" - <+> attr <+> "(table:" <+> f c <> ", definition:" <+> f d <> ")." + "Column '" + <> ident + <> "' differs in" + <+> attr + <+> "(table:" + <+> f c + <> ", definition:" + <+> f d + <> ")." sqlHint sql = "(HINT: SQL for making the change is: ALTER TABLE" - <+> tblNameText table <+> "ALTER COLUMN" <+> unRawSQL (colName d) - <+> unRawSQL sql <> ")" - - checkPrimaryKey :: Maybe PrimaryKey -> Maybe (PrimaryKey, RawSQL ()) - -> ValidationResult - checkPrimaryKey mdef mpk = mconcat [ - checkEquality "PRIMARY KEY" def (map fst pk) - , checkNames (const (pkName tblName)) pk - , if eoEnforcePKs options - then checkPKPresence tblName mdef mpk - else mempty - ] + <+> tblNameText table + <+> "ALTER COLUMN" + <+> unRawSQL (colName d) + <+> unRawSQL sql + <> ")" + + checkPrimaryKey + :: Maybe PrimaryKey + -> Maybe (PrimaryKey, RawSQL ()) + -> ValidationResult + checkPrimaryKey mdef mpk = + mconcat + [ checkEquality "PRIMARY KEY" def (map fst pk) + , checkNames (const (pkName tblName)) pk + , if eoEnforcePKs options + then checkPKPresence tblName mdef mpk + else mempty + ] where def = maybeToList mdef pk = maybeToList mpk @@ -524,85 +631,94 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) checkChecks defs checks = mapValidationResult id mapErrs (checkEquality "CHECKs" defs checks) where - mapErrs [] = [] - mapErrs errmsgs = errmsgs <> - [ " (HINT: If checks are equal modulo number of \ - \ parentheses/whitespaces used in conditions, \ - \ just copy and paste expected output into source code)" - ] - - checkIndexes :: [TableIndex] -> [(TableIndex, RawSQL ())] - -> ValidationResult - checkIndexes defs allIndexes = mconcat - $ checkEquality "INDEXes" defs (map fst indexes) - : checkNames (indexName tblName) indexes - : map localIndexInfo localIndexes + mapErrs [] = [] + mapErrs errmsgs = + errmsgs + <> [ " (HINT: If checks are equal modulo number of \ + \ parentheses/whitespaces used in conditions, \ + \ just copy and paste expected output into source code)" + ] + + checkIndexes + :: [TableIndex] + -> [(TableIndex, RawSQL ())] + -> ValidationResult + checkIndexes defs allIndexes = + mconcat $ + checkEquality "INDEXes" defs (map fst indexes) + : checkNames (indexName tblName) indexes + : map localIndexInfo localIndexes where - localIndexInfo (index, name) = validationInfo $ T.concat - [ "Found a local index '" - , unRawSQL name - , "': " - , T.pack (show index) - ] + localIndexInfo (index, name) = + validationInfo $ + T.concat + [ "Found a local index '" + , unRawSQL name + , "': " + , T.pack (show index) + ] (localIndexes, indexes) = (`partition` allIndexes) $ \(_, name) -> -- Manually created indexes for ad-hoc improvements. - "local_" `T.isPrefixOf` unRawSQL name - -- Indexes related to the REINDEX operation, see - -- https://www.postgresql.org/docs/15/sql-reindex.html - || "_ccnew" `T.isSuffixOf` unRawSQL name - || "_ccold" `T.isSuffixOf` unRawSQL name - - checkForeignKeys :: [ForeignKey] -> [(ForeignKey, RawSQL ())] - -> ValidationResult - checkForeignKeys defs fkeys = mconcat [ - checkEquality "FOREIGN KEYs" defs (map fst fkeys) - , checkNames (fkName tblName) fkeys - ] + "local_" `T.isPrefixOf` unRawSQL name + -- Indexes related to the REINDEX operation, see + -- https://www.postgresql.org/docs/15/sql-reindex.html + || "_ccnew" `T.isSuffixOf` unRawSQL name + || "_ccold" `T.isSuffixOf` unRawSQL name + + checkForeignKeys + :: [ForeignKey] + -> [(ForeignKey, RawSQL ())] + -> ValidationResult + checkForeignKeys defs fkeys = + mconcat + [ checkEquality "FOREIGN KEYs" defs (map fst fkeys) + , checkNames (fkName tblName) fkeys + ] checkForeignKeyIndexes :: Maybe PrimaryKey -> [ForeignKey] -> [TableIndex] -> ValidationResult checkForeignKeyIndexes pkey foreignKeys indexes = if eoCheckForeignKeysIndexes options - then foldMap' go foreignKeys - else mempty + then foldMap' go foreignKeys + else mempty where - -- Map index on the given table name to a list of list of names -- so that index on a and index on (b, c) becomes [[a], [b, c,]]. allIndexes :: [[RawSQL ()]] allIndexes = fmap (fmap indexColumnName . idxColumns) . filter (isNothing . idxWhere) $ indexes allCoverage :: [[RawSQL ()]] - allCoverage = maybe [] pkColumns pkey:allIndexes - + allCoverage = maybe [] pkColumns pkey : allIndexes -- A foreign key is covered if it is a prefix of a list of indices. -- So a FK on a is covered by an index on (a, b) but not an index on (b, a). coveredFK :: ForeignKey -> [[RawSQL ()]] -> Bool coveredFK fk = any (\idx -> fkColumns fk `L.isPrefixOf` idx) go :: ForeignKey -> ValidationResult - go fk = let columns = map unRawSQL (fkColumns fk) - in if coveredFK fk allCoverage - then mempty - else validationError $ mconcat ["\n ● Foreign key '(", T.intercalate "," columns, ")' is missing an index"] + go fk = + let columns = map unRawSQL (fkColumns fk) + in if coveredFK fk allCoverage + then mempty + else validationError $ mconcat ["\n ● Foreign key '(", T.intercalate "," columns, ")' is missing an index"] checkTriggers :: [Trigger] -> [(Trigger, RawSQL ())] -> ValidationResult checkTriggers defs triggers = mapValidationResult id mapErrs $ checkEquality "TRIGGERs" defs' triggers where defs' = map (\t -> (t, triggerFunctionMakeName $ triggerName t)) defs - mapErrs [] = [] - mapErrs errmsgs = errmsgs <> - [ "(HINT: If WHEN clauses are equal modulo number of parentheses, whitespace, \ - \case of variables or type casts used in conditions, just copy and paste \ - \expected output into source code.)" - ] - - checkOverlappingIndexes :: (MonadDB m) => m ValidationResult + mapErrs [] = [] + mapErrs errmsgs = + errmsgs + <> [ "(HINT: If WHEN clauses are equal modulo number of parentheses, whitespace, \ + \case of variables or type casts used in conditions, just copy and paste \ + \expected output into source code.)" + ] + + checkOverlappingIndexes :: MonadDB m => m ValidationResult checkOverlappingIndexes = if eoCheckOverlappingIndexes options - then go - else pure mempty + then go + else pure mempty where go = do let handleOverlap (contained, contains) = @@ -614,9 +730,10 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) ] runSQL_ checkOverlappingIndexesQuery overlaps <- fetchMany handleOverlap - pure $ if null overlaps - then mempty - else validationError . T.unlines $ "Some indexes are overlapping" : overlaps + pure $ + if null overlaps + then mempty + else validationError . T.unlines $ "Some indexes are overlapping" : overlaps -- | Checks whether database is consistent, performing migrations if -- necessary. Requires all table names to be in lower case. @@ -627,8 +744,12 @@ checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) -- * all 'mgrFrom' are less than table version number of the table in -- the 'tables' list checkDBConsistency - :: forall m. (MonadIO m, MonadDB m, MonadLog m, MonadMask m) - => ExtrasOptions -> [Domain] -> TablesWithVersions -> [Migration m] + :: forall m + . (MonadIO m, MonadDB m, MonadLog m, MonadMask m) + => ExtrasOptions + -> [Domain] + -> TablesWithVersions + -> [Migration m] -> m () checkDBConsistency options domains tablesWithVersions migrations = do autoTransaction <- tsAutoTransaction <$> getTransactionSettings @@ -642,7 +763,6 @@ checkDBConsistency options domains tablesWithVersions migrations = do dbTablesWithVersions <- getDBTableVersions if noTablesPresent tablesWithVersions - -- No tables are present, create everything from scratch. then do createDBSchema @@ -651,85 +771,105 @@ checkDBConsistency options domains tablesWithVersions migrations = do -- Migration mode. else do -- Additional validity checks for the migrations list. - validateMigrationsAgainstDB [ (tblName table, tblVersion table, actualVer) - | (table, actualVer) <- tablesWithVersions ] + validateMigrationsAgainstDB + [ (tblName table, tblVersion table, actualVer) + | (table, actualVer) <- tablesWithVersions + ] validateDropTableMigrationsAgainstDB dbTablesWithVersions -- Run migrations, if necessary. runMigrations dbTablesWithVersions - where tables = map fst tablesWithVersions errorInvalidMigrations :: HasCallStack => [RawSQL ()] -> a errorInvalidMigrations tblNames = - error $ "checkDBConsistency: invalid migrations for tables" - <+> L.intercalate ", " (map (T.unpack . unRawSQL) tblNames) + error $ + "checkDBConsistency: invalid migrations for tables" + <+> L.intercalate ", " (map (T.unpack . unRawSQL) tblNames) checkMigrationsListValidity :: Table -> [Int32] -> [Int32] -> m () - checkMigrationsListValidity table presentMigrationVersions + checkMigrationsListValidity + table + presentMigrationVersions expectedMigrationVersions = do - when (presentMigrationVersions /= expectedMigrationVersions) $ do - logAttention "Migrations are invalid" $ object [ - "table" .= tblNameText table - , "migration_versions" .= presentMigrationVersions - , "expected_migration_versions" .= expectedMigrationVersions - ] - errorInvalidMigrations [tblName table] + when (presentMigrationVersions /= expectedMigrationVersions) $ do + logAttention "Migrations are invalid" $ + object + [ "table" .= tblNameText table + , "migration_versions" .= presentMigrationVersions + , "expected_migration_versions" .= expectedMigrationVersions + ] + errorInvalidMigrations [tblName table] validateMigrations :: m () validateMigrations = forM_ tables $ \table -> do -- FIXME: https://github.com/scrive/hpqtypes-extras/issues/73 - let presentMigrationVersions - = [ mgrFrom | Migration{..} <- migrations - , mgrTableName == tblName table ] - expectedMigrationVersions - = reverse $ take (length presentMigrationVersions) $ - reverse [0 .. tblVersion table - 1] - checkMigrationsListValidity table presentMigrationVersions + let presentMigrationVersions = + [ mgrFrom | Migration {..} <- migrations, mgrTableName == tblName table + ] + expectedMigrationVersions = + reverse $ + take (length presentMigrationVersions) $ + reverse [0 .. tblVersion table - 1] + checkMigrationsListValidity + table + presentMigrationVersions expectedMigrationVersions validateDropTableMigrations :: m () validateDropTableMigrations = do let droppedTableNames = - [ mgrTableName mgr | mgr <- migrations - , isDropTableMigration mgr ] + [ mgrTableName mgr | mgr <- migrations, isDropTableMigration mgr + ] tableNames = - [ tblName tbl | tbl <- tables ] + [tblName tbl | tbl <- tables] -- Check that the intersection between the 'tables' list and -- dropped tables is empty. let intersection = L.intersect droppedTableNames tableNames unless (null intersection) $ do - logAttention ("The intersection between tables " - <> "and dropped tables is not empty") - $ object - [ "intersection" .= map unRawSQL intersection ] - errorInvalidMigrations [ tblName tbl - | tbl <- tables - , tblName tbl `elem` intersection ] + logAttention + ( "The intersection between tables " + <> "and dropped tables is not empty" + ) + $ object + ["intersection" .= map unRawSQL intersection] + errorInvalidMigrations + [ tblName tbl + | tbl <- tables + , tblName tbl `elem` intersection + ] -- Check that if a list of migrations for a given table has a -- drop table migration, it is unique and is the last migration -- in the list. - let migrationsByTable = L.groupBy ((==) `on` mgrTableName) - migrations - dropMigrationLists = [ mgrs | mgrs <- migrationsByTable - , any isDropTableMigration mgrs ] + let migrationsByTable = + L.groupBy + ((==) `on` mgrTableName) + migrations + dropMigrationLists = + [ mgrs | mgrs <- migrationsByTable, any isDropTableMigration mgrs + ] invalidMigrationLists = - [ mgrs | mgrs <- dropMigrationLists - , (not . isDropTableMigration . last $ mgrs) || - (length . filter isDropTableMigration $ mgrs) > 1 ] + [ mgrs | mgrs <- dropMigrationLists, (not . isDropTableMigration . last $ mgrs) + || (length . filter isDropTableMigration $ mgrs) > 1 + ] unless (null invalidMigrationLists) $ do let tablesWithInvalidMigrationLists = - [ mgrTableName mgr | mgrs <- invalidMigrationLists - , let mgr = head mgrs ] - logAttention ("Migration lists for some tables contain " - <> "either multiple drop table migrations or " - <> "a drop table migration in non-tail position.") - $ object [ "tables" .= - [ unRawSQL tblName - | tblName <- tablesWithInvalidMigrationLists ] ] + [ mgrTableName mgr | mgrs <- invalidMigrationLists, let mgr = head mgrs + ] + logAttention + ( "Migration lists for some tables contain " + <> "either multiple drop table migrations or " + <> "a drop table migration in non-tail position." + ) + $ object + [ "tables" + .= [ unRawSQL tblName + | tblName <- tablesWithInvalidMigrationLists + ] + ] errorInvalidMigrations tablesWithInvalidMigrationLists createDBSchema :: m () @@ -753,24 +893,34 @@ checkDBConsistency options domains tablesWithVersions migrations = do initialSetup tis logInfo_ "Done." - -- | Input is a list of (table name, expected version, actual + -- \| Input is a list of (table name, expected version, actual -- version) triples. validateMigrationsAgainstDB :: [(RawSQL (), Int32, Int32)] -> m () - validateMigrationsAgainstDB tablesWithVersions_ - = forM_ tablesWithVersions_ $ \(tableName, expectedVer, actualVer) -> + validateMigrationsAgainstDB tablesWithVersions_ = + forM_ tablesWithVersions_ $ \(tableName, expectedVer, actualVer) -> when (expectedVer /= actualVer) $ - case [ m | m@Migration{..} <- migrations - , mgrTableName == tableName ] of - [] -> - error $ "checkDBConsistency: no migrations found for table '" - ++ (T.unpack . unRawSQL $ tableName) ++ "', cannot migrate " - ++ show actualVer ++ " -> " ++ show expectedVer - (m:_) | mgrFrom m > actualVer -> - error $ "checkDBConsistency: earliest migration for table '" - ++ (T.unpack . unRawSQL $ tableName) ++ "' is from version " - ++ show (mgrFrom m) ++ ", cannot migrate " - ++ show actualVer ++ " -> " ++ show expectedVer - | otherwise -> return () + case [ m | m@Migration {..} <- migrations, mgrTableName == tableName + ] of + [] -> + error $ + "checkDBConsistency: no migrations found for table '" + ++ (T.unpack . unRawSQL $ tableName) + ++ "', cannot migrate " + ++ show actualVer + ++ " -> " + ++ show expectedVer + (m : _) + | mgrFrom m > actualVer -> + error $ + "checkDBConsistency: earliest migration for table '" + ++ (T.unpack . unRawSQL $ tableName) + ++ "' is from version " + ++ show (mgrFrom m) + ++ ", cannot migrate " + ++ show actualVer + ++ " -> " + ++ show expectedVer + | otherwise -> return () validateDropTableMigrationsAgainstDB :: [(Text, Int32)] -> m () validateDropTableMigrationsAgainstDB dbTablesWithVersions = do @@ -780,7 +930,8 @@ checkDBConsistency options domains tablesWithVersions migrations = do , isDropTableMigration mgr , let tblName = mgrTableName mgr , let mver = lookup (unRawSQL tblName) dbTablesWithVersions - , isJust mver ] + , isJust mver + ] forM_ dbTablesToDropWithVersions $ \(tblName, fromVer, ver) -> when (fromVer /= ver) $ -- In case when the table we're going to drop is an old @@ -790,33 +941,39 @@ checkDBConsistency options domains tablesWithVersions migrations = do findMigrationsToRun :: [(Text, Int32)] -> [Migration m] findMigrationsToRun dbTablesWithVersions = - let tableNamesToDrop = [ mgrTableName mgr | mgr <- migrations - , isDropTableMigration mgr ] + let tableNamesToDrop = + [ mgrTableName mgr | mgr <- migrations, isDropTableMigration mgr + ] droppedEventually :: Migration m -> Bool droppedEventually mgr = mgrTableName mgr `elem` tableNamesToDrop lookupVer :: Migration m -> Maybe Int32 - lookupVer mgr = lookup (unRawSQL $ mgrTableName mgr) - dbTablesWithVersions + lookupVer mgr = + lookup + (unRawSQL $ mgrTableName mgr) + dbTablesWithVersions tableDoesNotExist = isNothing . lookupVer -- The idea here is that we find the first migration we need -- to run and then just run all migrations in order after -- that one. - migrationsToRun' = dropWhile - (\mgr -> - case lookupVer mgr of - -- Table doesn't exist in the DB. If it's a create - -- table migration and we're not going to drop the - -- table afterwards, this is our starting point. - Nothing -> not $ - (mgrFrom mgr == 0) && - (not . droppedEventually $ mgr) - -- Table exists in the DB. Run only those migrations - -- that have mgrFrom >= table version in the DB. - Just ver -> mgrFrom mgr < ver) - migrations + migrationsToRun' = + dropWhile + ( \mgr -> + case lookupVer mgr of + -- Table doesn't exist in the DB. If it's a create + -- table migration and we're not going to drop the + -- table afterwards, this is our starting point. + Nothing -> + not $ + (mgrFrom mgr == 0) + && (not . droppedEventually $ mgr) + -- Table exists in the DB. Run only those migrations + -- that have mgrFrom >= table version in the DB. + Just ver -> mgrFrom mgr < ver + ) + migrations -- Special case: also include migrations for tables that do -- not exist in the DB and ARE going to be dropped if they @@ -828,42 +985,44 @@ checkDBConsistency options domains tablesWithVersions migrations = do -- 'doSomethingTo t1', and that step depends on 't', -- 'doSomethingTo t1' will fail. So we include 'createTable -- t' and 'doSomethingTo t' as well. - l = length migrationsToRun' - initialMigrations = drop l $ reverse migrations - additionalMigrations' = takeWhile - (\mgr -> droppedEventually mgr && tableDoesNotExist mgr) - initialMigrations + l = length migrationsToRun' + initialMigrations = drop l $ reverse migrations + additionalMigrations' = + takeWhile + (\mgr -> droppedEventually mgr && tableDoesNotExist mgr) + initialMigrations -- Check that all extra migration chains we've chosen begin -- with 'createTable', otherwise skip adding them (to -- prevent raising an exception during the validation step). - additionalMigrations = - let ret = reverse additionalMigrations' + additionalMigrations = + let ret = reverse additionalMigrations' grps = L.groupBy ((==) `on` mgrTableName) ret in if any ((/=) 0 . mgrFrom . head) grps - then [] - else ret + then [] + else ret -- Also there's no point in adding these extra migrations if -- we're not running any migrations to begin with. - migrationsToRun = if not . null $ migrationsToRun' - then additionalMigrations ++ migrationsToRun' - else [] + migrationsToRun = + if not . null $ migrationsToRun' + then additionalMigrations ++ migrationsToRun' + else [] in migrationsToRun runMigration :: Migration m -> m () - runMigration Migration{..} = do + runMigration Migration {..} = do case mgrAction of StandardMigration mgrDo -> do logMigration mgrDo updateTableVersion - DropTableMigration mgrDropTableMode -> do logInfo_ $ arrListTable mgrTableName <> "drop table" - runQuery_ $ sqlDropTable mgrTableName - mgrDropTableMode + runQuery_ $ + sqlDropTable + mgrTableName + mgrDropTableMode runQuery_ $ sqlDelete "table_versions" $ do sqlWhereEq "name" (T.unpack . unRawSQL $ mgrTableName) - CreateIndexConcurrentlyMigration tname idx -> do logMigration -- We're in auto transaction mode (as ensured at the beginning of @@ -880,7 +1039,6 @@ checkDBConsistency options domains tablesWithVersions migrations = do runQuery_ $ "DROP INDEX CONCURRENTLY IF EXISTS" <+> indexName tname idx runQuery_ (sqlCreateIndexConcurrently tname idx) updateTableVersion - DropIndexConcurrentlyMigration tname idx -> do logMigration -- We're in auto transaction mode (as ensured at the beginning of @@ -891,7 +1049,6 @@ checkDBConsistency options domains tablesWithVersions migrations = do bracket_ (runSQL_ "COMMIT") (runSQL_ "BEGIN") $ do runQuery_ (sqlDropIndexConcurrently tname idx) updateTableVersion - ModifyColumnMigration tableName cursorSql updateSql batchSize -> do logMigration when (batchSize < 1000) $ do @@ -922,25 +1079,28 @@ checkDBConsistency options domains tablesWithVersions migrations = do unless (null primaryKeys) $ do updateSql primaryKeys if processed + batchSize >= vacuumThreshold - then do - bracket_ (runSQL_ "COMMIT") - (runSQL_ "BEGIN") - (runQuery_ $ "VACUUM" <+> tableName) - cursorLoop 0 - else do - commit - cursorLoop (processed + batchSize) + then do + bracket_ + (runSQL_ "COMMIT") + (runSQL_ "BEGIN") + (runQuery_ $ "VACUUM" <+> tableName) + cursorLoop 0 + else do + commit + cursorLoop (processed + batchSize) cursorLoop 0 updateTableVersion - where logMigration = do - logInfo_ $ arrListTable mgrTableName - <> showt mgrFrom <+> "->" <+> showt (succ mgrFrom) + logInfo_ $ + arrListTable mgrTableName + <> showt mgrFrom + <+> "->" + <+> showt (succ mgrFrom) updateTableVersion = do runQuery_ $ sqlUpdate "table_versions" $ do - sqlSet "version" (succ mgrFrom) + sqlSet "version" (succ mgrFrom) sqlWhereEq "name" (T.unpack . unRawSQL $ mgrTableName) -- Get the estimated number of rows of the given table. It might not @@ -977,27 +1137,30 @@ checkDBConsistency options domains tablesWithVersions migrations = do intToSQL = unsafeSQL . show lockNotAvailable :: DBException -> Maybe String - lockNotAvailable DBException{..} - | Just DetailedQueryError{..} <- cast dbeError - , qeErrorCode == LockNotAvailable = Just $ show dbeQueryContext - | otherwise = Nothing + lockNotAvailable DBException {..} + | Just DetailedQueryError {..} <- cast dbeError + , qeErrorCode == LockNotAvailable = + Just $ show dbeQueryContext + | otherwise = Nothing validateMigrationsToRun :: [Migration m] -> [(Text, Int32)] -> m () validateMigrationsToRun migrationsToRun dbTablesWithVersions = do - let migrationsToRunGrouped :: [[Migration m]] migrationsToRunGrouped = - L.groupBy ((==) `on` mgrTableName) . - L.sortOn mgrTableName $ -- NB: stable sort - migrationsToRun - - loc_common = "Database.PostgreSQL.PQTypes.Checks." - ++ "checkDBConsistency.validateMigrationsToRun" + L.groupBy ((==) `on` mgrTableName) + . L.sortOn mgrTableName + $ migrationsToRun -- NB: stable sort + loc_common = + "Database.PostgreSQL.PQTypes.Checks." + ++ "checkDBConsistency.validateMigrationsToRun" lookupDBTableVer :: [Migration m] -> Maybe Int32 lookupDBTableVer mgrGroup = - lookup (unRawSQL . mgrTableName . headExc head_err - $ mgrGroup) dbTablesWithVersions + lookup + ( unRawSQL . mgrTableName . headExc head_err $ + mgrGroup + ) + dbTablesWithVersions where head_err = loc_common ++ ".lookupDBTableVer: broken invariant" @@ -1009,8 +1172,9 @@ checkDBConsistency options domains tablesWithVersions migrations = do , dbTableVer /= (mgrFrom . headExc head_err $ mgrGroup) ] where - head_err = loc_common - ++ ".groupsWithWrongDBTableVersions: broken invariant" + head_err = + loc_common + ++ ".groupsWithWrongDBTableVersions: broken invariant" mgrGroupsNotInDB :: [[Migration m]] mgrGroupsNotInDB = @@ -1026,8 +1190,9 @@ checkDBConsistency options domains tablesWithVersions migrations = do , isDropTableMigration . headExc head_err $ mgrGroup ] where - head_err = loc_common - ++ ".groupsStartingWithDropTable: broken invariant" + head_err = + loc_common + ++ ".groupsStartingWithDropTable: broken invariant" groupsNotStartingWithCreateTable :: [[Migration m]] groupsNotStartingWithCreateTable = @@ -1036,38 +1201,41 @@ checkDBConsistency options domains tablesWithVersions migrations = do , mgrFrom (headExc head_err mgrGroup) /= 0 ] where - head_err = loc_common - ++ ".groupsNotStartingWithCreateTable: broken invariant" + head_err = + loc_common + ++ ".groupsNotStartingWithCreateTable: broken invariant" tblNames :: [[Migration m]] -> [RawSQL ()] tblNames grps = - [ mgrTableName . headExc head_err $ grp | grp <- grps ] + [mgrTableName . headExc head_err $ grp | grp <- grps] where head_err = loc_common ++ ".tblNames: broken invariant" unless (null groupsWithWrongDBTableVersions) $ do let tnms = tblNames . map fst $ groupsWithWrongDBTableVersions logAttention - ("There are migration chains selected for execution " - <> "that expect a different starting table version number " - <> "from the one in the database. " - <> "This likely means that the order of migrations is wrong.") - $ object [ "tables" .= map unRawSQL tnms ] + ( "There are migration chains selected for execution " + <> "that expect a different starting table version number " + <> "from the one in the database. " + <> "This likely means that the order of migrations is wrong." + ) + $ object ["tables" .= map unRawSQL tnms] errorInvalidMigrations tnms unless (null groupsStartingWithDropTable) $ do let tnms = tblNames groupsStartingWithDropTable - logAttention "There are drop table migrations for non-existing tables." - $ object [ "tables" .= map unRawSQL tnms ] + logAttention "There are drop table migrations for non-existing tables." $ + object ["tables" .= map unRawSQL tnms] errorInvalidMigrations tnms -- NB: the following check can break if we allow renaming tables. unless (null groupsNotStartingWithCreateTable) $ do let tnms = tblNames groupsNotStartingWithCreateTable logAttention - ("Some tables haven't been created yet, but" <> - "their migration lists don't start with a create table migration.") - $ object [ "tables" .= map unRawSQL tnms ] + ( "Some tables haven't been created yet, but" + <> "their migration lists don't start with a create table migration." + ) + $ object ["tables" .= map unRawSQL tnms] errorInvalidMigrations tnms -- | Type synonym for a list of tables along with their database versions. @@ -1084,8 +1252,9 @@ checkVersionIsAtLeast15 = do getTableVersions :: (MonadDB m, MonadThrow m) => [Table] -> m TablesWithVersions getTableVersions tbls = sequence - [ (\mver -> (tbl, fromMaybe 0 mver)) <$> checkTableVersion (tblNameString tbl) - | tbl <- tbls ] + [ (\mver -> (tbl, fromMaybe 0 mver)) <$> checkTableVersion (tblNameString tbl) + | tbl <- tbls + ] -- | Given a result of 'getTableVersions' check if no tables are present in the -- database. @@ -1099,7 +1268,8 @@ getDBTableVersions = do dbTableNames <- getDBTableNames sequence [ (\mver -> (name, fromMaybe 0 mver)) <$> checkTableVersion (T.unpack name) - | name <- dbTableNames ] + | name <- dbTableNames + ] -- | Check whether the table exists in the DB, and return 'Just' its -- version if it does, or 'Nothing' if it doesn't. @@ -1112,15 +1282,18 @@ checkTableVersion tblName = do sqlWhere "pg_catalog.pg_table_is_visible(c.oid)" if doesExist then do - runQuery_ $ "SELECT version FROM table_versions WHERE name =" - tblName + runQuery_ $ + "SELECT version FROM table_versions WHERE name =" + tblName mver <- fetchMaybe runIdentity case mver of Just ver -> return $ Just ver - Nothing -> error $ "checkTableVersion: table '" - ++ tblName - ++ "' is present in the database, " - ++ "but there is no corresponding version info in 'table_versions'." + Nothing -> + error $ + "checkTableVersion: table '" + ++ tblName + ++ "' is present in the database, " + ++ "but there is no corresponding version info in 'table_versions'." else do return Nothing @@ -1138,9 +1311,9 @@ sqlGetTableID table = parenthesize . toSQLCommand $ sqlGetPrimaryKey :: (MonadDB m, MonadThrow m) - => Table -> m (Maybe (PrimaryKey, RawSQL ())) + => Table + -> m (Maybe (PrimaryKey, RawSQL ())) sqlGetPrimaryKey table = do - (mColumnNumbers :: Maybe [Int16]) <- do runQuery_ . sqlSelect "pg_catalog.pg_constraint" $ do sqlResult "conkey" @@ -1154,14 +1327,13 @@ sqlGetPrimaryKey table = do columnNames <- do forM columnNumbers $ \k -> do runQuery_ . sqlSelect "pk_columns" $ do - sqlWith "key_series" . sqlSelect "pg_constraint as c2" $ do sqlResult "unnest(c2.conkey) as k" sqlWhereEqSql "c2.conrelid" $ sqlGetTableID table sqlWhereEq "c2.contype" 'p' sqlWith "pk_columns" . sqlSelect "key_series" $ do - sqlJoinOn "pg_catalog.pg_attribute as a" "a.attnum = key_series.k" + sqlJoinOn "pg_catalog.pg_attribute as a" "a.attnum = key_series.k" sqlResult "a.attname::text as column_name" sqlResult "key_series.k as column_order" sqlWhereEqSql "a.attrelid" $ sqlGetTableID table @@ -1175,32 +1347,36 @@ sqlGetPrimaryKey table = do sqlWhereEq "c.contype" 'p' sqlWhereEqSql "c.conrelid" $ sqlGetTableID table sqlResult "c.conname::text" - sqlResult $ Data.String.fromString - ("array['" <> mintercalate "', '" columnNames <> "']::text[]") + sqlResult $ + Data.String.fromString + ("array['" <> mintercalate "', '" columnNames <> "']::text[]") join <$> fetchMaybe fetchPrimaryKey fetchPrimaryKey :: (String, Array1 String) -> Maybe (PrimaryKey, RawSQL ()) -fetchPrimaryKey (name, Array1 columns) = (, unsafeSQL name) - <$> pkOnColumns (map unsafeSQL columns) +fetchPrimaryKey (name, Array1 columns) = + (,unsafeSQL name) + <$> pkOnColumns (map unsafeSQL columns) -- *** CHECKS *** sqlGetChecks :: Table -> SQL sqlGetChecks table = toSQLCommand . sqlSelect "pg_catalog.pg_constraint c" $ do sqlResult "c.conname::text" - sqlResult "regexp_replace(pg_get_constraintdef(c.oid, true), \ - \'CHECK \\((.*)\\)', '\\1') AS body" -- check body + sqlResult + "regexp_replace(pg_get_constraintdef(c.oid, true), \ + \'CHECK \\((.*)\\)', '\\1') AS body" -- check body sqlResult "c.convalidated" -- validated? sqlWhereEq "c.contype" 'c' sqlWhereEqSql "c.conrelid" $ sqlGetTableID table fetchTableCheck :: (String, String, Bool) -> Check -fetchTableCheck (name, condition, validated) = Check { - chkName = unsafeSQL name -, chkCondition = unsafeSQL condition -, chkValidated = validated -} +fetchTableCheck (name, condition, validated) = + Check + { chkName = unsafeSQL name + , chkCondition = unsafeSQL condition + , chkValidated = validated + } -- *** INDEXES *** sqlGetIndexes :: Bool -> Table -> SQL @@ -1210,101 +1386,122 @@ sqlGetIndexes nullsNotDistinctSupported table = toSQLCommand . sqlSelect "pg_cat sqlResult $ "ARRAY(" <> selectCoordinates "i.indnkeyatts" "i.indnatts" <> ")" -- array of included columns in the index sqlResult "am.amname::text" -- the method used (btree, gin etc) sqlResult "i.indisunique" -- is it unique? - sqlResult "i.indisvalid" -- is it valid? + sqlResult "i.indisvalid" -- is it valid? -- does it have NULLS NOT DISTINCT ? if nullsNotDistinctSupported - then sqlResult "i.indnullsnotdistinct" + then sqlResult "i.indnullsnotdistinct" else sqlResult "false" -- if partial, get constraint def sqlResult "pg_catalog.pg_get_expr(i.indpred, i.indrelid, true)" sqlJoinOn "pg_catalog.pg_index i" "c.oid = i.indexrelid" sqlJoinOn "pg_catalog.pg_am am" "c.relam = am.oid" - sqlLeftJoinOn "pg_catalog.pg_constraint r" + sqlLeftJoinOn + "pg_catalog.pg_constraint r" "r.conrelid = i.indrelid AND r.conindid = i.indexrelid" sqlWhereEqSql "i.indrelid" $ sqlGetTableID table sqlWhereIsNULL "r.contype" -- fetch only "pure" indexes where -- Get all coordinates of the index. - selectCoordinates start end = smconcat [ - "WITH RECURSIVE coordinates(k, name) AS (" - , " VALUES (" <> start <> "::integer, NULL)" - , " UNION ALL" - , " SELECT k+1, pg_catalog.pg_get_indexdef(i.indexrelid, k+1, true)" - , " FROM coordinates" - , " WHERE k < " <> end - , ")" - , "SELECT name FROM coordinates WHERE name IS NOT NULL" - ] + selectCoordinates start end = + smconcat + [ "WITH RECURSIVE coordinates(k, name) AS (" + , " VALUES (" <> start <> "::integer, NULL)" + , " UNION ALL" + , " SELECT k+1, pg_catalog.pg_get_indexdef(i.indexrelid, k+1, true)" + , " FROM coordinates" + , " WHERE k < " <> end + , ")" + , "SELECT name FROM coordinates WHERE name IS NOT NULL" + ] fetchTableIndex :: (String, Array1 String, Array1 String, String, Bool, Bool, Bool, Maybe String) -> (TableIndex, RawSQL ()) fetchTableIndex (name, Array1 keyColumns, Array1 includeColumns, method, unique, valid, nullsNotDistinct, mconstraint) = - (TableIndex - { idxColumns = map (indexColumn . unsafeSQL) keyColumns - , idxInclude = map unsafeSQL includeColumns - , idxMethod = read method - , idxUnique = unique - , idxValid = valid - , idxWhere = unsafeSQL <$> mconstraint - , idxNotDistinctNulls = nullsNotDistinct - } - , unsafeSQL name) + ( TableIndex + { idxColumns = map (indexColumn . unsafeSQL) keyColumns + , idxInclude = map unsafeSQL includeColumns + , idxMethod = read method + , idxUnique = unique + , idxValid = valid + , idxWhere = unsafeSQL <$> mconstraint + , idxNotDistinctNulls = nullsNotDistinct + } + , unsafeSQL name + ) -- *** FOREIGN KEYS *** sqlGetForeignKeys :: Table -> SQL sqlGetForeignKeys table = toSQLCommand - . sqlSelect "pg_catalog.pg_constraint r" $ do - sqlResult "r.conname::text" -- fk name - sqlResult $ - "ARRAY(SELECT a.attname::text FROM pg_catalog.pg_attribute a JOIN (" - <> unnestWithOrdinality "r.conkey" - <> ") conkeys ON (a.attnum = conkeys.item) \ - \WHERE a.attrelid = r.conrelid \ - \ORDER BY conkeys.n)" -- constrained columns - sqlResult "c.relname::text" -- referenced table - sqlResult $ "ARRAY(SELECT a.attname::text \ - \FROM pg_catalog.pg_attribute a JOIN (" - <> unnestWithOrdinality "r.confkey" - <> ") confkeys ON (a.attnum = confkeys.item) \ - \WHERE a.attrelid = r.confrelid \ - \ORDER BY confkeys.n)" -- referenced columns - sqlResult "r.confupdtype" -- on update - sqlResult "r.confdeltype" -- on delete - sqlResult "r.condeferrable" -- deferrable? - sqlResult "r.condeferred" -- initially deferred? - sqlResult "r.convalidated" -- validated? - sqlJoinOn "pg_catalog.pg_class c" "c.oid = r.confrelid" - sqlWhereEqSql "r.conrelid" $ sqlGetTableID table - sqlWhereEq "r.contype" 'f' + . sqlSelect "pg_catalog.pg_constraint r" + $ do + sqlResult "r.conname::text" -- fk name + sqlResult $ + "ARRAY(SELECT a.attname::text FROM pg_catalog.pg_attribute a JOIN (" + <> unnestWithOrdinality "r.conkey" + <> ") conkeys ON (a.attnum = conkeys.item) \ + \WHERE a.attrelid = r.conrelid \ + \ORDER BY conkeys.n)" -- constrained columns + sqlResult "c.relname::text" -- referenced table + sqlResult $ + "ARRAY(SELECT a.attname::text \ + \FROM pg_catalog.pg_attribute a JOIN (" + <> unnestWithOrdinality "r.confkey" + <> ") confkeys ON (a.attnum = confkeys.item) \ + \WHERE a.attrelid = r.confrelid \ + \ORDER BY confkeys.n)" -- referenced columns + sqlResult "r.confupdtype" -- on update + sqlResult "r.confdeltype" -- on delete + sqlResult "r.condeferrable" -- deferrable? + sqlResult "r.condeferred" -- initially deferred? + sqlResult "r.convalidated" -- validated? + sqlJoinOn "pg_catalog.pg_class c" "c.oid = r.confrelid" + sqlWhereEqSql "r.conrelid" $ sqlGetTableID table + sqlWhereEq "r.contype" 'f' where unnestWithOrdinality :: RawSQL () -> SQL unnestWithOrdinality arr = - "SELECT n, " <> raw arr - <> "[n] AS item FROM generate_subscripts(" <> raw arr <> ", 1) AS n" + "SELECT n, " + <> raw arr + <> "[n] AS item FROM generate_subscripts(" + <> raw arr + <> ", 1) AS n" -fetchForeignKey :: - (String, Array1 String, String, Array1 String, Char, Char, Bool, Bool, Bool) +fetchForeignKey + :: (String, Array1 String, String, Array1 String, Char, Char, Bool, Bool, Bool) -> (ForeignKey, RawSQL ()) fetchForeignKey - ( name, Array1 columns, reftable, Array1 refcolumns - , on_update, on_delete, deferrable, deferred, validated ) = (ForeignKey { - fkColumns = map unsafeSQL columns -, fkRefTable = unsafeSQL reftable -, fkRefColumns = map unsafeSQL refcolumns -, fkOnUpdate = charToForeignKeyAction on_update -, fkOnDelete = charToForeignKeyAction on_delete -, fkDeferrable = deferrable -, fkDeferred = deferred -, fkValidated = validated -}, unsafeSQL name) - where - charToForeignKeyAction c = case c of - 'a' -> ForeignKeyNoAction - 'r' -> ForeignKeyRestrict - 'c' -> ForeignKeyCascade - 'n' -> ForeignKeySetNull - 'd' -> ForeignKeySetDefault - _ -> error $ "fetchForeignKey: invalid foreign key action code: " - ++ show c + ( name + , Array1 columns + , reftable + , Array1 refcolumns + , on_update + , on_delete + , deferrable + , deferred + , validated + ) = + ( ForeignKey + { fkColumns = map unsafeSQL columns + , fkRefTable = unsafeSQL reftable + , fkRefColumns = map unsafeSQL refcolumns + , fkOnUpdate = charToForeignKeyAction on_update + , fkOnDelete = charToForeignKeyAction on_delete + , fkDeferrable = deferrable + , fkDeferred = deferred + , fkValidated = validated + } + , unsafeSQL name + ) + where + charToForeignKeyAction c = case c of + 'a' -> ForeignKeyNoAction + 'r' -> ForeignKeyRestrict + 'c' -> ForeignKeyCascade + 'n' -> ForeignKeySetNull + 'd' -> ForeignKeySetDefault + _ -> + error $ + "fetchForeignKey: invalid foreign key action code: " + ++ show c diff --git a/src/Database/PostgreSQL/PQTypes/Checks/Util.hs b/src/Database/PostgreSQL/PQTypes/Checks/Util.hs index 30fad79..e53f953 100644 --- a/src/Database/PostgreSQL/PQTypes/Checks/Util.hs +++ b/src/Database/PostgreSQL/PQTypes/Checks/Util.hs @@ -1,75 +1,83 @@ {-# LANGUAGE CPP #-} -module Database.PostgreSQL.PQTypes.Checks.Util ( - ValidationResult, - validationError, - validationInfo, - mapValidationResult, - validationErrorsToInfos, - resultCheck, - topMessage, - tblNameText, - tblNameString, - checkEquality, - checkNames, - checkPKPresence, - objectHasLess, - objectHasMore, - arrListTable, - checkOverlappingIndexesQuery, + +module Database.PostgreSQL.PQTypes.Checks.Util + ( ValidationResult + , validationError + , validationInfo + , mapValidationResult + , validationErrorsToInfos + , resultCheck + , topMessage + , tblNameText + , tblNameString + , checkEquality + , checkNames + , checkPKPresence + , objectHasLess + , objectHasMore + , arrListTable + , checkOverlappingIndexesQuery ) where import Control.Monad.Catch #if !MIN_VERSION_base(4,11,0) import Data.Monoid #endif +import Data.List qualified as L import Data.Monoid.Utils +import Data.Semigroup qualified as SG import Data.Text (Text) +import Data.Text qualified as T import Log import TextShow -import qualified Data.List as L -import qualified Data.Text as T -import qualified Data.Semigroup as SG -import Database.PostgreSQL.PQTypes.Model import Database.PostgreSQL.PQTypes +import Database.PostgreSQL.PQTypes.Model -- | A (potentially empty) list of info/error messages. data ValidationResult = ValidationResult - { vrInfos :: [Text] + { vrInfos :: [Text] , vrErrors :: [Text] } validationError :: Text -> ValidationResult -validationError err = mempty { vrErrors = [err] } +validationError err = mempty {vrErrors = [err]} validationInfo :: Text -> ValidationResult -validationInfo msg = mempty { vrInfos = [msg] } +validationInfo msg = mempty {vrInfos = [msg]} -- | Downgrade all error messages in a ValidationResult to info messages. validationErrorsToInfos :: ValidationResult -> ValidationResult -validationErrorsToInfos ValidationResult{..} = - mempty { vrInfos = vrInfos <> vrErrors } +validationErrorsToInfos ValidationResult {..} = + mempty {vrInfos = vrInfos <> vrErrors} -mapValidationResult :: - ([Text] -> [Text]) -> ([Text] -> [Text]) -> ValidationResult -> ValidationResult -mapValidationResult mapInfos mapErrs ValidationResult{..} = - mempty { vrInfos = mapInfos vrInfos, vrErrors = mapErrs vrErrors } +mapValidationResult + :: ([Text] -> [Text]) -> ([Text] -> [Text]) -> ValidationResult -> ValidationResult +mapValidationResult mapInfos mapErrs ValidationResult {..} = + mempty {vrInfos = mapInfos vrInfos, vrErrors = mapErrs vrErrors} instance SG.Semigroup ValidationResult where - (ValidationResult infos0 errs0) <> (ValidationResult infos1 errs1) - = ValidationResult (infos0 <> infos1) (errs0 <> errs1) + (ValidationResult infos0 errs0) <> (ValidationResult infos1 errs1) = + ValidationResult (infos0 <> infos1) (errs0 <> errs1) instance Monoid ValidationResult where - mempty = ValidationResult [] [] + mempty = ValidationResult [] [] mappend = (SG.<>) topMessage :: Text -> Text -> ValidationResult -> ValidationResult -topMessage objtype objname vr@ValidationResult{..} = +topMessage objtype objname vr@ValidationResult {..} = case vrErrors of [] -> vr - es -> ValidationResult vrInfos - ("There are problems with the" <+> - objtype <+> "'" <> objname <> "'" : es) + es -> + ValidationResult + vrInfos + ( "There are problems with the" + <+> objtype + <+> "'" + <> objname + <> "'" + : es + ) -- | Log all messages in a 'ValidationResult', and fail if any of them -- were errors. @@ -77,10 +85,10 @@ resultCheck :: (MonadLog m, MonadThrow m) => ValidationResult -> m () -resultCheck ValidationResult{..} = do +resultCheck ValidationResult {..} = do mapM_ logInfo_ vrInfos case vrErrors of - [] -> return () + [] -> return () msgs -> do mapM_ logAttention_ msgs error "resultCheck: validation failed" @@ -96,19 +104,21 @@ tblNameString = T.unpack . tblNameText checkEquality :: (Eq t, Show t) => Text -> [t] -> [t] -> ValidationResult checkEquality pname defs props = case (defs L.\\ props, props L.\\ defs) of ([], []) -> mempty - (def_diff, db_diff) -> validationError $ mconcat [ - "Table and its definition have diverged and have " - , showt $ length db_diff - , " and " - , showt $ length def_diff - , " different " - , pname - , " each, respectively:\n" - , " ● table:" - , showDiff db_diff - , "\n ● definition:" - , showDiff def_diff - ] + (def_diff, db_diff) -> + validationError $ + mconcat + [ "Table and its definition have diverged and have " + , showt $ length db_diff + , " and " + , showt $ length def_diff + , " different " + , pname + , " each, respectively:\n" + , " ● table:" + , showDiff db_diff + , "\n ● definition:" + , showDiff def_diff + ] where showDiff = mconcat . map (("\n ○ " <>) . T.pack . show) @@ -118,52 +128,64 @@ checkNames prop_name = mconcat . map check check (prop, name) = case prop_name prop of pname | pname == name -> mempty - | otherwise -> validationError . mconcat $ [ - "Property " - , T.pack $ show prop - , " has invalid name (expected: " - , unRawSQL pname - , ", given: " - , unRawSQL name - , ")." - ] + | otherwise -> + validationError . mconcat $ + [ "Property " + , T.pack $ show prop + , " has invalid name (expected: " + , unRawSQL pname + , ", given: " + , unRawSQL name + , ")." + ] -- | Check presence of primary key on the named table. We cover all the cases so -- this could be used standalone, but note that the those where the table source -- definition and the table in the database differ in this respect is also -- covered by @checkEquality@. -checkPKPresence :: RawSQL () - -- ^ The name of the table to check for presence of primary key - -> Maybe PrimaryKey - -- ^ A possible primary key gotten from the table data structure - -> Maybe (PrimaryKey, RawSQL ()) - -- ^ A possible primary key as retrieved from database along - -- with its name - -> ValidationResult +checkPKPresence + :: RawSQL () + -- ^ The name of the table to check for presence of primary key + -> Maybe PrimaryKey + -- ^ A possible primary key gotten from the table data structure + -> Maybe (PrimaryKey, RawSQL ()) + -- ^ A possible primary key as retrieved from database along + -- with its name + -> ValidationResult checkPKPresence tableName mdef mpk = case (mdef, mpk) of (Nothing, Nothing) -> valRes [noSrc, noTbl] - (Nothing, Just _) -> valRes [noSrc] - (Just _, Nothing) -> valRes [noTbl] - _ -> mempty + (Nothing, Just _) -> valRes [noSrc] + (Just _, Nothing) -> valRes [noTbl] + _ -> mempty where noSrc = "no source definition" noTbl = "no table definition" valRes msgs = - validationError . mconcat $ - [ "Table ", unRawSQL tableName + validationError . mconcat $ + [ "Table " + , unRawSQL tableName , " has no primary key defined " - , " (" <> mintercalate ", " msgs <> ")"] + , " (" <> mintercalate ", " msgs <> ")" + ] objectHasLess :: Show t => Text -> Text -> t -> Text objectHasLess otype ptype missing = - otype <+> "in the database has *less*" <+> ptype <+> - "than its definition (missing:" <+> T.pack (show missing) <> ")" + otype + <+> "in the database has *less*" + <+> ptype + <+> "than its definition (missing:" + <+> T.pack (show missing) + <> ")" objectHasMore :: Show t => Text -> Text -> t -> Text objectHasMore otype ptype extra = - otype <+> "in the database has *more*" <+> ptype <+> - "than its definition (extra:" <+> T.pack (show extra) <> ")" + otype + <+> "in the database has *more*" + <+> ptype + <+> "than its definition (extra:" + <+> T.pack (show extra) + <> ")" arrListTable :: RawSQL () -> Text arrListTable tableName = " ->" <+> unRawSQL tableName <> ": " @@ -171,38 +193,38 @@ arrListTable tableName = " ->" <+> unRawSQL tableName <> ": " checkOverlappingIndexesQuery :: SQL checkOverlappingIndexesQuery = smconcat - [ "WITH", - -- get predicates (WHERE clause) definition in text format (ugly but the parsed version + [ "WITH" + , -- get predicates (WHERE clause) definition in text format (ugly but the parsed version -- can differ even if the predicate is the same), ignore functional indexes at the same time -- as that would make this query very ugly - " indexdata1 AS (SELECT *", - " , ((regexp_match(pg_get_indexdef(indexrelid)", - " , 'WHERE (.*)$')))[1] AS preddef", - " FROM pg_index", - " WHERE indexprs IS NULL)", - -- add the rest of metadata and do the join - " , indexdata2 AS (SELECT t1.*", - " , pg_get_indexdef(t1.indexrelid) AS contained", - " , pg_get_indexdef(t2.indexrelid) AS contains", - " , array_to_string(t1.indkey, '+') AS colindex", - " , array_to_string(t2.indkey, '+') AS colotherindex", - " , t2.indexrelid AS other_index", - " , t2.indisunique AS other_indisunique", - " , t2.preddef AS other_preddef", - -- cross join all indexes on the same table to try all combination (except oneself) - " FROM indexdata1 AS t1", - " INNER JOIN indexdata1 AS t2 ON t1.indrelid = t2.indrelid", - " AND t1.indexrelid <> t2.indexrelid)", - " SELECT contained", - " , contains", - " FROM indexdata2", - -- The indexes are the same or the "other" is larger than us - " WHERE (colotherindex = colindex", - " OR colotherindex LIKE colindex || '+%')", - -- and we have the same predicate - " AND other_preddef IS NOT DISTINCT FROM preddef", - -- and either the other is unique (so better than us) or none of us is unique - " AND (other_indisunique", - " OR (NOT other_indisunique", - " AND NOT indisunique));" + " indexdata1 AS (SELECT *" + , " , ((regexp_match(pg_get_indexdef(indexrelid)" + , " , 'WHERE (.*)$')))[1] AS preddef" + , " FROM pg_index" + , " WHERE indexprs IS NULL)" + , -- add the rest of metadata and do the join + " , indexdata2 AS (SELECT t1.*" + , " , pg_get_indexdef(t1.indexrelid) AS contained" + , " , pg_get_indexdef(t2.indexrelid) AS contains" + , " , array_to_string(t1.indkey, '+') AS colindex" + , " , array_to_string(t2.indkey, '+') AS colotherindex" + , " , t2.indexrelid AS other_index" + , " , t2.indisunique AS other_indisunique" + , " , t2.preddef AS other_preddef" + , -- cross join all indexes on the same table to try all combination (except oneself) + " FROM indexdata1 AS t1" + , " INNER JOIN indexdata1 AS t2 ON t1.indrelid = t2.indrelid" + , " AND t1.indexrelid <> t2.indexrelid)" + , " SELECT contained" + , " , contains" + , " FROM indexdata2" + , -- The indexes are the same or the "other" is larger than us + " WHERE (colotherindex = colindex" + , " OR colotherindex LIKE colindex || '+%')" + , -- and we have the same predicate + " AND other_preddef IS NOT DISTINCT FROM preddef" + , -- and either the other is unique (so better than us) or none of us is unique + " AND (other_indisunique" + , " OR (NOT other_indisunique" + , " AND NOT indisunique));" ] diff --git a/src/Database/PostgreSQL/PQTypes/Deriving.hs b/src/Database/PostgreSQL/PQTypes/Deriving.hs index f26400a..fbf358e 100644 --- a/src/Database/PostgreSQL/PQTypes/Deriving.hs +++ b/src/Database/PostgreSQL/PQTypes/Deriving.hs @@ -1,22 +1,24 @@ {-# LANGUAGE AllowAmbiguousTypes #-} -module Database.PostgreSQL.PQTypes.Deriving ( - -- * Helpers, to be used with @deriving via@ (@-XDerivingVia@). - SQLEnum(..) - , EnumEncoding(..) - , SQLEnumAsText(..) - , EnumAsTextEncoding(..) + +module Database.PostgreSQL.PQTypes.Deriving + ( -- * Helpers, to be used with @deriving via@ (@-XDerivingVia@). + SQLEnum (..) + , EnumEncoding (..) + , SQLEnumAsText (..) + , EnumAsTextEncoding (..) + -- * For use in doctests. , isInjective ) where -import Control.Exception (SomeException(..), throwIO) +import Control.Exception (SomeException (..), throwIO) import Data.List.Extra (enumerate, nubSort) import Data.Map.Strict (Map) +import Data.Map.Strict qualified as Map import Data.Text (Text) import Data.Typeable import Database.PostgreSQL.PQTypes import Foreign.Storable -import qualified Data.Map.Strict as Map -- | Helper newtype to be used with @deriving via@ to derive @(PQFormat, ToSQL, -- FromSQL)@ instances for enums, given an instance of 'EnumEncoding'. @@ -59,11 +61,14 @@ class ( -- The semantic type needs to be finitely enumerable. Enum a , Bounded a - -- The base type needs to be enumerable and ordered. - , Enum (EnumBase a) + , -- The base type needs to be enumerable and ordered. + Enum (EnumBase a) , Ord (EnumBase a) - ) => EnumEncoding a where + ) => + EnumEncoding a + where type EnumBase a + -- | Encode @a@ as a base type. encodeEnum :: a -> EnumBase a @@ -73,13 +78,14 @@ class -- /Note:/ The default implementation looks up values in 'decodeEnumMap' and -- can be overwritten for performance if necessary. decodeEnum :: EnumBase a -> Either [(EnumBase a, EnumBase a)] a - decodeEnum b = maybe (Left . intervals $ Map.keys (decodeEnumMap @a)) Right - $ Map.lookup b (decodeEnumMap @a) + decodeEnum b = + maybe (Left . intervals $ Map.keys (decodeEnumMap @a)) Right $ + Map.lookup b (decodeEnumMap @a) -- | Include the inverse map as a top-level part of the 'EnumEncoding' -- instance to ensure it is only computed once by GHC. decodeEnumMap :: Map (EnumBase a) a - decodeEnumMap = Map.fromList [ (encodeEnum a, a) | a <- enumerate ] + decodeEnumMap = Map.fromList [(encodeEnum a, a) | a <- enumerate] instance PQFormat (EnumBase a) => PQFormat (SQLEnum a) where pqFormat = pqFormat @(EnumBase a) @@ -88,7 +94,9 @@ instance ( EnumEncoding a , PQFormat (EnumBase a) , ToSQL (EnumBase a) - ) => ToSQL (SQLEnum a) where + ) + => ToSQL (SQLEnum a) + where type PQDest (SQLEnum a) = PQDest (EnumBase a) toSQL (SQLEnum a) = toSQL $ encodeEnum a @@ -99,15 +107,20 @@ instance , FromSQL (EnumBase a) , Show (EnumBase a) , Typeable (EnumBase a) - ) => FromSQL (SQLEnum a) where + ) + => FromSQL (SQLEnum a) + where type PQBase (SQLEnum a) = PQBase (EnumBase a) fromSQL base = do b <- fromSQL base case decodeEnum b of - Left validRange -> throwIO $ SomeException RangeError - { reRange = validRange - , reValue = b - } + Left validRange -> + throwIO $ + SomeException + RangeError + { reRange = validRange + , reValue = b + } Right a -> return $ SQLEnum a -- | A special case of 'SQLEnum', where the enum is to be encoded as text @@ -152,13 +165,14 @@ class (Enum a, Bounded a) => EnumAsTextEncoding a where -- /Note:/ The default implementation looks up values in 'decodeEnumAsTextMap' -- and can be overwritten for performance if necessary. decodeEnumAsText :: Text -> Either [Text] a - decodeEnumAsText text = maybe (Left $ Map.keys (decodeEnumAsTextMap @a)) Right - $ Map.lookup text (decodeEnumAsTextMap @a) + decodeEnumAsText text = + maybe (Left $ Map.keys (decodeEnumAsTextMap @a)) Right $ + Map.lookup text (decodeEnumAsTextMap @a) -- | Include the inverse map as a top-level part of the 'SQLEnumTextEncoding' -- instance to ensure it is only computed once by GHC. decodeEnumAsTextMap :: Map Text a - decodeEnumAsTextMap = Map.fromList [ (encodeEnumAsText a, a) | a <- enumerate ] + decodeEnumAsTextMap = Map.fromList [(encodeEnumAsText a, a) | a <- enumerate] instance EnumAsTextEncoding a => PQFormat (SQLEnumAsText a) where pqFormat = pqFormat @Text @@ -172,10 +186,13 @@ instance EnumAsTextEncoding a => FromSQL (SQLEnumAsText a) where fromSQL base = do text <- fromSQL base case decodeEnumAsText text of - Left validValues -> throwIO $ SomeException InvalidValue - { ivValue = text - , ivValidValues = Just validValues - } + Left validValues -> + throwIO $ + SomeException + InvalidValue + { ivValue = text + , ivValidValues = Just validValues + } Right a -> return $ SQLEnumAsText a -- | To be used in doctests to prove injectivity of encoding functions. @@ -186,7 +203,7 @@ instance EnumAsTextEncoding a => FromSQL (SQLEnumAsText a) where -- >>> isInjective (\(_ :: Bool) -> False) -- False isInjective :: (Enum a, Bounded a, Eq a, Eq b) => (a -> b) -> Bool -isInjective f = null [ (a, b) | a <- enumerate, b <- enumerate, a /= b, f a == f b ] +isInjective f = null [(a, b) | a <- enumerate, b <- enumerate, a /= b, f a == f b] -- | Internal helper: given a list of values, decompose it into a list of -- intervals. @@ -195,16 +212,17 @@ isInjective f = null [ (a, b) | a <- enumerate, b <- enumerate, a /= b, f a == f -- [(-1,3),(42,43),(88,88)] -- -- prop> nubSort xs == concatMap (\(l,r) -> [l .. r]) (intervals xs) -intervals :: forall a . (Enum a, Ord a) => [a] -> [(a, a)] +intervals :: forall a. (Enum a, Ord a) => [a] -> [(a, a)] intervals as = case nubSort as of [] -> [] (first : ascendingRest) -> accumIntervals (first, first) ascendingRest where accumIntervals :: (a, a) -> [a] -> [(a, a)] accumIntervals (lower, upper) [] = [(lower, upper)] - accumIntervals (lower, upper) (first' : ascendingRest') = if succ upper == first' - then accumIntervals (lower, first') ascendingRest' - else (lower, upper) : accumIntervals (first', first') ascendingRest' + accumIntervals (lower, upper) (first' : ascendingRest') = + if succ upper == first' + then accumIntervals (lower, first') ascendingRest' + else (lower, upper) : accumIntervals (first', first') ascendingRest' -- $setup -- >>> import Data.Int diff --git a/src/Database/PostgreSQL/PQTypes/ExtrasOptions.hs b/src/Database/PostgreSQL/PQTypes/ExtrasOptions.hs index 1727231..e36955d 100644 --- a/src/Database/PostgreSQL/PQTypes/ExtrasOptions.hs +++ b/src/Database/PostgreSQL/PQTypes/ExtrasOptions.hs @@ -1,34 +1,36 @@ module Database.PostgreSQL.PQTypes.ExtrasOptions - ( ExtrasOptions(..) + ( ExtrasOptions (..) , defaultExtrasOptions - , ObjectsValidationMode(..) + , ObjectsValidationMode (..) ) where -data ExtrasOptions = - ExtrasOptions - { eoLockTimeoutMs :: !(Maybe Int) - , eoEnforcePKs :: !Bool - -- ^ Validate that every handled table has a primary key - , eoObjectsValidationMode :: !ObjectsValidationMode - -- ^ Validation mode for unknown tables and composite types. - , eoAllowHigherTableVersions :: !Bool - -- ^ Whether to allow tables in the database to have higher versions than - -- the one in the code definition. - , eoCheckForeignKeysIndexes :: !Bool - -- ^ Check if all foreign keys have indexes. - , eoCheckOverlappingIndexes :: !Bool - -- ^ Check if some indexes are redundant - } deriving Eq +data ExtrasOptions + = ExtrasOptions + { eoLockTimeoutMs :: !(Maybe Int) + , eoEnforcePKs :: !Bool + -- ^ Validate that every handled table has a primary key + , eoObjectsValidationMode :: !ObjectsValidationMode + -- ^ Validation mode for unknown tables and composite types. + , eoAllowHigherTableVersions :: !Bool + -- ^ Whether to allow tables in the database to have higher versions than + -- the one in the code definition. + , eoCheckForeignKeysIndexes :: !Bool + -- ^ Check if all foreign keys have indexes. + , eoCheckOverlappingIndexes :: !Bool + -- ^ Check if some indexes are redundant + } + deriving (Eq) defaultExtrasOptions :: ExtrasOptions -defaultExtrasOptions = ExtrasOptions - { eoLockTimeoutMs = Nothing - , eoEnforcePKs = False - , eoObjectsValidationMode = DontAllowUnknownObjects - , eoAllowHigherTableVersions = False - , eoCheckForeignKeysIndexes = False - , eoCheckOverlappingIndexes = False - } +defaultExtrasOptions = + ExtrasOptions + { eoLockTimeoutMs = Nothing + , eoEnforcePKs = False + , eoObjectsValidationMode = DontAllowUnknownObjects + , eoAllowHigherTableVersions = False + , eoCheckForeignKeysIndexes = False + , eoCheckOverlappingIndexes = False + } data ObjectsValidationMode = AllowUnknownObjects | DontAllowUnknownObjects - deriving Eq + deriving (Eq) diff --git a/src/Database/PostgreSQL/PQTypes/Migrate.hs b/src/Database/PostgreSQL/PQTypes/Migrate.hs index 6ce7b19..330a3b4 100644 --- a/src/Database/PostgreSQL/PQTypes/Migrate.hs +++ b/src/Database/PostgreSQL/PQTypes/Migrate.hs @@ -1,12 +1,12 @@ -module Database.PostgreSQL.PQTypes.Migrate ( - createDomain, - createTable, - createTableConstraints, - createTableTriggers +module Database.PostgreSQL.PQTypes.Migrate + ( createDomain + , createTable + , createTableConstraints + , createTableTriggers ) where import Control.Monad -import qualified Data.Foldable as F +import Data.Foldable qualified as F import Database.PostgreSQL.PQTypes import Database.PostgreSQL.PQTypes.Checks.Util @@ -14,14 +14,14 @@ import Database.PostgreSQL.PQTypes.Model import Database.PostgreSQL.PQTypes.SQL.Builder createDomain :: MonadDB m => Domain -> m () -createDomain dom@Domain{..} = do +createDomain dom@Domain {..} = do -- create the domain runQuery_ $ sqlCreateDomain dom -- add constraint checks to the domain F.forM_ domChecks $ runQuery_ . sqlAlterDomain domName . sqlAddValidCheckMaybeDowntime createTable :: MonadDB m => Bool -> Table -> m () -createTable withConstraints table@Table{..} = do +createTable withConstraints table@Table {..} = do -- Create empty table and add the columns. runQuery_ $ sqlCreateTable tblName runQuery_ $ sqlAlterTable tblName $ map sqlAddColumn tblColumns @@ -39,11 +39,12 @@ createTable withConstraints table@Table{..} = do sqlSet "version" tblVersion createTableConstraints :: MonadDB m => Table -> m () -createTableConstraints Table{..} = unless (null addConstraints) $ do +createTableConstraints Table {..} = unless (null addConstraints) $ do runQuery_ $ sqlAlterTable tblName addConstraints where - addConstraints = map sqlAddValidCheckMaybeDowntime tblChecks - ++ map (sqlAddValidFKMaybeDowntime tblName) tblForeignKeys + addConstraints = + map sqlAddValidCheckMaybeDowntime tblChecks + ++ map (sqlAddValidFKMaybeDowntime tblName) tblForeignKeys createTableTriggers :: MonadDB m => Table -> m () createTableTriggers = mapM_ createTrigger . tblTriggers diff --git a/src/Database/PostgreSQL/PQTypes/Model.hs b/src/Database/PostgreSQL/PQTypes/Model.hs index 978ea3c..f2569d9 100644 --- a/src/Database/PostgreSQL/PQTypes/Model.hs +++ b/src/Database/PostgreSQL/PQTypes/Model.hs @@ -1,5 +1,5 @@ -module Database.PostgreSQL.PQTypes.Model ( - module Database.PostgreSQL.PQTypes.Model.Check +module Database.PostgreSQL.PQTypes.Model + ( module Database.PostgreSQL.PQTypes.Model.Check , module Database.PostgreSQL.PQTypes.Model.ColumnType , module Database.PostgreSQL.PQTypes.Model.CompositeType , module Database.PostgreSQL.PQTypes.Model.Domain diff --git a/src/Database/PostgreSQL/PQTypes/Model/Check.hs b/src/Database/PostgreSQL/PQTypes/Model/Check.hs index 8a2bcd5..72a00e5 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Check.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Check.hs @@ -1,5 +1,5 @@ -module Database.PostgreSQL.PQTypes.Model.Check ( - Check(..) +module Database.PostgreSQL.PQTypes.Model.Check + ( Check (..) , tblCheck , sqlAddValidCheckMaybeDowntime , sqlAddNotValidCheck @@ -10,19 +10,22 @@ module Database.PostgreSQL.PQTypes.Model.Check ( import Data.Monoid.Utils import Database.PostgreSQL.PQTypes -data Check = Check { - chkName :: RawSQL () -, chkCondition :: RawSQL () -, chkValidated :: Bool -- ^ Set to 'False' if check is created as NOT VALID and - -- left in such state (for whatever reason). -} deriving (Eq, Ord, Show) +data Check = Check + { chkName :: RawSQL () + , chkCondition :: RawSQL () + , chkValidated :: Bool + -- ^ Set to 'False' if check is created as NOT VALID and + -- left in such state (for whatever reason). + } + deriving (Eq, Ord, Show) tblCheck :: Check -tblCheck = Check - { chkName = "" - , chkCondition = "" - , chkValidated = True - } +tblCheck = + Check + { chkName = "" + , chkCondition = "" + , chkValidated = True + } -- | Add valid check constraint. Warning: PostgreSQL acquires SHARE ROW -- EXCLUSIVE lock (that prevents updates) on modified table for the duration of @@ -42,14 +45,15 @@ sqlValidateCheck :: RawSQL () -> RawSQL () sqlValidateCheck checkName = "VALIDATE CONSTRAINT" <+> checkName sqlAddCheck_ :: Bool -> Check -> RawSQL () -sqlAddCheck_ valid Check{..} = smconcat [ - "ADD CONSTRAINT" - , chkName - , "CHECK (" - , chkCondition - , ")" - , if valid then "" else " NOT VALID" - ] +sqlAddCheck_ valid Check {..} = + smconcat + [ "ADD CONSTRAINT" + , chkName + , "CHECK (" + , chkCondition + , ")" + , if valid then "" else " NOT VALID" + ] sqlDropCheck :: RawSQL () -> RawSQL () sqlDropCheck name = "DROP CONSTRAINT" <+> name diff --git a/src/Database/PostgreSQL/PQTypes/Model/ColumnType.hs b/src/Database/PostgreSQL/PQTypes/Model/ColumnType.hs index e7bded6..76dfe3e 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/ColumnType.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/ColumnType.hs @@ -1,10 +1,10 @@ -module Database.PostgreSQL.PQTypes.Model.ColumnType ( - ColumnType(..) +module Database.PostgreSQL.PQTypes.Model.ColumnType + ( ColumnType (..) , columnTypeToSQL ) where +import Data.Text qualified as T import Database.PostgreSQL.PQTypes -import qualified Data.Text as T data ColumnType = BigIntT @@ -25,7 +25,7 @@ data ColumnType | XmlT | ArrayT !ColumnType | CustomT !(RawSQL ()) - deriving (Eq, Ord, Show) + deriving (Eq, Ord, Show) instance PQFormat ColumnType where pqFormat = pqFormat @T.Text @@ -55,21 +55,21 @@ instance FromSQL ColumnType where | otherwise -> CustomT $ rawSQL tname () columnTypeToSQL :: ColumnType -> RawSQL () -columnTypeToSQL BigIntT = "BIGINT" -columnTypeToSQL BigSerialT = "BIGSERIAL" -columnTypeToSQL BinaryT = "BYTEA" -columnTypeToSQL BoolT = "BOOLEAN" -columnTypeToSQL DateT = "DATE" -columnTypeToSQL DoubleT = "DOUBLE PRECISION" -columnTypeToSQL IntegerT = "INTEGER" -columnTypeToSQL UuidT = "UUID" -columnTypeToSQL IntervalT = "INTERVAL" -columnTypeToSQL JsonT = "JSON" -columnTypeToSQL JsonbT = "JSONB" -columnTypeToSQL SmallIntT = "SMALLINT" -columnTypeToSQL TextT = "TEXT" -columnTypeToSQL TSVectorT = "TSVECTOR" +columnTypeToSQL BigIntT = "BIGINT" +columnTypeToSQL BigSerialT = "BIGSERIAL" +columnTypeToSQL BinaryT = "BYTEA" +columnTypeToSQL BoolT = "BOOLEAN" +columnTypeToSQL DateT = "DATE" +columnTypeToSQL DoubleT = "DOUBLE PRECISION" +columnTypeToSQL IntegerT = "INTEGER" +columnTypeToSQL UuidT = "UUID" +columnTypeToSQL IntervalT = "INTERVAL" +columnTypeToSQL JsonT = "JSON" +columnTypeToSQL JsonbT = "JSONB" +columnTypeToSQL SmallIntT = "SMALLINT" +columnTypeToSQL TextT = "TEXT" +columnTypeToSQL TSVectorT = "TSVECTOR" columnTypeToSQL TimestampWithZoneT = "TIMESTAMPTZ" -columnTypeToSQL XmlT = "XML" -columnTypeToSQL (ArrayT t) = columnTypeToSQL t <> "[]" -columnTypeToSQL (CustomT tname) = tname +columnTypeToSQL XmlT = "XML" +columnTypeToSQL (ArrayT t) = columnTypeToSQL t <> "[]" +columnTypeToSQL (CustomT tname) = tname diff --git a/src/Database/PostgreSQL/PQTypes/Model/CompositeType.hs b/src/Database/PostgreSQL/PQTypes/Model/CompositeType.hs index 48f87c6..42d97c4 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/CompositeType.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/CompositeType.hs @@ -1,30 +1,32 @@ -module Database.PostgreSQL.PQTypes.Model.CompositeType ( - CompositeType(..) - , CompositeColumn(..) +module Database.PostgreSQL.PQTypes.Model.CompositeType + ( CompositeType (..) + , CompositeColumn (..) , compositeTypePqFormat , sqlCreateComposite , sqlDropComposite , getDBCompositeTypes ) where +import Data.ByteString qualified as BS import Data.Int import Data.Monoid.Utils +import Data.Text.Encoding qualified as T import Database.PostgreSQL.PQTypes -import qualified Data.ByteString as BS -import qualified Data.Text.Encoding as T import Database.PostgreSQL.PQTypes.Model.ColumnType import Database.PostgreSQL.PQTypes.SQL.Builder -data CompositeType = CompositeType { - ctName :: !(RawSQL ()) -, ctColumns :: ![CompositeColumn] -} deriving (Eq, Ord, Show) +data CompositeType = CompositeType + { ctName :: !(RawSQL ()) + , ctColumns :: ![CompositeColumn] + } + deriving (Eq, Ord, Show) -data CompositeColumn = CompositeColumn { - ccName :: !(RawSQL ()) -, ccType :: ColumnType -} deriving (Eq, Ord, Show) +data CompositeColumn = CompositeColumn + { ccName :: !(RawSQL ()) + , ccType :: ColumnType + } + deriving (Eq, Ord, Show) -- | Convenience function for converting CompositeType definition to -- corresponding 'pqFormat' definition. @@ -33,15 +35,16 @@ compositeTypePqFormat ct = "%" `BS.append` T.encodeUtf8 (unRawSQL $ ctName ct) -- | Make SQL query that creates a composite type. sqlCreateComposite :: CompositeType -> RawSQL () -sqlCreateComposite CompositeType{..} = smconcat [ - "CREATE TYPE" - , ctName - , "AS (" - , mintercalate ", " $ map columnToSQL ctColumns - , ")" - ] +sqlCreateComposite CompositeType {..} = + smconcat + [ "CREATE TYPE" + , ctName + , "AS (" + , mintercalate ", " $ map columnToSQL ctColumns + , ")" + ] where - columnToSQL CompositeColumn{..} = ccName <+> columnTypeToSQL ccType + columnToSQL CompositeColumn {..} = ccName <+> columnTypeToSQL ccType -- | Make SQL query that drops a composite type. sqlDropComposite :: RawSQL () -> RawSQL () @@ -68,8 +71,8 @@ getDBCompositeTypes = do sqlWhereEq "a.attrelid" oid sqlOrderBy "a.attnum" columns <- fetchMany fetch - return CompositeType { ctName = unsafeSQL name, ctColumns = columns } + return CompositeType {ctName = unsafeSQL name, ctColumns = columns} where fetch :: (String, ColumnType) -> CompositeColumn fetch (cname, ctype) = - CompositeColumn { ccName = unsafeSQL cname, ccType = ctype } + CompositeColumn {ccName = unsafeSQL cname, ccType = ctype} diff --git a/src/Database/PostgreSQL/PQTypes/Model/Domain.hs b/src/Database/PostgreSQL/PQTypes/Model/Domain.hs index f65b11d..7c0cca6 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Domain.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Domain.hs @@ -1,5 +1,5 @@ -module Database.PostgreSQL.PQTypes.Model.Domain ( - Domain(..) +module Database.PostgreSQL.PQTypes.Model.Domain + ( Domain (..) , mkChecks , sqlCreateDomain , sqlAlterDomain @@ -49,31 +49,33 @@ import Database.PostgreSQL.PQTypes.Model.ColumnType -- and edit old migrations), whereas the current solution makes the -- transition trivial. -data Domain = Domain { - -- | Name of the domain. - domName :: RawSQL () - -- | Type of the domain. -, domType :: ColumnType - -- | Defines whether the domain value can be NULL. +data Domain = Domain + { domName :: RawSQL () + -- ^ Name of the domain. + , domType :: ColumnType + -- ^ Type of the domain. + , domNullable :: Bool + -- ^ Defines whether the domain value can be NULL. -- *Cannot* be superseded by a table column definition. -, domNullable :: Bool - -- Default value for the domain. *Can* be - -- superseded by a table column definition. -, domDefault :: Maybe (RawSQL ()) - -- Set of constraint checks on the domain. -, domChecks :: Set Check -} deriving (Eq, Ord, Show) + , -- Default value for the domain. *Can* be + -- superseded by a table column definition. + domDefault :: Maybe (RawSQL ()) + , -- Set of constraint checks on the domain. + domChecks :: Set Check + } + deriving (Eq, Ord, Show) mkChecks :: [Check] -> Set Check mkChecks = fromList sqlCreateDomain :: Domain -> RawSQL () -sqlCreateDomain Domain{..} = smconcat [ - "CREATE DOMAIN" <+> domName <+> "AS" - , columnTypeToSQL domType - , if domNullable then "NULL" else "NOT NULL" - , maybe "" ("DEFAULT" <+>) domDefault - ] +sqlCreateDomain Domain {..} = + smconcat + [ "CREATE DOMAIN" <+> domName <+> "AS" + , columnTypeToSQL domType + , if domNullable then "NULL" else "NOT NULL" + , maybe "" ("DEFAULT" <+>) domDefault + ] sqlAlterDomain :: RawSQL () -> RawSQL () -> RawSQL () sqlAlterDomain dname alter = "ALTER DOMAIN" <+> dname <+> alter diff --git a/src/Database/PostgreSQL/PQTypes/Model/Extension.hs b/src/Database/PostgreSQL/PQTypes/Model/Extension.hs index 9e75ce3..9703a74 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Extension.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Extension.hs @@ -1,5 +1,5 @@ -module Database.PostgreSQL.PQTypes.Model.Extension ( - Extension(..) +module Database.PostgreSQL.PQTypes.Model.Extension + ( Extension (..) , ununExtension ) where @@ -7,7 +7,7 @@ import Data.String import Data.Text (Text) import Database.PostgreSQL.PQTypes -newtype Extension = Extension { unExtension :: RawSQL () } +newtype Extension = Extension {unExtension :: RawSQL ()} deriving (Eq, Ord, Show, IsString) ununExtension :: Extension -> Text diff --git a/src/Database/PostgreSQL/PQTypes/Model/ForeignKey.hs b/src/Database/PostgreSQL/PQTypes/Model/ForeignKey.hs index d497372..3f59b5f 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/ForeignKey.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/ForeignKey.hs @@ -1,6 +1,6 @@ -module Database.PostgreSQL.PQTypes.Model.ForeignKey ( - ForeignKey(..) - , ForeignKeyAction(..) +module Database.PostgreSQL.PQTypes.Model.ForeignKey + ( ForeignKey (..) + , ForeignKeyAction (..) , fkOnColumn , fkOnColumns , fkName @@ -11,20 +11,22 @@ module Database.PostgreSQL.PQTypes.Model.ForeignKey ( ) where import Data.Monoid.Utils +import Data.Text qualified as T import Database.PostgreSQL.PQTypes -import qualified Data.Text as T -data ForeignKey = ForeignKey { - fkColumns :: [RawSQL ()] -, fkRefTable :: RawSQL () -, fkRefColumns :: [RawSQL ()] -, fkOnUpdate :: ForeignKeyAction -, fkOnDelete :: ForeignKeyAction -, fkDeferrable :: Bool -, fkDeferred :: Bool -, fkValidated :: Bool -- ^ Set to 'False' if foreign key is created as NOT - -- VALID and left in such state (for whatever reason). -} deriving (Eq, Ord, Show) +data ForeignKey = ForeignKey + { fkColumns :: [RawSQL ()] + , fkRefTable :: RawSQL () + , fkRefColumns :: [RawSQL ()] + , fkOnUpdate :: ForeignKeyAction + , fkOnDelete :: ForeignKeyAction + , fkDeferrable :: Bool + , fkDeferred :: Bool + , fkValidated :: Bool + -- ^ Set to 'False' if foreign key is created as NOT + -- VALID and left in such state (for whatever reason). + } + deriving (Eq, Ord, Show) data ForeignKeyAction = ForeignKeyNoAction @@ -39,26 +41,29 @@ fkOnColumn column reftable refcolumn = fkOnColumns [column] reftable [refcolumn] fkOnColumns :: [RawSQL ()] -> RawSQL () -> [RawSQL ()] -> ForeignKey -fkOnColumns columns reftable refcolumns = ForeignKey { - fkColumns = columns -, fkRefTable = reftable -, fkRefColumns = refcolumns -, fkOnUpdate = ForeignKeyCascade -, fkOnDelete = ForeignKeyNoAction -, fkDeferrable = True -, fkDeferred = False -, fkValidated = True -} +fkOnColumns columns reftable refcolumns = + ForeignKey + { fkColumns = columns + , fkRefTable = reftable + , fkRefColumns = refcolumns + , fkOnUpdate = ForeignKeyCascade + , fkOnDelete = ForeignKeyNoAction + , fkDeferrable = True + , fkDeferred = False + , fkValidated = True + } fkName :: RawSQL () -> ForeignKey -> RawSQL () -fkName tname ForeignKey{..} = shorten $ mconcat [ - "fk__" - , tname - , "__" - , mintercalate "__" fkColumns - , "__" - , fkRefTable - ] +fkName tname ForeignKey {..} = + shorten $ + mconcat + [ "fk__" + , tname + , "__" + , mintercalate "__" fkColumns + , "__" + , fkRefTable + ] where -- PostgreSQL's limit for identifier is 63 characters shorten = flip rawSQL () . T.take 63 . unRawSQL @@ -82,22 +87,23 @@ sqlValidateFK :: RawSQL () -> ForeignKey -> RawSQL () sqlValidateFK tname fk = "VALIDATE CONSTRAINT" <+> fkName tname fk sqlAddFK_ :: Bool -> RawSQL () -> ForeignKey -> RawSQL () -sqlAddFK_ valid tname fk@ForeignKey{..} = mconcat [ - "ADD CONSTRAINT" <+> fkName tname fk <+> "FOREIGN KEY (" - , mintercalate ", " fkColumns - , ") REFERENCES" <+> fkRefTable <+> "(" - , mintercalate ", " fkRefColumns - , ") ON UPDATE" <+> foreignKeyActionToSQL fkOnUpdate - , " ON DELETE" <+> foreignKeyActionToSQL fkOnDelete - , " " <> if fkDeferrable then "DEFERRABLE" else "NOT DEFERRABLE" - , " INITIALLY" <+> if fkDeferred then "DEFERRED" else "IMMEDIATE" - , if valid then "" else " NOT VALID" - ] +sqlAddFK_ valid tname fk@ForeignKey {..} = + mconcat + [ "ADD CONSTRAINT" <+> fkName tname fk <+> "FOREIGN KEY (" + , mintercalate ", " fkColumns + , ") REFERENCES" <+> fkRefTable <+> "(" + , mintercalate ", " fkRefColumns + , ") ON UPDATE" <+> foreignKeyActionToSQL fkOnUpdate + , " ON DELETE" <+> foreignKeyActionToSQL fkOnDelete + , " " <> if fkDeferrable then "DEFERRABLE" else "NOT DEFERRABLE" + , " INITIALLY" <+> if fkDeferred then "DEFERRED" else "IMMEDIATE" + , if valid then "" else " NOT VALID" + ] where - foreignKeyActionToSQL ForeignKeyNoAction = "NO ACTION" - foreignKeyActionToSQL ForeignKeyRestrict = "RESTRICT" - foreignKeyActionToSQL ForeignKeyCascade = "CASCADE" - foreignKeyActionToSQL ForeignKeySetNull = "SET NULL" + foreignKeyActionToSQL ForeignKeyNoAction = "NO ACTION" + foreignKeyActionToSQL ForeignKeyRestrict = "RESTRICT" + foreignKeyActionToSQL ForeignKeyCascade = "CASCADE" + foreignKeyActionToSQL ForeignKeySetNull = "SET NULL" foreignKeyActionToSQL ForeignKeySetDefault = "SET DEFAULT" sqlDropFK :: RawSQL () -> ForeignKey -> RawSQL () diff --git a/src/Database/PostgreSQL/PQTypes/Model/Index.hs b/src/Database/PostgreSQL/PQTypes/Model/Index.hs index 90ccf8c..f67dbf3 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Index.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Index.hs @@ -1,9 +1,9 @@ -module Database.PostgreSQL.PQTypes.Model.Index ( - TableIndex(..) - , IndexColumn(..) +module Database.PostgreSQL.PQTypes.Model.Index + ( TableIndex (..) + , IndexColumn (..) , indexColumn , indexColumnWithOperatorClass - , IndexMethod(..) + , IndexMethod (..) , tblIndex , indexOnColumn , indexOnColumns @@ -20,38 +20,39 @@ module Database.PostgreSQL.PQTypes.Model.Index ( , sqlDropIndexConcurrently ) where +import Crypto.Hash qualified as H +import Data.ByteArray qualified as BA +import Data.ByteString.Base16 qualified as B16 +import Data.ByteString.Char8 qualified as BS import Data.Char import Data.Function -import Data.String import Data.Monoid.Utils +import Data.String +import Data.Text qualified as T +import Data.Text.Encoding qualified as T import Database.PostgreSQL.PQTypes -import qualified Crypto.Hash as H -import qualified Data.ByteArray as BA -import qualified Data.ByteString.Base16 as B16 -import qualified Data.ByteString.Char8 as BS -import qualified Data.Text as T -import qualified Data.Text.Encoding as T - -data TableIndex = TableIndex { - idxColumns :: [IndexColumn] -, idxInclude :: [RawSQL ()] -, idxMethod :: IndexMethod -, idxUnique :: Bool -, idxValid :: Bool --- ^ If creation of index with CONCURRENTLY fails, index --- will be marked as invalid. Set it to 'False' if such --- situation is expected. -, idxWhere :: Maybe (RawSQL ()) -, idxNotDistinctNulls :: Bool --- ^ Adds NULL NOT DISTINCT on the index, meaning that --- ^ only one NULL value will be accepted; other NULLs --- ^ will be perceived as a violation of the constraint. --- ^ NB: will only be used if idxUnique is set to True -} deriving (Eq, Ord, Show) + +data TableIndex = TableIndex + { idxColumns :: [IndexColumn] + , idxInclude :: [RawSQL ()] + , idxMethod :: IndexMethod + , idxUnique :: Bool + , idxValid :: Bool + , -- \^ If creation of index with CONCURRENTLY fails, index + -- will be marked as invalid. Set it to 'False' if such + -- situation is expected. + idxWhere :: Maybe (RawSQL ()) + , idxNotDistinctNulls :: Bool + } + -- \^ Adds NULL NOT DISTINCT on the index, meaning that + -- \^ only one NULL value will be accepted; other NULLs + -- \^ will be perceived as a violation of the constraint. + -- \^ NB: will only be used if idxUnique is set to True + deriving (Eq, Ord, Show) data IndexColumn = IndexColumn (RawSQL ()) (Maybe (RawSQL ())) - deriving Show + deriving (Show) -- If one of the two columns doesn't specify the operator class, we just ignore -- it and still treat them as equivalent. @@ -75,111 +76,129 @@ indexColumnWithOperatorClass col opclass = IndexColumn col (Just opclass) indexColumnName :: IndexColumn -> RawSQL () indexColumnName (IndexColumn col _) = col -data IndexMethod = - BTree +data IndexMethod + = BTree | GIN deriving (Eq, Ord) instance Show IndexMethod where - show BTree = "btree" - show GIN = "gin" + show BTree = "btree" + show GIN = "gin" instance Read IndexMethod where - readsPrec _ (map toLower -> "btree") = [(BTree,"")] - readsPrec _ (map toLower -> "gin") = [(GIN,"")] - readsPrec _ _ = [] + readsPrec _ (map toLower -> "btree") = [(BTree, "")] + readsPrec _ (map toLower -> "gin") = [(GIN, "")] + readsPrec _ _ = [] tblIndex :: TableIndex -tblIndex = TableIndex { - idxColumns = [] -, idxInclude = [] -, idxMethod = BTree -, idxUnique = False -, idxValid = True -, idxWhere = Nothing -, idxNotDistinctNulls = False -} +tblIndex = + TableIndex + { idxColumns = [] + , idxInclude = [] + , idxMethod = BTree + , idxUnique = False + , idxValid = True + , idxWhere = Nothing + , idxNotDistinctNulls = False + } indexOnColumn :: IndexColumn -> TableIndex -indexOnColumn column = tblIndex { idxColumns = [column] } +indexOnColumn column = tblIndex {idxColumns = [column]} -- | Create an index on the given column with the specified method. No checks -- are made that the method is appropriate for the type of the column. indexOnColumnWithMethod :: IndexColumn -> IndexMethod -> TableIndex indexOnColumnWithMethod column method = - tblIndex { idxColumns = [column] - , idxMethod = method } + tblIndex + { idxColumns = [column] + , idxMethod = method + } indexOnColumns :: [IndexColumn] -> TableIndex -indexOnColumns columns = tblIndex { idxColumns = columns } +indexOnColumns columns = tblIndex {idxColumns = columns} -- | Create an index on the given columns with the specified method. No checks -- are made that the method is appropriate for the type of the column; -- cf. [the PostgreSQL manual](https://www.postgresql.org/docs/current/static/indexes-multicolumn.html). indexOnColumnsWithMethod :: [IndexColumn] -> IndexMethod -> TableIndex indexOnColumnsWithMethod columns method = - tblIndex { idxColumns = columns - , idxMethod = method } + tblIndex + { idxColumns = columns + , idxMethod = method + } uniqueIndexOnColumn :: IndexColumn -> TableIndex -uniqueIndexOnColumn column = TableIndex { - idxColumns = [column] -, idxInclude = [] -, idxMethod = BTree -, idxUnique = True -, idxValid = True -, idxWhere = Nothing -, idxNotDistinctNulls = False -} +uniqueIndexOnColumn column = + TableIndex + { idxColumns = [column] + , idxInclude = [] + , idxMethod = BTree + , idxUnique = True + , idxValid = True + , idxWhere = Nothing + , idxNotDistinctNulls = False + } uniqueIndexOnColumns :: [IndexColumn] -> TableIndex -uniqueIndexOnColumns columns = TableIndex { - idxColumns = columns -, idxInclude = [] -, idxMethod = BTree -, idxUnique = True -, idxValid = True -, idxWhere = Nothing -, idxNotDistinctNulls = False -} +uniqueIndexOnColumns columns = + TableIndex + { idxColumns = columns + , idxInclude = [] + , idxMethod = BTree + , idxUnique = True + , idxValid = True + , idxWhere = Nothing + , idxNotDistinctNulls = False + } uniqueIndexOnColumnWithCondition :: IndexColumn -> RawSQL () -> TableIndex -uniqueIndexOnColumnWithCondition column whereC = TableIndex { - idxColumns = [column] -, idxInclude = [] -, idxMethod = BTree -, idxUnique = True -, idxValid = True -, idxWhere = Just whereC -, idxNotDistinctNulls = False -} +uniqueIndexOnColumnWithCondition column whereC = + TableIndex + { idxColumns = [column] + , idxInclude = [] + , idxMethod = BTree + , idxUnique = True + , idxValid = True + , idxWhere = Just whereC + , idxNotDistinctNulls = False + } indexName :: RawSQL () -> TableIndex -> RawSQL () -indexName tname TableIndex{..} = flip rawSQL () $ T.take 63 . unRawSQL $ mconcat [ - if idxUnique then "unique_idx__" else "idx__" - , tname - , "__" - , mintercalate "__" $ map (asText sanitize . indexColumnName) idxColumns - , if null idxInclude - then "" - else "$$" <> mintercalate "__" (map (asText sanitize) idxInclude) - , maybe "" (("__" <>) . hashWhere) idxWhere - ] +indexName tname TableIndex {..} = + flip rawSQL () $ + T.take 63 . unRawSQL $ + mconcat + [ if idxUnique then "unique_idx__" else "idx__" + , tname + , "__" + , mintercalate "__" $ map (asText sanitize . indexColumnName) idxColumns + , if null idxInclude + then "" + else "$$" <> mintercalate "__" (map (asText sanitize) idxInclude) + , maybe "" (("__" <>) . hashWhere) idxWhere + ] where asText f = flip rawSQL () . f . unRawSQL -- See http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS. -- Remove all unallowed characters and replace them by at most one adjacent dollar sign. sanitize = T.pack . foldr go [] . T.unpack where - go c acc = if isAlphaNum c || c == '_' - then c : acc - else case acc of - ('$':_) -> acc - _ -> '$' : acc + go c acc = + if isAlphaNum c || c == '_' + then c : acc + else case acc of + ('$' : _) -> acc + _ -> '$' : acc -- hash WHERE clause and add it to index name so that indexes -- with the same columns, but different constraints can coexist - hashWhere = asText $ T.decodeUtf8 . B16.encode . BS.take 10 - . BA.convert . H.hash @_ @H.RIPEMD160 . T.encodeUtf8 + hashWhere = + asText $ + T.decodeUtf8 + . B16.encode + . BS.take 10 + . BA.convert + . H.hash @_ @H.RIPEMD160 + . T.encodeUtf8 -- | Create an index. Warning: if the affected table is large, this will prevent -- the table from being modified during the creation. If this is not acceptable, @@ -194,30 +213,33 @@ sqlCreateIndexConcurrently :: RawSQL () -> TableIndex -> RawSQL () sqlCreateIndexConcurrently = sqlCreateIndex_ True sqlCreateIndex_ :: Bool -> RawSQL () -> TableIndex -> RawSQL () -sqlCreateIndex_ concurrently tname idx@TableIndex{..} = mconcat [ - "CREATE" - , if idxUnique then " UNIQUE" else "" - , " INDEX " - , if concurrently then "CONCURRENTLY " else "" - , indexName tname idx - , " ON" <+> tname - , " USING" <+> rawSQL (T.pack . show $ idxMethod) () <+> "(" - , mintercalate ", " - (map - (\case - IndexColumn col Nothing -> col - IndexColumn col (Just opclass) -> col <+> opclass +sqlCreateIndex_ concurrently tname idx@TableIndex {..} = + mconcat + [ "CREATE" + , if idxUnique then " UNIQUE" else "" + , " INDEX " + , if concurrently then "CONCURRENTLY " else "" + , indexName tname idx + , " ON" <+> tname + , " USING" <+> rawSQL (T.pack . show $ idxMethod) () <+> "(" + , mintercalate + ", " + ( map + ( \case + IndexColumn col Nothing -> col + IndexColumn col (Just opclass) -> col <+> opclass + ) + idxColumns ) - idxColumns) - , ")" - , if null idxInclude - then "" - else " INCLUDE (" <> mintercalate ", " idxInclude <> ")" - , if idxUnique && idxNotDistinctNulls - then " NULLS NOT DISTINCT" - else "" - , maybe "" (" WHERE" <+>) idxWhere - ] + , ")" + , if null idxInclude + then "" + else " INCLUDE (" <> mintercalate ", " idxInclude <> ")" + , if idxUnique && idxNotDistinctNulls + then " NULLS NOT DISTINCT" + else "" + , maybe "" (" WHERE" <+>) idxWhere + ] -- | Drop an index. Warning: if you don't want to lock out concurrent operations -- on the index's table, use 'DropIndexConcurrentlyMigration'. See diff --git a/src/Database/PostgreSQL/PQTypes/Model/Migration.hs b/src/Database/PostgreSQL/PQTypes/Model/Migration.hs index 52cdc11..87bb063 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Migration.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Migration.hs @@ -1,33 +1,32 @@ {-# LANGUAGE CPP #-} -{- | -Using migrations is fairly easy. After you've defined the lists of -migrations and tables, just run -'Database.PostgreSQL.PQTypes.Checks.migrateDatabase': - -@ -tables :: [Table] -tables = ... - -migrations :: [Migration] -migrations = ... - -migrateDatabase options extensions domains tables migrations -@ - -Migrations are run strictly in the order specified in the migrations -list, starting with the first migration for which the corresponding -table in the DB has the version number equal to the 'mgrFrom' field of -the migration. - --} - -module Database.PostgreSQL.PQTypes.Model.Migration ( - DropTableMode(..), - MigrationAction(..), - Migration(..), - isStandardMigration, isDropTableMigration - ) where +-- | +-- +-- Using migrations is fairly easy. After you've defined the lists of +-- migrations and tables, just run +-- 'Database.PostgreSQL.PQTypes.Checks.migrateDatabase': +-- +-- @ +-- tables :: [Table] +-- tables = ... +-- +-- migrations :: [Migration] +-- migrations = ... +-- +-- migrateDatabase options extensions domains tables migrations +-- @ +-- +-- Migrations are run strictly in the order specified in the migrations +-- list, starting with the first migration for which the corresponding +-- table in the DB has the version number equal to the 'mgrFrom' field of +-- the migration. +module Database.PostgreSQL.PQTypes.Model.Migration + ( DropTableMode (..) + , MigrationAction (..) + , Migration (..) + , isStandardMigration + , isDropTableMigration + ) where import Data.Int @@ -39,74 +38,73 @@ import Database.PostgreSQL.PQTypes.SQL.Raw -- | Migration action to run, either an arbitrary 'MonadDB' action, or -- something more fine-grained. -data MigrationAction m = - - -- | Standard migration, i.e. an arbitrary 'MonadDB' action. - StandardMigration (m ()) - - -- | Drop table migration. Parameter is the drop table mode - -- (@RESTRICT@/@CASCADE@). The 'Migration' record holds the name of - -- the table to drop. - | DropTableMigration DropTableMode - - -- | Migration for creating an index concurrently. - | CreateIndexConcurrentlyMigration - (RawSQL ()) -- ^ Table name - TableIndex -- ^ Index - - -- | Migration for dropping an index concurrently. - | DropIndexConcurrentlyMigration - (RawSQL ()) -- ^ Table name - TableIndex -- ^ Index - - -- | Migration for modifying columns. Parameters are: - -- - -- Name of the table that the cursor is associated with. It has to be the same as in the - -- cursor SQL, see the second parameter. - -- - -- SQL providing a list of primary keys from the associated table that will be used for the cursor. - -- - -- Function that takes a batch of primary keys provided by the cursor SQL and runs an arbitrary computation - -- within MonadDB. The function might be called repeatedly depending on the batch size and total number of - -- selected primary keys. See the last argument. - -- - -- Batch size of primary keys to be fetched at once by the cursor SQL and be given to the modification function. - -- To handle multi-column primary keys, the following needs to be done: - -- - -- 1. Get the list of tuples from PostgreSQL. - -- 2. Unzip them into a tuple of lists in Haskell. - -- 3. Pass the lists to PostgreSQL as separate parameters and zip them back in the SQL, - -- see https://stackoverflow.com/questions/12414750/is-there-something-like-a-zip-function-in-postgresql-that-combines-two-arrays for more details. - | forall t . FromRow t => ModifyColumnMigration (RawSQL ()) SQL ([t] -> m ()) Int +data MigrationAction m + = -- | Standard migration, i.e. an arbitrary 'MonadDB' action. + StandardMigration (m ()) + | -- | Drop table migration. Parameter is the drop table mode + -- (@RESTRICT@/@CASCADE@). The 'Migration' record holds the name of + -- the table to drop. + DropTableMigration DropTableMode + | -- | Migration for creating an index concurrently. + CreateIndexConcurrentlyMigration + (RawSQL ()) + -- ^ Table name + TableIndex + -- ^ Index + | -- | Migration for dropping an index concurrently. + DropIndexConcurrentlyMigration + (RawSQL ()) + -- ^ Table name + TableIndex + -- ^ Index + | -- | Migration for modifying columns. Parameters are: + -- + -- Name of the table that the cursor is associated with. It has to be the same as in the + -- cursor SQL, see the second parameter. + -- + -- SQL providing a list of primary keys from the associated table that will be used for the cursor. + -- + -- Function that takes a batch of primary keys provided by the cursor SQL and runs an arbitrary computation + -- within MonadDB. The function might be called repeatedly depending on the batch size and total number of + -- selected primary keys. See the last argument. + -- + -- Batch size of primary keys to be fetched at once by the cursor SQL and be given to the modification function. + -- To handle multi-column primary keys, the following needs to be done: + -- + -- 1. Get the list of tuples from PostgreSQL. + -- 2. Unzip them into a tuple of lists in Haskell. + -- 3. Pass the lists to PostgreSQL as separate parameters and zip them back in the SQL, + -- see https://stackoverflow.com/questions/12414750/is-there-something-like-a-zip-function-in-postgresql-that-combines-two-arrays for more details. + forall t. FromRow t => ModifyColumnMigration (RawSQL ()) SQL ([t] -> m ()) Int -- | Migration object. -data Migration m = - Migration { - -- | The name of the table you're migrating. - mgrTableName :: RawSQL () - -- | The version you're migrating *from* (you don't specify what +data Migration m + = Migration + { mgrTableName :: RawSQL () + -- ^ The name of the table you're migrating. + , mgrFrom :: Int32 + -- ^ The version you're migrating *from* (you don't specify what -- version you migrate TO, because version is always increased by 1, -- so if 'mgrFrom' is 2, that means that after that migration is run, -- table version will equal 3 -, mgrFrom :: Int32 - -- | Migration action. -, mgrAction :: MigrationAction m + , mgrAction :: MigrationAction m + -- ^ Migration action. } isStandardMigration :: Migration m -> Bool -isStandardMigration Migration{..} = +isStandardMigration Migration {..} = case mgrAction of - StandardMigration{} -> True - DropTableMigration{} -> False - CreateIndexConcurrentlyMigration{} -> False - DropIndexConcurrentlyMigration{} -> False - ModifyColumnMigration{} -> False + StandardMigration {} -> True + DropTableMigration {} -> False + CreateIndexConcurrentlyMigration {} -> False + DropIndexConcurrentlyMigration {} -> False + ModifyColumnMigration {} -> False isDropTableMigration :: Migration m -> Bool -isDropTableMigration Migration{..} = +isDropTableMigration Migration {..} = case mgrAction of - StandardMigration{} -> False - DropTableMigration{} -> True - CreateIndexConcurrentlyMigration{} -> False - DropIndexConcurrentlyMigration{} -> False - ModifyColumnMigration{} -> False + StandardMigration {} -> False + DropTableMigration {} -> True + CreateIndexConcurrentlyMigration {} -> False + DropIndexConcurrentlyMigration {} -> False + ModifyColumnMigration {} -> False diff --git a/src/Database/PostgreSQL/PQTypes/Model/PrimaryKey.hs b/src/Database/PostgreSQL/PQTypes/Model/PrimaryKey.hs index 7b07887..d8c5fee 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/PrimaryKey.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/PrimaryKey.hs @@ -1,5 +1,5 @@ -module Database.PostgreSQL.PQTypes.Model.PrimaryKey ( - PrimaryKey +module Database.PostgreSQL.PQTypes.Model.PrimaryKey + ( PrimaryKey , pkOnColumn , pkOnColumns , pkName @@ -21,7 +21,7 @@ pkOnColumn :: RawSQL () -> Maybe PrimaryKey pkOnColumn column = Just . PrimaryKey . toNubList $ [column] pkOnColumns :: [RawSQL ()] -> Maybe PrimaryKey -pkOnColumns [] = Nothing +pkOnColumns [] = Nothing pkOnColumns columns = Just . PrimaryKey . toNubList $ columns pkName :: RawSQL () -> RawSQL () @@ -31,25 +31,27 @@ pkColumns :: PrimaryKey -> [RawSQL ()] pkColumns (PrimaryKey columns) = fromNubList columns sqlAddPK :: RawSQL () -> PrimaryKey -> RawSQL () -sqlAddPK tname (PrimaryKey columns) = smconcat [ - "ADD CONSTRAINT" - , pkName tname - , "PRIMARY KEY (" - , mintercalate ", " $ fromNubList columns - , ")" - ] +sqlAddPK tname (PrimaryKey columns) = + smconcat + [ "ADD CONSTRAINT" + , pkName tname + , "PRIMARY KEY (" + , mintercalate ", " $ fromNubList columns + , ")" + ] -- | Convert a unique index into a primary key. Main usage is to build a unique -- index concurrently first (so that its creation doesn't conflict with table -- updates on the modified table) and then convert it into a primary key using -- this function. sqlAddPKUsing :: RawSQL () -> TableIndex -> RawSQL () -sqlAddPKUsing tname idx = smconcat - [ "ADD CONSTRAINT" - , pkName tname - , "PRIMARY KEY USING INDEX" - , indexName tname idx - ] +sqlAddPKUsing tname idx = + smconcat + [ "ADD CONSTRAINT" + , pkName tname + , "PRIMARY KEY USING INDEX" + , indexName tname idx + ] sqlDropPK :: RawSQL () -> RawSQL () sqlDropPK tname = "DROP CONSTRAINT" <+> pkName tname diff --git a/src/Database/PostgreSQL/PQTypes/Model/Table.hs b/src/Database/PostgreSQL/PQTypes/Model/Table.hs index 8fd53d9..a736fee 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Table.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Table.hs @@ -1,17 +1,17 @@ -module Database.PostgreSQL.PQTypes.Model.Table ( - TableColumn(..) +module Database.PostgreSQL.PQTypes.Model.Table + ( TableColumn (..) , tblColumn , sqlAddColumn , sqlAlterColumn , sqlDropColumn - , Rows(..) - , Table(..) + , Rows (..) + , Table (..) , tblTable , sqlCreateTable , sqlAlterTable - , DropTableMode(..) + , DropTableMode (..) , sqlDropTable - , TableInitialSetup(..) + , TableInitialSetup (..) ) where import Control.Monad.Catch @@ -27,32 +27,35 @@ import Database.PostgreSQL.PQTypes.Model.Index import Database.PostgreSQL.PQTypes.Model.PrimaryKey import Database.PostgreSQL.PQTypes.Model.Trigger -data TableColumn = TableColumn { - colName :: RawSQL () -, colType :: ColumnType -, colCollation :: Maybe (RawSQL ()) -, colNullable :: Bool -, colDefault :: Maybe (RawSQL ()) -} deriving Show +data TableColumn = TableColumn + { colName :: RawSQL () + , colType :: ColumnType + , colCollation :: Maybe (RawSQL ()) + , colNullable :: Bool + , colDefault :: Maybe (RawSQL ()) + } + deriving (Show) tblColumn :: TableColumn -tblColumn = TableColumn { - colName = error "tblColumn: column name must be specified" -, colType = error "tblColumn: column type must be specified" -, colCollation = Nothing -, colNullable = True -, colDefault = Nothing -} +tblColumn = + TableColumn + { colName = error "tblColumn: column name must be specified" + , colType = error "tblColumn: column type must be specified" + , colCollation = Nothing + , colNullable = True + , colDefault = Nothing + } sqlAddColumn :: TableColumn -> RawSQL () -sqlAddColumn TableColumn{..} = smconcat [ - "ADD COLUMN" - , colName - , columnTypeToSQL colType - , maybe "" (\c -> "COLLATE \"" <> c <> "\"") colCollation - , if colNullable then "NULL" else "NOT NULL" - , maybe "" ("DEFAULT" <+>) colDefault - ] +sqlAddColumn TableColumn {..} = + smconcat + [ "ADD COLUMN" + , colName + , columnTypeToSQL colType + , maybe "" (\c -> "COLLATE \"" <> c <> "\"") colCollation + , if colNullable then "NULL" else "NOT NULL" + , maybe "" ("DEFAULT" <+>) colDefault + ] sqlAlterColumn :: RawSQL () -> RawSQL () -> RawSQL () sqlAlterColumn cname alter = "ALTER COLUMN" <+> cname <+> alter @@ -64,56 +67,61 @@ sqlDropColumn cname = "DROP COLUMN" <+> cname data Rows = forall row. (Show row, ToRow row) => Rows [ByteString] [row] -data Table = - Table { - tblName :: RawSQL () -- ^ Must be in lower case. -, tblVersion :: Int32 -, tblColumns :: [TableColumn] -, tblPrimaryKey :: Maybe PrimaryKey -, tblChecks :: [Check] -, tblForeignKeys :: [ForeignKey] -, tblIndexes :: [TableIndex] -, tblTriggers :: [Trigger] -, tblInitialSetup :: Maybe TableInitialSetup -} - -data TableInitialSetup = TableInitialSetup { - checkInitialSetup :: forall m. (MonadDB m, MonadThrow m) => m Bool -, initialSetup :: forall m. (MonadDB m, MonadThrow m) => m () -} +data Table + = Table + { tblName :: RawSQL () + -- ^ Must be in lower case. + , tblVersion :: Int32 + , tblColumns :: [TableColumn] + , tblPrimaryKey :: Maybe PrimaryKey + , tblChecks :: [Check] + , tblForeignKeys :: [ForeignKey] + , tblIndexes :: [TableIndex] + , tblTriggers :: [Trigger] + , tblInitialSetup :: Maybe TableInitialSetup + } + +data TableInitialSetup = TableInitialSetup + { checkInitialSetup :: forall m. (MonadDB m, MonadThrow m) => m Bool + , initialSetup :: forall m. (MonadDB m, MonadThrow m) => m () + } tblTable :: Table -tblTable = Table { - tblName = error "tblTable: table name must be specified" -, tblVersion = error "tblTable: table version must be specified" -, tblColumns = error "tblTable: table columns must be specified" -, tblPrimaryKey = Nothing -, tblChecks = [] -, tblForeignKeys = [] -, tblIndexes = [] -, tblTriggers = [] -, tblInitialSetup = Nothing -} +tblTable = + Table + { tblName = error "tblTable: table name must be specified" + , tblVersion = error "tblTable: table version must be specified" + , tblColumns = error "tblTable: table columns must be specified" + , tblPrimaryKey = Nothing + , tblChecks = [] + , tblForeignKeys = [] + , tblIndexes = [] + , tblTriggers = [] + , tblInitialSetup = Nothing + } sqlCreateTable :: RawSQL () -> RawSQL () sqlCreateTable tname = "CREATE TABLE" <+> tname <+> "()" -- | Whether to also drop objects that depend on the table. -data DropTableMode = - -- | Automatically drop objects that depend on the table (such as views). - DropTableCascade | - -- | Refuse to drop the table if any objects depend on it. This is the default. - DropTableRestrict +data DropTableMode + = -- | Automatically drop objects that depend on the table (such as views). + DropTableCascade + | -- | Refuse to drop the table if any objects depend on it. This is the default. + DropTableRestrict sqlDropTable :: RawSQL () -> DropTableMode -> RawSQL () -sqlDropTable tname mode = "DROP TABLE" <+> tname - <+> case mode of - DropTableCascade -> "CASCADE" - DropTableRestrict -> "RESTRICT" +sqlDropTable tname mode = + "DROP TABLE" + <+> tname + <+> case mode of + DropTableCascade -> "CASCADE" + DropTableRestrict -> "RESTRICT" sqlAlterTable :: RawSQL () -> [RawSQL ()] -> RawSQL () -sqlAlterTable tname alter_statements = smconcat [ - "ALTER TABLE" - , tname - , mintercalate ", " alter_statements - ] +sqlAlterTable tname alter_statements = + smconcat + [ "ALTER TABLE" + , tname + , mintercalate ", " alter_statements + ] diff --git a/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs b/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs index 994eed3..6e1bbb6 100644 --- a/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs +++ b/src/Database/PostgreSQL/PQTypes/Model/Trigger.hs @@ -6,11 +6,10 @@ -- created with no arguments and always @RETURN TRIGGER@. -- -- For details, see . - -module Database.PostgreSQL.PQTypes.Model.Trigger ( - -- * Triggers - TriggerEvent(..) - , Trigger(..) +module Database.PostgreSQL.PQTypes.Model.Trigger + ( -- * Triggers + TriggerEvent (..) + , Trigger (..) , triggerMakeName , triggerBaseName , sqlCreateTrigger @@ -18,7 +17,8 @@ module Database.PostgreSQL.PQTypes.Model.Trigger ( , createTrigger , dropTrigger , getDBTriggers - -- * Trigger functions + + -- * Trigger functions , sqlCreateTriggerFunction , sqlDropTriggerFunction , triggerFunctionMakeName @@ -29,59 +29,61 @@ import Data.Foldable (foldl') import Data.Int import Data.Monoid.Utils import Data.Set (Set) +import Data.Set qualified as Set import Data.Text (Text) +import Data.Text qualified as Text import Database.PostgreSQL.PQTypes import Database.PostgreSQL.PQTypes.SQL.Builder -import qualified Data.Set as Set -import qualified Data.Text as Text -- | Trigger event name. -- -- @since 1.15.0.0 data TriggerEvent - = TriggerInsert - -- ^ The @INSERT@ event. - | TriggerUpdate - -- ^ The @UPDATE@ event. - | TriggerUpdateOf [RawSQL ()] - -- ^ The @UPDATE OF column1 [, column2 ...]@ event. - | TriggerDelete - -- ^ The @DELETE@ event. + = -- | The @INSERT@ event. + TriggerInsert + | -- | The @UPDATE@ event. + TriggerUpdate + | -- | The @UPDATE OF column1 [, column2 ...]@ event. + TriggerUpdateOf [RawSQL ()] + | -- | The @DELETE@ event. + TriggerDelete deriving (Eq, Ord, Show) -- | Trigger. -- -- @since 1.15.0.0 -data Trigger = Trigger { - triggerTable :: RawSQL () - -- ^ The table that the trigger is associated with. - , triggerName :: RawSQL () - -- ^ The internal name without any prefixes. Trigger name must be unique among - -- triggers of same table. See 'triggerMakeName'. - , triggerEvents :: Set TriggerEvent - -- ^ The set of events. Corresponds to the @{ __event__ [ OR ... ] }@ in the trigger - -- definition. The order in which they are defined doesn't matter and there can - -- only be one of each. - , triggerDeferrable :: Bool - -- ^ Is the trigger @DEFERRABLE@ or @NOT DEFERRABLE@ ? +data Trigger = Trigger + { triggerTable :: RawSQL () + -- ^ The table that the trigger is associated with. + , triggerName :: RawSQL () + -- ^ The internal name without any prefixes. Trigger name must be unique among + -- triggers of same table. See 'triggerMakeName'. + , triggerEvents :: Set TriggerEvent + -- ^ The set of events. Corresponds to the @{ __event__ [ OR ... ] }@ in the trigger + -- definition. The order in which they are defined doesn't matter and there can + -- only be one of each. + , triggerDeferrable :: Bool + -- ^ Is the trigger @DEFERRABLE@ or @NOT DEFERRABLE@ ? , triggerInitiallyDeferred :: Bool - -- ^ Is the trigger @INITIALLY DEFERRED@ or @INITIALLY IMMEDIATE@ ? - , triggerWhen :: Maybe (RawSQL ()) - -- ^ The condition that specifies whether the trigger should fire. Corresponds to the - -- @WHEN ( __condition__ )@ in the trigger definition. - , triggerFunction :: RawSQL () - -- ^ The function to execute when the trigger fires. -} deriving (Show) + -- ^ Is the trigger @INITIALLY DEFERRED@ or @INITIALLY IMMEDIATE@ ? + , triggerWhen :: Maybe (RawSQL ()) + -- ^ The condition that specifies whether the trigger should fire. Corresponds to the + -- @WHEN ( __condition__ )@ in the trigger definition. + , triggerFunction :: RawSQL () + -- ^ The function to execute when the trigger fires. + } + deriving (Show) instance Eq Trigger where t1 == t2 = triggerTable t1 == triggerTable t2 - && triggerName t1 == triggerName t2 - && triggerEvents t1 == triggerEvents t2 - && triggerDeferrable t1 == triggerDeferrable t2 - && triggerInitiallyDeferred t1 == triggerInitiallyDeferred t2 - && triggerWhen t1 == triggerWhen t2 - -- Function source code is not guaranteed to be equal, so we ignore it. + && triggerName t1 == triggerName t2 + && triggerEvents t1 == triggerEvents t2 + && triggerDeferrable t1 == triggerDeferrable t2 + && triggerInitiallyDeferred t1 == triggerInitiallyDeferred t2 + && triggerWhen t1 == triggerWhen t2 + +-- Function source code is not guaranteed to be equal, so we ignore it. -- | Make a trigger name that can be used in SQL. -- @@ -107,9 +109,10 @@ triggerEventName :: TriggerEvent -> RawSQL () triggerEventName = \case TriggerInsert -> "INSERT" TriggerUpdate -> "UPDATE" - TriggerUpdateOf columns -> if null columns - then error "UPDATE OF must have columns." - else "UPDATE OF" <+> mintercalate ", " columns + TriggerUpdateOf columns -> + if null columns + then error "UPDATE OF must have columns." + else "UPDATE OF" <+> mintercalate ", " columns TriggerDelete -> "DELETE" -- | Build an SQL statement that creates a trigger. @@ -118,14 +121,18 @@ triggerEventName = \case -- -- @since 1.15.0 sqlCreateTrigger :: Trigger -> RawSQL () -sqlCreateTrigger Trigger{..} = - "CREATE CONSTRAINT TRIGGER" <+> trgName - <+> "AFTER" <+> trgEvents - <+> "ON" <+> triggerTable +sqlCreateTrigger Trigger {..} = + "CREATE CONSTRAINT TRIGGER" + <+> trgName + <+> "AFTER" + <+> trgEvents + <+> "ON" + <+> triggerTable <+> trgTiming <+> "FOR EACH ROW" <+> trgWhen - <+> "EXECUTE FUNCTION" <+> trgFunction + <+> "EXECUTE FUNCTION" + <+> trgFunction <+> "()" where trgName @@ -134,20 +141,21 @@ sqlCreateTrigger Trigger{..} = trgEvents | triggerEvents == Set.empty = error "Trigger must have at least one event." | otherwise = mintercalate " OR " . map triggerEventName $ Set.toList triggerEvents - trgTiming = let deferrable = (if triggerDeferrable then "" else "NOT") <+> "DEFERRABLE" - deferred = if triggerInitiallyDeferred - then "INITIALLY DEFERRED" - else "INITIALLY IMMEDIATE" - in deferrable <+> deferred + trgTiming = + let deferrable = (if triggerDeferrable then "" else "NOT") <+> "DEFERRABLE" + deferred = + if triggerInitiallyDeferred + then "INITIALLY DEFERRED" + else "INITIALLY IMMEDIATE" + in deferrable <+> deferred trgWhen = maybe "" (\w -> "WHEN (" <+> w <+> ")") triggerWhen trgFunction = triggerFunctionMakeName triggerName - -- | Build an SQL statement that drops a trigger. -- -- @since 1.15.0 sqlDropTrigger :: Trigger -> RawSQL () -sqlDropTrigger Trigger{..} = +sqlDropTrigger Trigger {..} = -- In theory, because the trigger is dependent on its function, it should be enough to -- 'DROP FUNCTION triggerFunction CASCADE'. However, let's make this safe and go with -- the default RESTRICT here. @@ -198,7 +206,7 @@ getDBTriggers tableName = do sqlResult "t.tgname::text" -- name sqlResult "t.tgtype" -- smallint == int2 => (2 bytes) sqlResult "t.tgdeferrable" -- boolean - sqlResult "t.tginitdeferred"-- boolean + sqlResult "t.tginitdeferred" -- boolean -- This gets the entire query that created this trigger. Note that it's decompiled and -- normalized, which means that it's likely not what the user actually typed. For -- example, if the original query had excessive whitespace in it, it won't be in this @@ -215,14 +223,15 @@ getDBTriggers tableName = do where getTrigger :: (String, Int16, Bool, Bool, String, String, String, String) -> (Trigger, RawSQL ()) getTrigger (tgname, tgtype, tgdeferrable, tginitdeferrable, triggerdef, proname, prosrc, tblName) = - ( Trigger { triggerTable = tableName' - , triggerName = triggerBaseName (unsafeSQL tgname) tableName' - , triggerEvents = trgEvents - , triggerDeferrable = tgdeferrable - , triggerInitiallyDeferred = tginitdeferrable - , triggerWhen = tgrWhen - , triggerFunction = unsafeSQL prosrc - } + ( Trigger + { triggerTable = tableName' + , triggerName = triggerBaseName (unsafeSQL tgname) tableName' + , triggerEvents = trgEvents + , triggerDeferrable = tgdeferrable + , triggerInitiallyDeferred = tginitdeferrable + , triggerWhen = tgrWhen + , triggerFunction = unsafeSQL prosrc + } , unsafeSQL proname ) where @@ -233,8 +242,8 @@ getDBTriggers tableName = do parseBetween left right = let (prefix, match) = Text.breakOnEnd left $ Text.pack triggerdef in if Text.null prefix - then Nothing - else Just $ (rawSQL . fst $ Text.breakOn right match) () + then Nothing + else Just $ (rawSQL . fst $ Text.breakOn right match) () -- Get the WHEN part of the query. Anything between WHEN and EXECUTE is what we -- want. The Postgres' grammar guarantees that WHEN and EXECUTE are always next to @@ -247,23 +256,24 @@ getDBTriggers tableName = do -- the same bit set in the underlying tgtype bit field. trgEvents :: Set TriggerEvent trgEvents = - foldl' (\set (mask, event) -> - if testBit tgtype mask - then - Set.insert - (if event == TriggerUpdate - then maybe event trgUpdateOf $ parseBetween "UPDATE OF " " ON" - else event - ) - set - else set - ) - Set.empty - -- Taken from PostgreSQL sources: src/include/catalog/pg_trigger.h: - [ (2, TriggerInsert) -- #define TRIGGER_TYPE_INSERT (1 << 2) - , (3, TriggerDelete) -- #define TRIGGER_TYPE_DELETE (1 << 3) - , (4, TriggerUpdate) -- #define TRIGGER_TYPE_UPDATE (1 << 4) - ] + foldl' + ( \set (mask, event) -> + if testBit tgtype mask + then + Set.insert + ( if event == TriggerUpdate + then maybe event trgUpdateOf $ parseBetween "UPDATE OF " " ON" + else event + ) + set + else set + ) + Set.empty + -- Taken from PostgreSQL sources: src/include/catalog/pg_trigger.h: + [ (2, TriggerInsert) -- #define TRIGGER_TYPE_INSERT (1 << 2) + , (3, TriggerDelete) -- #define TRIGGER_TYPE_DELETE (1 << 3) + , (4, TriggerUpdate) -- #define TRIGGER_TYPE_UPDATE (1 << 4) + ] trgUpdateOf :: RawSQL () -> TriggerEvent trgUpdateOf columnsSQL = @@ -277,10 +287,10 @@ getDBTriggers tableName = do -- -- @since 1.15.0.0 sqlCreateTriggerFunction :: Trigger -> RawSQL () -sqlCreateTriggerFunction Trigger{..} = +sqlCreateTriggerFunction Trigger {..} = "CREATE FUNCTION" <+> triggerFunctionMakeName triggerName - <> "()" + <> "()" <+> "RETURNS TRIGGER" <+> "AS $$" <+> triggerFunction @@ -293,7 +303,7 @@ sqlCreateTriggerFunction Trigger{..} = -- -- @since 1.15.0.0 sqlDropTriggerFunction :: Trigger -> RawSQL () -sqlDropTriggerFunction Trigger{..} = +sqlDropTriggerFunction Trigger {..} = "DROP FUNCTION" <+> triggerFunctionMakeName triggerName <+> "RESTRICT" -- | Make a trigger function name that can be used in SQL. @@ -305,4 +315,3 @@ sqlDropTriggerFunction Trigger{..} = -- @since 1.16.0.0 triggerFunctionMakeName :: RawSQL () -> RawSQL () triggerFunctionMakeName name = "trgfun__" <> name - diff --git a/src/Database/PostgreSQL/PQTypes/SQL/Builder.hs b/src/Database/PostgreSQL/PQTypes/SQL/Builder.hs index f224601..55c7901 100644 --- a/src/Database/PostgreSQL/PQTypes/SQL/Builder.hs +++ b/src/Database/PostgreSQL/PQTypes/SQL/Builder.hs @@ -1,91 +1,88 @@ -{- | - -Module "Database.PostgreSQL.PQTypes.SQL.Builder" offers a nice -monadic DSL for building SQL statements on the fly. Some examples: - ->>> :{ -sqlSelect "documents" $ do - sqlResult "id" - sqlResult "title" - sqlResult "mtime" - sqlOrderBy "documents.mtime DESC" - sqlWhereILike "documents.title" "%pattern%" -:} -SQL " SELECT id, title, mtime FROM documents WHERE (documents.title ILIKE <\"%pattern%\">) ORDER BY documents.mtime DESC " - -@SQL.Builder@ supports SELECT as 'sqlSelect' and data manipulation using -'sqlInsert', 'sqlInsertSelect', 'sqlDelete' and 'sqlUpdate'. - ->>> import Data.Time ->>> let title = "title" :: String ->>> let ctime = read "2020-01-01 00:00:00 UTC" :: UTCTime ->>> :{ -sqlInsert "documents" $ do - sqlSet "title" title - sqlSet "ctime" ctime - sqlResult "id" -:} -SQL " INSERT INTO documents (title, ctime) VALUES (<\"title\">, <2020-01-01 00:00:00 UTC>) RETURNING id" - -The 'sqlInsertSelect' is particulary interesting as it supports INSERT -of values taken from a SELECT clause from same or even different -tables. - -There is a possibility to do multiple inserts at once. Data given by -'sqlSetList' will be inserted multiple times, data given by 'sqlSet' -will be multiplied as many times as needed to cover all inserted rows -(it is common to all rows). If you use multiple 'sqlSetList' then -lists will be made equal in length by appending @DEFAULT@ as fill -element. - ->>> :{ -sqlInsert "documents" $ do - sqlSet "ctime" ctime - sqlSetList "title" ["title1", "title2", "title3"] - sqlResult "id" -:} -SQL " INSERT INTO documents (ctime, title) VALUES (<2020-01-01 00:00:00 UTC>, <\"title1\">) , (<2020-01-01 00:00:00 UTC>, <\"title2\">) , (<2020-01-01 00:00:00 UTC>, <\"title3\">) RETURNING id" - -The above will insert 3 new documents. - -@SQL.Builder@ provides quite a lot of SQL magic, including @ORDER BY@ as -'sqlOrderBy', @GROUP BY@ as 'sqlGroupBy'. - ->>> :{ -sqlSelect "documents" $ do - sqlResult "id" - sqlResult "title" - sqlResult "mtime" - sqlOrderBy "documents.mtime DESC" - sqlOrderBy "documents.title" - sqlGroupBy "documents.status" - sqlJoinOn "users" "documents.user_id = users.id" - sqlWhere $ mkSQL "documents.title ILIKE" "%pattern%" -:} -SQL " SELECT id, title, mtime FROM documents JOIN users ON documents.user_id = users.id WHERE (documents.title ILIKE <\"%pattern%\">) GROUP BY documents.status ORDER BY documents.mtime DESC, documents.title " - -Joins are done by 'sqlJoinOn', 'sqlLeftJoinOn', 'sqlRightJoinOn', -'sqlJoinOn', 'sqlFullJoinOn'. If everything fails use 'sqlJoin' and -'sqlFrom' to set join clause as string. Support for a join grammars as -some kind of abstract syntax data type is lacking. - ->>> :{ -sqlDelete "mails" $ do - sqlWhere "id > 67" -:} -SQL " DELETE FROM mails WHERE (id > 67) " - ->>> :{ -sqlUpdate "document_tags" $ do - sqlSet "value" (123 :: Int) - sqlWhere "name = 'abc'" -:} -SQL " UPDATE document_tags SET value=<123> WHERE (name = 'abc') " - --} - -- TODO: clean this up, add more documentation. +-- | +-- +-- Module "Database.PostgreSQL.PQTypes.SQL.Builder" offers a nice +-- monadic DSL for building SQL statements on the fly. Some examples: +-- +-- >>> :{ +-- sqlSelect "documents" $ do +-- sqlResult "id" +-- sqlResult "title" +-- sqlResult "mtime" +-- sqlOrderBy "documents.mtime DESC" +-- sqlWhereILike "documents.title" "%pattern%" +-- :} +-- SQL " SELECT id, title, mtime FROM documents WHERE (documents.title ILIKE <\"%pattern%\">) ORDER BY documents.mtime DESC " +-- +-- @SQL.Builder@ supports SELECT as 'sqlSelect' and data manipulation using +-- 'sqlInsert', 'sqlInsertSelect', 'sqlDelete' and 'sqlUpdate'. +-- +-- >>> import Data.Time +-- >>> let title = "title" :: String +-- >>> let ctime = read "2020-01-01 00:00:00 UTC" :: UTCTime +-- >>> :{ +-- sqlInsert "documents" $ do +-- sqlSet "title" title +-- sqlSet "ctime" ctime +-- sqlResult "id" +-- :} +-- SQL " INSERT INTO documents (title, ctime) VALUES (<\"title\">, <2020-01-01 00:00:00 UTC>) RETURNING id" +-- +-- The 'sqlInsertSelect' is particulary interesting as it supports INSERT +-- of values taken from a SELECT clause from same or even different +-- tables. +-- +-- There is a possibility to do multiple inserts at once. Data given by +-- 'sqlSetList' will be inserted multiple times, data given by 'sqlSet' +-- will be multiplied as many times as needed to cover all inserted rows +-- (it is common to all rows). If you use multiple 'sqlSetList' then +-- lists will be made equal in length by appending @DEFAULT@ as fill +-- element. +-- +-- >>> :{ +-- sqlInsert "documents" $ do +-- sqlSet "ctime" ctime +-- sqlSetList "title" ["title1", "title2", "title3"] +-- sqlResult "id" +-- :} +-- SQL " INSERT INTO documents (ctime, title) VALUES (<2020-01-01 00:00:00 UTC>, <\"title1\">) , (<2020-01-01 00:00:00 UTC>, <\"title2\">) , (<2020-01-01 00:00:00 UTC>, <\"title3\">) RETURNING id" +-- +-- The above will insert 3 new documents. +-- +-- @SQL.Builder@ provides quite a lot of SQL magic, including @ORDER BY@ as +-- 'sqlOrderBy', @GROUP BY@ as 'sqlGroupBy'. +-- +-- >>> :{ +-- sqlSelect "documents" $ do +-- sqlResult "id" +-- sqlResult "title" +-- sqlResult "mtime" +-- sqlOrderBy "documents.mtime DESC" +-- sqlOrderBy "documents.title" +-- sqlGroupBy "documents.status" +-- sqlJoinOn "users" "documents.user_id = users.id" +-- sqlWhere $ mkSQL "documents.title ILIKE" "%pattern%" +-- :} +-- SQL " SELECT id, title, mtime FROM documents JOIN users ON documents.user_id = users.id WHERE (documents.title ILIKE <\"%pattern%\">) GROUP BY documents.status ORDER BY documents.mtime DESC, documents.title " +-- +-- Joins are done by 'sqlJoinOn', 'sqlLeftJoinOn', 'sqlRightJoinOn', +-- 'sqlJoinOn', 'sqlFullJoinOn'. If everything fails use 'sqlJoin' and +-- 'sqlFrom' to set join clause as string. Support for a join grammars as +-- some kind of abstract syntax data type is lacking. +-- +-- >>> :{ +-- sqlDelete "mails" $ do +-- sqlWhere "id > 67" +-- :} +-- SQL " DELETE FROM mails WHERE (id > 67) " +-- +-- >>> :{ +-- sqlUpdate "document_tags" $ do +-- sqlSet "value" (123 :: Int) +-- sqlWhere "name = 'abc'" +-- :} +-- SQL " UPDATE document_tags SET value=<123> WHERE (name = 'abc') " module Database.PostgreSQL.PQTypes.SQL.Builder ( sqlWhere , sqlWhereEq @@ -102,7 +99,6 @@ module Database.PostgreSQL.PQTypes.SQL.Builder , sqlWhereILike , sqlWhereIsNULL , sqlWhereIsNotNULL - , sqlFrom , sqlJoin , sqlJoinOn @@ -132,25 +128,22 @@ module Database.PostgreSQL.PQTypes.SQL.Builder , sqlUnion , sqlUnionAll , checkAndRememberMaterializationSupport - , sqlSelect , sqlSelect2 - , SqlSelect(..) + , SqlSelect (..) , sqlInsert - , SqlInsert(..) + , SqlInsert (..) , sqlInsertSelect - , SqlInsertSelect(..) + , SqlInsertSelect (..) , sqlUpdate - , SqlUpdate(..) + , SqlUpdate (..) , sqlDelete - , SqlDelete(..) - - , SqlWhereAll(..) + , SqlDelete (..) + , SqlWhereAll (..) , sqlAll - , SqlWhereAny(..) + , SqlWhereAny (..) , sqlAny , sqlWhereAny - , SqlResult , SqlSet , SqlFrom @@ -160,25 +153,23 @@ module Database.PostgreSQL.PQTypes.SQL.Builder , SqlGroupByHaving , SqlOffsetLimit , SqlDistinct - - , SqlCondition(..) + , SqlCondition (..) , sqlGetWhereConditions - - , Sqlable(..) + , Sqlable (..) , sqlOR , sqlConcatComma , sqlConcatAND , sqlConcatOR , parenthesize - , AscDesc(..) + , AscDesc (..) ) - where +where import Control.Monad.Catch import Control.Monad.State import Data.Either -import Data.Int import Data.IORef +import Data.Int import Data.List import Data.Maybe import Data.Monoid.Utils @@ -226,76 +217,76 @@ data Multiplicity a = Single a | Many [a] -- structure of a condition. For now it seems that the only -- interesting case is EXISTS (SELECT ...), because that internal -- SELECT can have explainable clauses. -data SqlCondition = SqlPlainCondition SQL - | SqlExistsCondition SqlSelect - deriving (Typeable, Show) +data SqlCondition + = SqlPlainCondition SQL + | SqlExistsCondition SqlSelect + deriving (Typeable, Show) instance Sqlable SqlCondition where toSQLCommand (SqlPlainCondition a) = a - toSQLCommand (SqlExistsCondition a) = "EXISTS (" <> toSQLCommand (a { sqlSelectResult = ["TRUE"] }) <> ")" + toSQLCommand (SqlExistsCondition a) = "EXISTS (" <> toSQLCommand (a {sqlSelectResult = ["TRUE"]}) <> ")" data SqlSelect = SqlSelect - { sqlSelectFrom :: SQL - , sqlSelectUnion :: [SQL] - , sqlSelectUnionAll :: [SQL] - , sqlSelectDistinct :: Bool - , sqlSelectResult :: [SQL] - , sqlSelectWhere :: [SqlCondition] - , sqlSelectOrderBy :: [SQL] - , sqlSelectGroupBy :: [SQL] - , sqlSelectHaving :: [SQL] - , sqlSelectOffset :: Integer - , sqlSelectLimit :: Integer - , sqlSelectWith :: [(SQL, SQL, Materialized)] + { sqlSelectFrom :: SQL + , sqlSelectUnion :: [SQL] + , sqlSelectUnionAll :: [SQL] + , sqlSelectDistinct :: Bool + , sqlSelectResult :: [SQL] + , sqlSelectWhere :: [SqlCondition] + , sqlSelectOrderBy :: [SQL] + , sqlSelectGroupBy :: [SQL] + , sqlSelectHaving :: [SQL] + , sqlSelectOffset :: Integer + , sqlSelectLimit :: Integer + , sqlSelectWith :: [(SQL, SQL, Materialized)] , sqlSelectRecursiveWith :: Recursive } data SqlUpdate = SqlUpdate - { sqlUpdateWhat :: SQL - , sqlUpdateFrom :: SQL - , sqlUpdateWhere :: [SqlCondition] - , sqlUpdateSet :: [(SQL,SQL)] - , sqlUpdateResult :: [SQL] - , sqlUpdateWith :: [(SQL, SQL, Materialized)] + { sqlUpdateWhat :: SQL + , sqlUpdateFrom :: SQL + , sqlUpdateWhere :: [SqlCondition] + , sqlUpdateSet :: [(SQL, SQL)] + , sqlUpdateResult :: [SQL] + , sqlUpdateWith :: [(SQL, SQL, Materialized)] , sqlUpdateRecursiveWith :: Recursive } data SqlInsert = SqlInsert - { sqlInsertWhat :: SQL - , sqlInsertOnConflict :: Maybe (SQL, Maybe SQL) - , sqlInsertSet :: [(SQL, Multiplicity SQL)] - , sqlInsertResult :: [SQL] - , sqlInsertWith :: [(SQL, SQL, Materialized)] + { sqlInsertWhat :: SQL + , sqlInsertOnConflict :: Maybe (SQL, Maybe SQL) + , sqlInsertSet :: [(SQL, Multiplicity SQL)] + , sqlInsertResult :: [SQL] + , sqlInsertWith :: [(SQL, SQL, Materialized)] , sqlInsertRecursiveWith :: Recursive } data SqlInsertSelect = SqlInsertSelect - { sqlInsertSelectWhat :: SQL - , sqlInsertSelectOnConflict :: Maybe (SQL, Maybe SQL) - , sqlInsertSelectDistinct :: Bool - , sqlInsertSelectSet :: [(SQL, SQL)] - , sqlInsertSelectResult :: [SQL] - , sqlInsertSelectFrom :: SQL - , sqlInsertSelectWhere :: [SqlCondition] - , sqlInsertSelectOrderBy :: [SQL] - , sqlInsertSelectGroupBy :: [SQL] - , sqlInsertSelectHaving :: [SQL] - , sqlInsertSelectOffset :: Integer - , sqlInsertSelectLimit :: Integer - , sqlInsertSelectWith :: [(SQL, SQL, Materialized)] + { sqlInsertSelectWhat :: SQL + , sqlInsertSelectOnConflict :: Maybe (SQL, Maybe SQL) + , sqlInsertSelectDistinct :: Bool + , sqlInsertSelectSet :: [(SQL, SQL)] + , sqlInsertSelectResult :: [SQL] + , sqlInsertSelectFrom :: SQL + , sqlInsertSelectWhere :: [SqlCondition] + , sqlInsertSelectOrderBy :: [SQL] + , sqlInsertSelectGroupBy :: [SQL] + , sqlInsertSelectHaving :: [SQL] + , sqlInsertSelectOffset :: Integer + , sqlInsertSelectLimit :: Integer + , sqlInsertSelectWith :: [(SQL, SQL, Materialized)] , sqlInsertSelectRecursiveWith :: Recursive } data SqlDelete = SqlDelete - { sqlDeleteFrom :: SQL - , sqlDeleteUsing :: SQL - , sqlDeleteWhere :: [SqlCondition] - , sqlDeleteResult :: [SQL] - , sqlDeleteWith :: [(SQL, SQL, Materialized)] + { sqlDeleteFrom :: SQL + , sqlDeleteUsing :: SQL + , sqlDeleteWhere :: [SqlCondition] + , sqlDeleteResult :: [SQL] + , sqlDeleteWith :: [(SQL, SQL, Materialized)] , sqlDeleteRecursiveWith :: Recursive } - -- | Type representing a set of conditions that are joined by 'AND'. -- -- When no conditions are given, the result is 'TRUE'. @@ -335,7 +326,7 @@ emitClause :: Sqlable sql => SQL -> sql -> SQL emitClause name s = case toSQLCommand s of sql | isSqlEmpty sql -> "" - | otherwise -> name <+> sql + | otherwise -> name <+> sql emitClausesSep :: SQL -> SQL -> [SQL] -> SQL emitClausesSep _name _sep [] = mempty @@ -361,101 +352,109 @@ instance IsSQL SqlDelete where withSQL = withSQL . toSQLCommand instance Sqlable SqlSelect where - toSQLCommand cmd = smconcat - [ emitClausesSepComma (recursiveClause $ sqlSelectRecursiveWith cmd) $ - map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlSelectWith cmd) - , if hasUnion || hasUnionAll - then emitClausesSep "" unionKeyword (mainSelectClause : unionCmd) - else mainSelectClause - , emitClausesSepComma "GROUP BY" (sqlSelectGroupBy cmd) - , emitClausesSep "HAVING" "AND" (sqlSelectHaving cmd) - , orderByClause - , if sqlSelectOffset cmd > 0 - then unsafeSQL ("OFFSET " ++ show (sqlSelectOffset cmd)) - else "" - , if sqlSelectLimit cmd >= 0 - then limitClause - else "" - ] - where - mainSelectClause = smconcat - [ "SELECT" <+> (if sqlSelectDistinct cmd then "DISTINCT" else mempty) - , sqlConcatComma (sqlSelectResult cmd) - , emitClause "FROM" (sqlSelectFrom cmd) - , emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlSelectWhere cmd) - -- If there's a union, the result is sorted and has a limit, applying - -- the order and limit to the main subquery won't reduce the overall - -- query result, but might reduce its processing time. - , if hasUnion && not (null $ sqlSelectOrderBy cmd) && sqlSelectLimit cmd >= 0 - then smconcat [orderByClause, limitClause] + toSQLCommand cmd = + smconcat + [ emitClausesSepComma (recursiveClause $ sqlSelectRecursiveWith cmd) $ + map (\(name, command, mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlSelectWith cmd) + , if hasUnion || hasUnionAll + then emitClausesSep "" unionKeyword (mainSelectClause : unionCmd) + else mainSelectClause + , emitClausesSepComma "GROUP BY" (sqlSelectGroupBy cmd) + , emitClausesSep "HAVING" "AND" (sqlSelectHaving cmd) + , orderByClause + , if sqlSelectOffset cmd > 0 + then unsafeSQL ("OFFSET " ++ show (sqlSelectOffset cmd)) else "" - ] - - hasUnion = not . null $ sqlSelectUnion cmd - hasUnionAll = not . null $ sqlSelectUnionAll cmd + , if sqlSelectLimit cmd >= 0 + then limitClause + else "" + ] + where + mainSelectClause = + smconcat + [ "SELECT" <+> (if sqlSelectDistinct cmd then "DISTINCT" else mempty) + , sqlConcatComma (sqlSelectResult cmd) + , emitClause "FROM" (sqlSelectFrom cmd) + , emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlSelectWhere cmd) + , -- If there's a union, the result is sorted and has a limit, applying + -- the order and limit to the main subquery won't reduce the overall + -- query result, but might reduce its processing time. + if hasUnion && not (null $ sqlSelectOrderBy cmd) && sqlSelectLimit cmd >= 0 + then smconcat [orderByClause, limitClause] + else "" + ] + + hasUnion = not . null $ sqlSelectUnion cmd + hasUnionAll = not . null $ sqlSelectUnionAll cmd unionKeyword = case (hasUnion, hasUnionAll) of - (False, True) -> "UNION ALL" - (True, False) -> "UNION" - -- False, False is caught by the (hasUnion || hasUnionAll) above. - -- Hence, the catch-all is implicitly for (True, True). - _ -> error "Having both `sqlSelectUnion` and `sqlSelectUnionAll` is not supported at the moment." + (False, True) -> "UNION ALL" + (True, False) -> "UNION" + -- False, False is caught by the (hasUnion || hasUnionAll) above. + -- Hence, the catch-all is implicitly for (True, True). + _ -> error "Having both `sqlSelectUnion` and `sqlSelectUnionAll` is not supported at the moment." unionCmd = case (hasUnion, hasUnionAll) of - (False, True) -> sqlSelectUnionAll cmd - (True, False) -> sqlSelectUnion cmd - -- False, False is caught by the (hasUnion || hasUnionAll) above. - -- Hence, the catch-all is implicitly for (True, True). - _ -> error "Having both `sqlSelectUnion` and `sqlSelectUnionAll` is not supported at the moment." + (False, True) -> sqlSelectUnionAll cmd + (True, False) -> sqlSelectUnion cmd + -- False, False is caught by the (hasUnion || hasUnionAll) above. + -- Hence, the catch-all is implicitly for (True, True). + _ -> error "Having both `sqlSelectUnion` and `sqlSelectUnionAll` is not supported at the moment." orderByClause = emitClausesSepComma "ORDER BY" $ sqlSelectOrderBy cmd - limitClause = unsafeSQL $ "LIMIT" <+> show (sqlSelectLimit cmd) + limitClause = unsafeSQL $ "LIMIT" <+> show (sqlSelectLimit cmd) emitClauseOnConflictForInsert :: Maybe (SQL, Maybe SQL) -> SQL emitClauseOnConflictForInsert = \case - Nothing -> "" - Just (condition, maction) -> emitClause "ON CONFLICT" $ - condition <+> "DO" <+> fromMaybe "NOTHING" maction + Nothing -> "" + Just (condition, maction) -> + emitClause "ON CONFLICT" $ + condition <+> "DO" <+> fromMaybe "NOTHING" maction instance Sqlable SqlInsert where toSQLCommand cmd = - emitClausesSepComma (recursiveClause $ sqlInsertRecursiveWith cmd) - (map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlInsertWith cmd)) <+> - "INSERT INTO" <+> sqlInsertWhat cmd <+> - parenthesize (sqlConcatComma (map fst (sqlInsertSet cmd))) <+> - emitClausesSep "VALUES" "," (map sqlConcatComma (transpose (map (makeLongEnough . snd) (sqlInsertSet cmd)))) <+> - emitClauseOnConflictForInsert (sqlInsertOnConflict cmd) <+> - emitClausesSepComma "RETURNING" (sqlInsertResult cmd) - where - -- this is the longest list of values - longest = maximum (1 : map (lengthOfEither . snd) (sqlInsertSet cmd)) - lengthOfEither (Single _) = 1 - lengthOfEither (Many x) = length x - makeLongEnough (Single x) = replicate longest x - makeLongEnough (Many x) = take longest (x ++ repeat "DEFAULT") + emitClausesSepComma + (recursiveClause $ sqlInsertRecursiveWith cmd) + (map (\(name, command, mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlInsertWith cmd)) + <+> "INSERT INTO" + <+> sqlInsertWhat cmd + <+> parenthesize (sqlConcatComma (map fst (sqlInsertSet cmd))) + <+> emitClausesSep "VALUES" "," (map sqlConcatComma (transpose (map (makeLongEnough . snd) (sqlInsertSet cmd)))) + <+> emitClauseOnConflictForInsert (sqlInsertOnConflict cmd) + <+> emitClausesSepComma "RETURNING" (sqlInsertResult cmd) + where + -- this is the longest list of values + longest = maximum (1 : map (lengthOfEither . snd) (sqlInsertSet cmd)) + lengthOfEither (Single _) = 1 + lengthOfEither (Many x) = length x + makeLongEnough (Single x) = replicate longest x + makeLongEnough (Many x) = take longest (x ++ repeat "DEFAULT") instance Sqlable SqlInsertSelect where - toSQLCommand cmd = smconcat - -- WITH clause needs to be at the top level, so we emit it here and not - -- include it in the SqlSelect below. - [ emitClausesSepComma (recursiveClause $ sqlInsertSelectRecursiveWith cmd) $ - map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlInsertSelectWith cmd) - , "INSERT INTO" <+> sqlInsertSelectWhat cmd - , parenthesize . sqlConcatComma . map fst $ sqlInsertSelectSet cmd - , parenthesize . toSQLCommand $ SqlSelect { sqlSelectFrom = sqlInsertSelectFrom cmd - , sqlSelectUnion = [] - , sqlSelectUnionAll = [] - , sqlSelectDistinct = sqlInsertSelectDistinct cmd - , sqlSelectResult = snd <$> sqlInsertSelectSet cmd - , sqlSelectWhere = sqlInsertSelectWhere cmd - , sqlSelectOrderBy = sqlInsertSelectOrderBy cmd - , sqlSelectGroupBy = sqlInsertSelectGroupBy cmd - , sqlSelectHaving = sqlInsertSelectHaving cmd - , sqlSelectOffset = sqlInsertSelectOffset cmd - , sqlSelectLimit = sqlInsertSelectLimit cmd - , sqlSelectWith = [] - , sqlSelectRecursiveWith = NonRecursive - } - , emitClauseOnConflictForInsert (sqlInsertSelectOnConflict cmd) - , emitClausesSepComma "RETURNING" $ sqlInsertSelectResult cmd - ] + toSQLCommand cmd = + smconcat + -- WITH clause needs to be at the top level, so we emit it here and not + -- include it in the SqlSelect below. + [ emitClausesSepComma (recursiveClause $ sqlInsertSelectRecursiveWith cmd) $ + map (\(name, command, mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlInsertSelectWith cmd) + , "INSERT INTO" <+> sqlInsertSelectWhat cmd + , parenthesize . sqlConcatComma . map fst $ sqlInsertSelectSet cmd + , parenthesize . toSQLCommand $ + SqlSelect + { sqlSelectFrom = sqlInsertSelectFrom cmd + , sqlSelectUnion = [] + , sqlSelectUnionAll = [] + , sqlSelectDistinct = sqlInsertSelectDistinct cmd + , sqlSelectResult = snd <$> sqlInsertSelectSet cmd + , sqlSelectWhere = sqlInsertSelectWhere cmd + , sqlSelectOrderBy = sqlInsertSelectOrderBy cmd + , sqlSelectGroupBy = sqlInsertSelectGroupBy cmd + , sqlSelectHaving = sqlInsertSelectHaving cmd + , sqlSelectOffset = sqlInsertSelectOffset cmd + , sqlSelectLimit = sqlInsertSelectLimit cmd + , sqlSelectWith = [] + , sqlSelectRecursiveWith = NonRecursive + } + , emitClauseOnConflictForInsert (sqlInsertSelectOnConflict cmd) + , emitClausesSepComma "RETURNING" $ sqlInsertSelectResult cmd + ] -- This function has to be called as one of first things in your program -- for the library to make sure that it is aware if the "WITH MATERIALIZED" @@ -480,27 +479,32 @@ materializedClause Materialized = if isWithMaterializedSupported then "MATERIALI materializedClause NonMaterialized = if isWithMaterializedSupported then "NOT MATERIALIZED" else "" recursiveClause :: Recursive -> SQL -recursiveClause Recursive = "WITH RECURSIVE" +recursiveClause Recursive = "WITH RECURSIVE" recursiveClause NonRecursive = "WITH" instance Sqlable SqlUpdate where toSQLCommand cmd = - emitClausesSepComma (recursiveClause $ sqlUpdateRecursiveWith cmd) - (map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlUpdateWith cmd)) <+> - "UPDATE" <+> sqlUpdateWhat cmd <+> "SET" <+> - sqlConcatComma (map (\(name, command) -> name <> "=" <> command) (sqlUpdateSet cmd)) <+> - emitClause "FROM" (sqlUpdateFrom cmd) <+> - emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlUpdateWhere cmd) <+> - emitClausesSepComma "RETURNING" (sqlUpdateResult cmd) + emitClausesSepComma + (recursiveClause $ sqlUpdateRecursiveWith cmd) + (map (\(name, command, mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlUpdateWith cmd)) + <+> "UPDATE" + <+> sqlUpdateWhat cmd + <+> "SET" + <+> sqlConcatComma (map (\(name, command) -> name <> "=" <> command) (sqlUpdateSet cmd)) + <+> emitClause "FROM" (sqlUpdateFrom cmd) + <+> emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlUpdateWhere cmd) + <+> emitClausesSepComma "RETURNING" (sqlUpdateResult cmd) instance Sqlable SqlDelete where toSQLCommand cmd = - emitClausesSepComma (recursiveClause $ sqlDeleteRecursiveWith cmd) - (map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlDeleteWith cmd)) <+> - "DELETE FROM" <+> sqlDeleteFrom cmd <+> - emitClause "USING" (sqlDeleteUsing cmd) <+> - emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlDeleteWhere cmd) <+> - emitClausesSepComma "RETURNING" (sqlDeleteResult cmd) + emitClausesSepComma + (recursiveClause $ sqlDeleteRecursiveWith cmd) + (map (\(name, command, mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlDeleteWith cmd)) + <+> "DELETE FROM" + <+> sqlDeleteFrom cmd + <+> emitClause "USING" (sqlDeleteUsing cmd) + <+> emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlDeleteWhere cmd) + <+> emitClausesSepComma "RETURNING" (sqlDeleteResult cmd) instance Sqlable SqlWhereAll where toSQLCommand cmd = case sqlWhereAllWhere cmd of @@ -528,22 +532,25 @@ sqlInsert table refine = sqlInsertSelect :: SQL -> SQL -> State SqlInsertSelect () -> SqlInsertSelect sqlInsertSelect table from refine = - execState refine (SqlInsertSelect - { sqlInsertSelectWhat = table - , sqlInsertSelectOnConflict = Nothing - , sqlInsertSelectDistinct = False - , sqlInsertSelectSet = [] - , sqlInsertSelectResult = [] - , sqlInsertSelectFrom = from - , sqlInsertSelectWhere = [] - , sqlInsertSelectOrderBy = [] - , sqlInsertSelectGroupBy = [] - , sqlInsertSelectHaving = [] - , sqlInsertSelectOffset = 0 - , sqlInsertSelectLimit = -1 - , sqlInsertSelectWith = [] - , sqlInsertSelectRecursiveWith = NonRecursive - }) + execState + refine + ( SqlInsertSelect + { sqlInsertSelectWhat = table + , sqlInsertSelectOnConflict = Nothing + , sqlInsertSelectDistinct = False + , sqlInsertSelectSet = [] + , sqlInsertSelectResult = [] + , sqlInsertSelectFrom = from + , sqlInsertSelectWhere = [] + , sqlInsertSelectOrderBy = [] + , sqlInsertSelectGroupBy = [] + , sqlInsertSelectHaving = [] + , sqlInsertSelectOffset = 0 + , sqlInsertSelectLimit = -1 + , sqlInsertSelectWith = [] + , sqlInsertSelectRecursiveWith = NonRecursive + } + ) sqlUpdate :: SQL -> State SqlUpdate () -> SqlUpdate sqlUpdate table refine = @@ -551,14 +558,17 @@ sqlUpdate table refine = sqlDelete :: SQL -> State SqlDelete () -> SqlDelete sqlDelete table refine = - execState refine (SqlDelete { sqlDeleteFrom = table - , sqlDeleteUsing = mempty - , sqlDeleteWhere = [] - , sqlDeleteResult = [] - , sqlDeleteWith = [] - , sqlDeleteRecursiveWith = NonRecursive - }) - + execState + refine + ( SqlDelete + { sqlDeleteFrom = table + , sqlDeleteUsing = mempty + , sqlDeleteWhere = [] + , sqlDeleteResult = [] + , sqlDeleteWith = [] + , sqlDeleteRecursiveWith = NonRecursive + } + ) data Materialized = Materialized | NonMaterialized data Recursive = Recursive | NonRecursive @@ -569,22 +579,22 @@ data Recursive = Recursive | NonRecursive instance Semigroup Recursive where _ <> Recursive = Recursive Recursive <> _ = Recursive - _ <> _ = NonRecursive + _ <> _ = NonRecursive class SqlWith a where sqlWith1 :: a -> SQL -> SQL -> Materialized -> Recursive -> a instance SqlWith SqlSelect where - sqlWith1 cmd name sql mat recurse = cmd { sqlSelectWith = sqlSelectWith cmd ++ [(name,sql,mat)], sqlSelectRecursiveWith = recurse <> sqlSelectRecursiveWith cmd } + sqlWith1 cmd name sql mat recurse = cmd {sqlSelectWith = sqlSelectWith cmd ++ [(name, sql, mat)], sqlSelectRecursiveWith = recurse <> sqlSelectRecursiveWith cmd} instance SqlWith SqlInsertSelect where - sqlWith1 cmd name sql mat recurse = cmd { sqlInsertSelectWith = sqlInsertSelectWith cmd ++ [(name,sql,mat)], sqlInsertSelectRecursiveWith = recurse <> sqlInsertSelectRecursiveWith cmd } + sqlWith1 cmd name sql mat recurse = cmd {sqlInsertSelectWith = sqlInsertSelectWith cmd ++ [(name, sql, mat)], sqlInsertSelectRecursiveWith = recurse <> sqlInsertSelectRecursiveWith cmd} instance SqlWith SqlUpdate where - sqlWith1 cmd name sql mat recurse = cmd { sqlUpdateWith = sqlUpdateWith cmd ++ [(name,sql,mat)], sqlUpdateRecursiveWith = recurse <> sqlUpdateRecursiveWith cmd } + sqlWith1 cmd name sql mat recurse = cmd {sqlUpdateWith = sqlUpdateWith cmd ++ [(name, sql, mat)], sqlUpdateRecursiveWith = recurse <> sqlUpdateRecursiveWith cmd} instance SqlWith SqlDelete where - sqlWith1 cmd name sql mat recurse = cmd { sqlDeleteWith = sqlDeleteWith cmd ++ [(name,sql,mat)], sqlDeleteRecursiveWith = recurse <> sqlDeleteRecursiveWith cmd } + sqlWith1 cmd name sql mat recurse = cmd {sqlDeleteWith = sqlDeleteWith cmd ++ [(name, sql, mat)], sqlDeleteRecursiveWith = recurse <> sqlDeleteRecursiveWith cmd} sqlWith :: (MonadState v m, SqlWith v, Sqlable s) => SQL -> s -> m () sqlWith name sql = modify (\cmd -> sqlWith1 cmd name (toSQLCommand sql) NonMaterialized NonRecursive) @@ -599,41 +609,41 @@ sqlWithRecursive name sql = modify (\cmd -> sqlWith1 cmd name (toSQLCommand sql) -- | Note: WHERE clause of the main SELECT is treated specially, i.e. it only -- applies to the main SELECT, not the whole union. sqlUnion :: (MonadState SqlSelect m, Sqlable sql) => [sql] -> m () -sqlUnion sqls = modify (\cmd -> cmd { sqlSelectUnion = map toSQLCommand sqls }) +sqlUnion sqls = modify (\cmd -> cmd {sqlSelectUnion = map toSQLCommand sqls}) -- | Note: WHERE clause of the main SELECT is treated specially, i.e. it only -- applies to the main SELECT, not the whole union. -- -- @since 1.16.4.0 sqlUnionAll :: (MonadState SqlSelect m, Sqlable sql) => [sql] -> m () -sqlUnionAll sqls = modify (\cmd -> cmd { sqlSelectUnionAll = map toSQLCommand sqls }) +sqlUnionAll sqls = modify (\cmd -> cmd {sqlSelectUnionAll = map toSQLCommand sqls}) class SqlWhere a where sqlWhere1 :: a -> SqlCondition -> a sqlGetWhereConditions :: a -> [SqlCondition] instance SqlWhere SqlSelect where - sqlWhere1 cmd cond = cmd { sqlSelectWhere = sqlSelectWhere cmd ++ [cond] } + sqlWhere1 cmd cond = cmd {sqlSelectWhere = sqlSelectWhere cmd ++ [cond]} sqlGetWhereConditions = sqlSelectWhere instance SqlWhere SqlInsertSelect where - sqlWhere1 cmd cond = cmd { sqlInsertSelectWhere = sqlInsertSelectWhere cmd ++ [cond] } + sqlWhere1 cmd cond = cmd {sqlInsertSelectWhere = sqlInsertSelectWhere cmd ++ [cond]} sqlGetWhereConditions = sqlInsertSelectWhere instance SqlWhere SqlUpdate where - sqlWhere1 cmd cond = cmd { sqlUpdateWhere = sqlUpdateWhere cmd ++ [cond] } + sqlWhere1 cmd cond = cmd {sqlUpdateWhere = sqlUpdateWhere cmd ++ [cond]} sqlGetWhereConditions = sqlUpdateWhere instance SqlWhere SqlDelete where - sqlWhere1 cmd cond = cmd { sqlDeleteWhere = sqlDeleteWhere cmd ++ [cond] } + sqlWhere1 cmd cond = cmd {sqlDeleteWhere = sqlDeleteWhere cmd ++ [cond]} sqlGetWhereConditions = sqlDeleteWhere instance SqlWhere SqlWhereAll where - sqlWhere1 cmd cond = cmd { sqlWhereAllWhere = sqlWhereAllWhere cmd ++ [cond] } + sqlWhere1 cmd cond = cmd {sqlWhereAllWhere = sqlWhereAllWhere cmd ++ [cond]} sqlGetWhereConditions = sqlWhereAllWhere instance SqlWhere SqlWhereAny where - sqlWhere1 cmd cond = cmd { sqlWhereAnyWhere = sqlWhereAnyWhere cmd ++ [cond] } + sqlWhere1 cmd cond = cmd {sqlWhereAnyWhere = sqlWhereAnyWhere cmd ++ [cond]} sqlGetWhereConditions = sqlWhereAnyWhere -- | The @WHERE@ part of an SQL query. See above for a usage @@ -654,7 +664,7 @@ sqlWhereLike :: (MonadState v m, SqlWhere v, Show a, ToSQL a) => SQL -> a -> m ( sqlWhereLike name value = sqlWhere $ name <+> "LIKE" value sqlWhereILike :: (MonadState v m, SqlWhere v, Show a, ToSQL a) => SQL -> a -> m () -sqlWhereILike name value = sqlWhere $ name <+> "ILIKE" value +sqlWhereILike name value = sqlWhere $ name <+> "ILIKE" value -- | Similar to 'sqlWhereIn', but uses @ANY@ instead of @SELECT UNNEST@. sqlWhereEqualsAny :: (MonadState v m, SqlWhere v, Show a, ToSQL a) => SQL -> [a] -> m () @@ -683,7 +693,7 @@ sqlWhereExists sql = do sqlWhereNotExists :: (MonadState v m, SqlWhere v) => SqlSelect -> m () sqlWhereNotExists sqlSelectD = do - sqlWhere ("NOT EXISTS (" <+> toSQLCommand (sqlSelectD { sqlSelectResult = ["TRUE"] }) <+> ")") + sqlWhere ("NOT EXISTS (" <+> toSQLCommand (sqlSelectD {sqlSelectResult = ["TRUE"]}) <+> ")") sqlWhereIsNULL :: (MonadState v m, SqlWhere v) => SQL -> m () sqlWhereIsNULL col = sqlWhere $ col <+> "IS NULL" @@ -719,16 +729,16 @@ class SqlFrom a where sqlFrom1 :: a -> SQL -> a instance SqlFrom SqlSelect where - sqlFrom1 cmd sql = cmd { sqlSelectFrom = sqlSelectFrom cmd <+> sql } + sqlFrom1 cmd sql = cmd {sqlSelectFrom = sqlSelectFrom cmd <+> sql} instance SqlFrom SqlInsertSelect where - sqlFrom1 cmd sql = cmd { sqlInsertSelectFrom = sqlInsertSelectFrom cmd <+> sql } + sqlFrom1 cmd sql = cmd {sqlInsertSelectFrom = sqlInsertSelectFrom cmd <+> sql} instance SqlFrom SqlUpdate where - sqlFrom1 cmd sql = cmd { sqlUpdateFrom = sqlUpdateFrom cmd <+> sql } + sqlFrom1 cmd sql = cmd {sqlUpdateFrom = sqlUpdateFrom cmd <+> sql} instance SqlFrom SqlDelete where - sqlFrom1 cmd sql = cmd { sqlDeleteUsing = sqlDeleteUsing cmd <+> sql } + sqlFrom1 cmd sql = cmd {sqlDeleteUsing = sqlDeleteUsing cmd <+> sql} sqlFrom :: (MonadState v m, SqlFrom v) => SQL -> m () sqlFrom sql = modify (\cmd -> sqlFrom1 cmd sql) @@ -737,46 +747,58 @@ sqlJoin :: (MonadState v m, SqlFrom v) => SQL -> m () sqlJoin table = sqlFrom (", " <+> table) sqlJoinOn :: (MonadState v m, SqlFrom v) => SQL -> SQL -> m () -sqlJoinOn table condition = sqlFrom (" JOIN " <+> - table <+> - " ON " <+> - condition) +sqlJoinOn table condition = + sqlFrom + ( " JOIN " + <+> table + <+> " ON " + <+> condition + ) sqlLeftJoinOn :: (MonadState v m, SqlFrom v) => SQL -> SQL -> m () -sqlLeftJoinOn table condition = sqlFrom (" LEFT JOIN " <+> - table <+> - " ON " <+> - condition) +sqlLeftJoinOn table condition = + sqlFrom + ( " LEFT JOIN " + <+> table + <+> " ON " + <+> condition + ) sqlRightJoinOn :: (MonadState v m, SqlFrom v) => SQL -> SQL -> m () -sqlRightJoinOn table condition = sqlFrom (" RIGHT JOIN " <+> - table <+> - " ON " <+> - condition) +sqlRightJoinOn table condition = + sqlFrom + ( " RIGHT JOIN " + <+> table + <+> " ON " + <+> condition + ) sqlFullJoinOn :: (MonadState v m, SqlFrom v) => SQL -> SQL -> m () -sqlFullJoinOn table condition = sqlFrom (" FULL JOIN " <+> - table <+> - " ON " <+> - condition) +sqlFullJoinOn table condition = + sqlFrom + ( " FULL JOIN " + <+> table + <+> " ON " + <+> condition + ) class SqlSet a where sqlSet1 :: a -> SQL -> SQL -> a instance SqlSet SqlUpdate where - sqlSet1 cmd name v = cmd { sqlUpdateSet = sqlUpdateSet cmd ++ [(name, v)] } + sqlSet1 cmd name v = cmd {sqlUpdateSet = sqlUpdateSet cmd ++ [(name, v)]} instance SqlSet SqlInsert where - sqlSet1 cmd name v = cmd { sqlInsertSet = sqlInsertSet cmd ++ [(name, Single v)] } + sqlSet1 cmd name v = cmd {sqlInsertSet = sqlInsertSet cmd ++ [(name, Single v)]} instance SqlSet SqlInsertSelect where - sqlSet1 cmd name v = cmd { sqlInsertSelectSet = sqlInsertSelectSet cmd ++ [(name, v)] } + sqlSet1 cmd name v = cmd {sqlInsertSelectSet = sqlInsertSelectSet cmd ++ [(name, v)]} sqlSetCmd :: (MonadState v m, SqlSet v) => SQL -> SQL -> m () sqlSetCmd name sql = modify (\cmd -> sqlSet1 cmd name sql) -sqlSetCmdList :: (MonadState SqlInsert m) => SQL -> [SQL] -> m () -sqlSetCmdList name as = modify (\cmd -> cmd { sqlInsertSet = sqlInsertSet cmd ++ [(name, Many as)] }) +sqlSetCmdList :: MonadState SqlInsert m => SQL -> [SQL] -> m () +sqlSetCmdList name as = modify (\cmd -> cmd {sqlInsertSet = sqlInsertSet cmd ++ [(name, Many as)]}) sqlSet :: (MonadState v m, SqlSet v, Show a, ToSQL a) => SQL -> a -> m () sqlSet name a = sqlSetCmd name (sqlParam a) @@ -800,19 +822,19 @@ class SqlOnConflict a where instance SqlOnConflict SqlInsert where sqlOnConflictDoNothing1 cmd = - cmd { sqlInsertOnConflict = Just ("", Nothing) } + cmd {sqlInsertOnConflict = Just ("", Nothing)} sqlOnConflictOnColumns1 cmd columns sql = - cmd { sqlInsertOnConflict = Just (parenthesize $ sqlConcatComma columns, Just $ toSQLCommand sql) } + cmd {sqlInsertOnConflict = Just (parenthesize $ sqlConcatComma columns, Just $ toSQLCommand sql)} sqlOnConflictOnColumnsDoNothing1 cmd columns = - cmd { sqlInsertOnConflict = Just (parenthesize $ sqlConcatComma columns, Nothing) } + cmd {sqlInsertOnConflict = Just (parenthesize $ sqlConcatComma columns, Nothing)} instance SqlOnConflict SqlInsertSelect where sqlOnConflictDoNothing1 cmd = - cmd { sqlInsertSelectOnConflict = Just ("", Nothing) } + cmd {sqlInsertSelectOnConflict = Just ("", Nothing)} sqlOnConflictOnColumns1 cmd columns sql = - cmd { sqlInsertSelectOnConflict = Just (parenthesize $ sqlConcatComma columns, Just $ toSQLCommand sql) } + cmd {sqlInsertSelectOnConflict = Just (parenthesize $ sqlConcatComma columns, Just $ toSQLCommand sql)} sqlOnConflictOnColumnsDoNothing1 cmd columns = - cmd { sqlInsertSelectOnConflict = Just (parenthesize $ sqlConcatComma columns, Nothing) } + cmd {sqlInsertSelectOnConflict = Just (parenthesize $ sqlConcatComma columns, Nothing)} sqlOnConflictDoNothing :: (MonadState v m, SqlOnConflict v) => m () sqlOnConflictDoNothing = modify sqlOnConflictDoNothing1 @@ -827,19 +849,19 @@ class SqlResult a where sqlResult1 :: a -> SQL -> a instance SqlResult SqlSelect where - sqlResult1 cmd sql = cmd { sqlSelectResult = sqlSelectResult cmd ++ [sql] } + sqlResult1 cmd sql = cmd {sqlSelectResult = sqlSelectResult cmd ++ [sql]} instance SqlResult SqlInsert where - sqlResult1 cmd sql = cmd { sqlInsertResult = sqlInsertResult cmd ++ [sql] } + sqlResult1 cmd sql = cmd {sqlInsertResult = sqlInsertResult cmd ++ [sql]} instance SqlResult SqlInsertSelect where - sqlResult1 cmd sql = cmd { sqlInsertSelectResult = sqlInsertSelectResult cmd ++ [sql] } + sqlResult1 cmd sql = cmd {sqlInsertSelectResult = sqlInsertSelectResult cmd ++ [sql]} instance SqlResult SqlUpdate where - sqlResult1 cmd sql = cmd { sqlUpdateResult = sqlUpdateResult cmd ++ [sql] } + sqlResult1 cmd sql = cmd {sqlUpdateResult = sqlUpdateResult cmd ++ [sql]} instance SqlResult SqlDelete where - sqlResult1 cmd sql = cmd { sqlDeleteResult = sqlDeleteResult cmd ++ [sql] } + sqlResult1 cmd sql = cmd {sqlDeleteResult = sqlDeleteResult cmd ++ [sql]} sqlResult :: (MonadState v m, SqlResult v) => SQL -> m () sqlResult sql = modify (\cmd -> sqlResult1 cmd sql) @@ -848,11 +870,10 @@ class SqlOrderBy a where sqlOrderBy1 :: a -> SQL -> a instance SqlOrderBy SqlSelect where - sqlOrderBy1 cmd sql = cmd { sqlSelectOrderBy = sqlSelectOrderBy cmd ++ [sql] } + sqlOrderBy1 cmd sql = cmd {sqlSelectOrderBy = sqlSelectOrderBy cmd ++ [sql]} instance SqlOrderBy SqlInsertSelect where - sqlOrderBy1 cmd sql = cmd { sqlInsertSelectOrderBy = sqlInsertSelectOrderBy cmd ++ [sql] } - + sqlOrderBy1 cmd sql = cmd {sqlInsertSelectOrderBy = sqlInsertSelectOrderBy cmd ++ [sql]} sqlOrderBy :: (MonadState v m, SqlOrderBy v) => SQL -> m () sqlOrderBy sql = modify (\cmd -> sqlOrderBy1 cmd sql) @@ -862,12 +883,12 @@ class SqlGroupByHaving a where sqlHaving1 :: a -> SQL -> a instance SqlGroupByHaving SqlSelect where - sqlGroupBy1 cmd sql = cmd { sqlSelectGroupBy = sqlSelectGroupBy cmd ++ [sql] } - sqlHaving1 cmd sql = cmd { sqlSelectHaving = sqlSelectHaving cmd ++ [sql] } + sqlGroupBy1 cmd sql = cmd {sqlSelectGroupBy = sqlSelectGroupBy cmd ++ [sql]} + sqlHaving1 cmd sql = cmd {sqlSelectHaving = sqlSelectHaving cmd ++ [sql]} instance SqlGroupByHaving SqlInsertSelect where - sqlGroupBy1 cmd sql = cmd { sqlInsertSelectGroupBy = sqlInsertSelectGroupBy cmd ++ [sql] } - sqlHaving1 cmd sql = cmd { sqlInsertSelectHaving = sqlInsertSelectHaving cmd ++ [sql] } + sqlGroupBy1 cmd sql = cmd {sqlInsertSelectGroupBy = sqlInsertSelectGroupBy cmd ++ [sql]} + sqlHaving1 cmd sql = cmd {sqlInsertSelectHaving = sqlInsertSelectHaving cmd ++ [sql]} sqlGroupBy :: (MonadState v m, SqlGroupByHaving v) => SQL -> m () sqlGroupBy sql = modify (\cmd -> sqlGroupBy1 cmd sql) @@ -875,18 +896,17 @@ sqlGroupBy sql = modify (\cmd -> sqlGroupBy1 cmd sql) sqlHaving :: (MonadState v m, SqlGroupByHaving v) => SQL -> m () sqlHaving sql = modify (\cmd -> sqlHaving1 cmd sql) - class SqlOffsetLimit a where sqlOffset1 :: a -> Integer -> a sqlLimit1 :: a -> Integer -> a instance SqlOffsetLimit SqlSelect where - sqlOffset1 cmd num = cmd { sqlSelectOffset = num } - sqlLimit1 cmd num = cmd { sqlSelectLimit = num } + sqlOffset1 cmd num = cmd {sqlSelectOffset = num} + sqlLimit1 cmd num = cmd {sqlSelectLimit = num} instance SqlOffsetLimit SqlInsertSelect where - sqlOffset1 cmd num = cmd { sqlInsertSelectOffset = num } - sqlLimit1 cmd num = cmd { sqlInsertSelectLimit = num } + sqlOffset1 cmd num = cmd {sqlInsertSelectOffset = num} + sqlLimit1 cmd num = cmd {sqlInsertSelectLimit = num} sqlOffset :: (MonadState v m, SqlOffsetLimit v, Integral int) => int -> m () sqlOffset val = modify (\cmd -> sqlOffset1 cmd $ toInteger val) @@ -898,10 +918,10 @@ class SqlDistinct a where sqlDistinct1 :: a -> a instance SqlDistinct SqlSelect where - sqlDistinct1 cmd = cmd { sqlSelectDistinct = True } + sqlDistinct1 cmd = cmd {sqlSelectDistinct = True} instance SqlDistinct SqlInsertSelect where - sqlDistinct1 cmd = cmd { sqlInsertSelectDistinct = True } + sqlDistinct1 cmd = cmd {sqlInsertSelectDistinct = True} sqlDistinct :: (MonadState v m, SqlDistinct v) => m () sqlDistinct = modify sqlDistinct1 diff --git a/src/Database/PostgreSQL/PQTypes/Utils/NubList.hs b/src/Database/PostgreSQL/PQTypes/Utils/NubList.hs index c05102e..c7a0c3b 100644 --- a/src/Database/PostgreSQL/PQTypes/Utils/NubList.hs +++ b/src/Database/PostgreSQL/PQTypes/Utils/NubList.hs @@ -1,15 +1,15 @@ module Database.PostgreSQL.PQTypes.Utils.NubList - ( NubList -- opaque - , toNubList -- smart construtor - , fromNubList - , overNubList - ) where + ( NubList -- opaque + , toNubList -- smart construtor + , fromNubList + , overNubList + ) where import Data.Typeable -import qualified Text.Read as R -import qualified Data.Set as Set -import qualified Data.Semigroup as SG +import Data.Semigroup qualified as SG +import Data.Set qualified as Set +import Text.Read qualified as R {- This module is a copy-paste fork of Distribution.Utils.NubList in Cabal @@ -19,9 +19,9 @@ import qualified Data.Semigroup as SG -} -- | NubList : A de-duplicated list that maintains the original order. -newtype NubList a = - NubList { fromNubList :: [a] } - deriving (Eq, Typeable) +newtype NubList a + = NubList {fromNubList :: [a]} + deriving (Eq, Typeable) -- NubList assumes that nub retains the list order while removing duplicate -- elements (keeping the first occurence). Documentation for "Data.List.nub" @@ -37,34 +37,35 @@ overNubList :: Ord a => ([a] -> [a]) -> NubList a -> NubList a overNubList f (NubList list) = toNubList . f $ list instance Ord a => SG.Semigroup (NubList a) where - (NubList xs) <> (NubList ys) = NubList $ xs `listUnion` ys - where - listUnion :: (Ord a) => [a] -> [a] -> [a] - listUnion a b = a + (NubList xs) <> (NubList ys) = NubList $ xs `listUnion` ys + where + listUnion :: Ord a => [a] -> [a] -> [a] + listUnion a b = + a ++ ordNubBy id (filter (`Set.notMember` Set.fromList a) b) - instance Ord a => Monoid (NubList a) where - mempty = NubList [] - mappend = (SG.<>) + mempty = NubList [] + mappend = (SG.<>) instance Show a => Show (NubList a) where - show (NubList list) = show list + show (NubList list) = show list instance (Ord a, Read a) => Read (NubList a) where - readPrec = readNubList toNubList + readPrec = readNubList toNubList -- | Helper used by NubList/NubListR's Read instances. -readNubList :: (Read a) => ([a] -> l a) -> R.ReadPrec (l a) +readNubList :: Read a => ([a] -> l a) -> R.ReadPrec (l a) readNubList toList = R.parens . R.prec 10 $ fmap toList R.readPrec ordNubBy :: Ord b => (a -> b) -> [a] -> [a] ordNubBy f = go Set.empty where go !_ [] = [] - go !s (x:xs) + go !s (x : xs) | y `Set.member` s = go s xs - | otherwise = let !s' = Set.insert y s - in x : go s' xs + | otherwise = + let !s' = Set.insert y s + in x : go s' xs where y = f x diff --git a/src/Database/PostgreSQL/PQTypes/Versions.hs b/src/Database/PostgreSQL/PQTypes/Versions.hs index 96dca5a..b2c7e67 100644 --- a/src/Database/PostgreSQL/PQTypes/Versions.hs +++ b/src/Database/PostgreSQL/PQTypes/Versions.hs @@ -3,12 +3,13 @@ module Database.PostgreSQL.PQTypes.Versions where import Database.PostgreSQL.PQTypes.Model tableVersions :: Table -tableVersions = tblTable { - tblName = "table_versions" - , tblVersion = 1 - , tblColumns = [ - tblColumn { colName = "name", colType = TextT, colNullable = False } - , tblColumn { colName = "version", colType = IntegerT, colNullable = False } - ] - , tblPrimaryKey = pkOnColumn "name" - } +tableVersions = + tblTable + { tblName = "table_versions" + , tblVersion = 1 + , tblColumns = + [ tblColumn {colName = "name", colType = TextT, colNullable = False} + , tblColumn {colName = "version", colType = IntegerT, colNullable = False} + ] + , tblPrimaryKey = pkOnColumn "name" + } diff --git a/test/Main.hs b/test/Main.hs index 25b7f51..6dd7617 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,4 +1,5 @@ {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + {-# HLINT ignore "Use head" #-} module Main where @@ -7,8 +8,8 @@ import Control.Monad.Catch import Control.Monad.IO.Class import Data.Either import Data.List (zip4) -import qualified Data.Set as Set -import qualified Data.Text as T +import Data.Set qualified as Set +import Data.Text qualified as T import Data.Typeable import Data.UUID.Types @@ -32,15 +33,16 @@ import Test.Tasty.HUnit import Test.Tasty.Options newtype ConnectionString = ConnectionString String - deriving Typeable + deriving (Typeable) instance IsOption ConnectionString where - defaultValue = ConnectionString - -- For GitHub Actions CI - "host=postgres user=postgres password=postgres" - parseValue = Just . ConnectionString - optionName = return "connection-string" - optionHelp = return "Postgres connection string" + defaultValue = + ConnectionString + -- For GitHub Actions CI + "host=postgres user=postgres password=postgres" + parseValue = Just . ConnectionString + optionName = return "connection-string" + optionHelp = return "Postgres connection string" -- Simple example schemata inspired by the one in -- < http://www.databaseanswers.org/data_models/bank_robberies/index.htm> @@ -64,22 +66,31 @@ instance IsOption ConnectionString where tableBankSchema1 :: Table tableBankSchema1 = tblTable - { tblName = "bank" - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "id", colType = UuidT - , colNullable = False - , colDefault = Just "gen_random_uuid()" } - , tblColumn { colName = "name", colType = TextT - , colCollation = Just "en_US" - , colNullable = False } - , tblColumn { colName = "location", colType = TextT - , colCollation = Just "C" - , colNullable = False } - ] - , tblPrimaryKey = pkOnColumn "id" - , tblTriggers = [] - } + { tblName = "bank" + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "id" + , colType = UuidT + , colNullable = False + , colDefault = Just "gen_random_uuid()" + } + , tblColumn + { colName = "name" + , colType = TextT + , colCollation = Just "en_US" + , colNullable = False + } + , tblColumn + { colName = "location" + , colType = TextT + , colCollation = Just "C" + , colNullable = False + } + ] + , tblPrimaryKey = pkOnColumn "id" + , tblTriggers = [] + } tableBankSchema2 :: Table tableBankSchema2 = tableBankSchema1 @@ -87,76 +98,100 @@ tableBankSchema2 = tableBankSchema1 tableBankSchema3 :: Table tableBankSchema3 = tableBankSchema2 -tableBankMigration4 :: (MonadDB m) => Migration m -tableBankMigration4 = Migration - { mgrTableName = tblName tableBankSchema3 - , mgrFrom = 1 - , mgrAction = StandardMigration $ do - runQuery_ $ sqlAlterTable (tblName tableBankSchema3) [ - sqlAddColumn $ tblColumn - { colName = "cash" - , colType = IntegerT - , colNullable = False - , colDefault = Just "0" - } - ] - } +tableBankMigration4 :: MonadDB m => Migration m +tableBankMigration4 = + Migration + { mgrTableName = tblName tableBankSchema3 + , mgrFrom = 1 + , mgrAction = StandardMigration $ do + runQuery_ $ + sqlAlterTable + (tblName tableBankSchema3) + [ sqlAddColumn $ + tblColumn + { colName = "cash" + , colType = IntegerT + , colNullable = False + , colDefault = Just "0" + } + ] + } tableBankSchema4 :: Table -tableBankSchema4 = tableBankSchema3 { - tblVersion = tblVersion tableBankSchema3 + 1 - , tblColumns = tblColumns tableBankSchema3 ++ [ - tblColumn - { colName = "cash", colType = IntegerT - , colNullable = False - , colDefault = Just "0" - } - ] - } - +tableBankSchema4 = + tableBankSchema3 + { tblVersion = tblVersion tableBankSchema3 + 1 + , tblColumns = + tblColumns tableBankSchema3 + ++ [ tblColumn + { colName = "cash" + , colType = IntegerT + , colNullable = False + , colDefault = Just "0" + } + ] + } -tableBankMigration5fst :: (MonadDB m) => Migration m -tableBankMigration5fst = Migration - { mgrTableName = tblName tableBankSchema3 - , mgrFrom = 2 - , mgrAction = StandardMigration $ do - runQuery_ $ sqlAlterTable (tblName tableBankSchema4) [ - sqlDropColumn "cash" - ] - } +tableBankMigration5fst :: MonadDB m => Migration m +tableBankMigration5fst = + Migration + { mgrTableName = tblName tableBankSchema3 + , mgrFrom = 2 + , mgrAction = StandardMigration $ do + runQuery_ $ + sqlAlterTable + (tblName tableBankSchema4) + [ sqlDropColumn "cash" + ] + } -tableBankMigration5snd :: (MonadDB m) => Migration m -tableBankMigration5snd = Migration - { mgrTableName = tblName tableBankSchema3 - , mgrFrom = 3 - , mgrAction = CreateIndexConcurrentlyMigration - (tblName tableBankSchema3) - ((indexOnColumn "name") { idxInclude = ["id", "location"] }) - } +tableBankMigration5snd :: MonadDB m => Migration m +tableBankMigration5snd = + Migration + { mgrTableName = tblName tableBankSchema3 + , mgrFrom = 3 + , mgrAction = + CreateIndexConcurrentlyMigration + (tblName tableBankSchema3) + ((indexOnColumn "name") {idxInclude = ["id", "location"]}) + } tableBankSchema5 :: Table -tableBankSchema5 = tableBankSchema4 { - tblVersion = tblVersion tableBankSchema4 + 2 - , tblColumns = filter (\c -> colName c /= "cash") - (tblColumns tableBankSchema4) - , tblIndexes = [(indexOnColumn "name") { idxInclude = ["id", "location"] }] - } +tableBankSchema5 = + tableBankSchema4 + { tblVersion = tblVersion tableBankSchema4 + 2 + , tblColumns = + filter + (\c -> colName c /= "cash") + (tblColumns tableBankSchema4) + , tblIndexes = [(indexOnColumn "name") {idxInclude = ["id", "location"]}] + } tableBadGuySchema1 :: Table tableBadGuySchema1 = tblTable - { tblName = "bad_guy" - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "id", colType = UuidT - , colNullable = False - , colDefault = Just "gen_random_uuid()" } - , tblColumn { colName = "firstname", colType = TextT - , colNullable = False } - , tblColumn { colName = "lastname", colType = TextT - , colNullable = False } - ] - , tblPrimaryKey = pkOnColumn "id" } + { tblName = "bad_guy" + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "id" + , colType = UuidT + , colNullable = False + , colDefault = Just "gen_random_uuid()" + } + , tblColumn + { colName = "firstname" + , colType = TextT + , colNullable = False + } + , tblColumn + { colName = "lastname" + , colType = TextT + , colNullable = False + } + ] + , tblPrimaryKey = pkOnColumn "id" + } tableBadGuySchema2 :: Table tableBadGuySchema2 = tableBadGuySchema1 @@ -173,19 +208,30 @@ tableBadGuySchema5 = tableBadGuySchema4 tableRobberySchema1 :: Table tableRobberySchema1 = tblTable - { tblName = "robbery" - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "id", colType = UuidT - , colNullable = False - , colDefault = Just "gen_random_uuid()" } - , tblColumn { colName = "bank_id", colType = UuidT - , colNullable = False } - , tblColumn { colName = "date", colType = DateT - , colNullable = False, colDefault = Just "now()" } - ] - , tblPrimaryKey = pkOnColumn "id" - , tblForeignKeys = [fkOnColumn "bank_id" "bank" "id"] } + { tblName = "robbery" + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "id" + , colType = UuidT + , colNullable = False + , colDefault = Just "gen_random_uuid()" + } + , tblColumn + { colName = "bank_id" + , colType = UuidT + , colNullable = False + } + , tblColumn + { colName = "date" + , colType = DateT + , colNullable = False + , colDefault = Just "now()" + } + ] + , tblPrimaryKey = pkOnColumn "id" + , tblForeignKeys = [fkOnColumn "bank_id" "bank" "id"] + } tableRobberySchema2 :: Table tableRobberySchema2 = tableRobberySchema1 @@ -202,17 +248,26 @@ tableRobberySchema5 = tableRobberySchema4 tableParticipatedInRobberySchema1 :: Table tableParticipatedInRobberySchema1 = tblTable - { tblName = "participated_in_robbery" - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "bad_guy_id", colType = UuidT - , colNullable = False } - , tblColumn { colName = "robbery_id", colType = UuidT - , colNullable = False } - ] - , tblPrimaryKey = pkOnColumns ["bad_guy_id", "robbery_id"] - , tblForeignKeys = [fkOnColumn "bad_guy_id" "bad_guy" "id" - ,fkOnColumn "robbery_id" "robbery" "id"] } + { tblName = "participated_in_robbery" + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "bad_guy_id" + , colType = UuidT + , colNullable = False + } + , tblColumn + { colName = "robbery_id" + , colType = UuidT + , colNullable = False + } + ] + , tblPrimaryKey = pkOnColumns ["bad_guy_id", "robbery_id"] + , tblForeignKeys = + [ fkOnColumn "bad_guy_id" "bad_guy" "id" + , fkOnColumn "robbery_id" "robbery" "id" + ] + } tableParticipatedInRobberySchema2 :: Table tableParticipatedInRobberySchema2 = tableParticipatedInRobberySchema1 @@ -232,18 +287,28 @@ tableWitnessName = "witness" tableWitnessSchema1 :: Table tableWitnessSchema1 = tblTable - { tblName = tableWitnessName - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "id", colType = UuidT - , colNullable = False - , colDefault = Just "gen_random_uuid()" } - , tblColumn { colName = "firstname", colType = TextT - , colNullable = False } - , tblColumn { colName = "lastname", colType = TextT - , colNullable = False } - ] - , tblPrimaryKey = pkOnColumn "id" } + { tblName = tableWitnessName + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "id" + , colType = UuidT + , colNullable = False + , colDefault = Just "gen_random_uuid()" + } + , tblColumn + { colName = "firstname" + , colType = TextT + , colNullable = False + } + , tblColumn + { colName = "lastname" + , colType = TextT + , colNullable = False + } + ] + , tblPrimaryKey = pkOnColumn "id" + } tableWitnessedRobberyName :: RawSQL () tableWitnessedRobberyName = "witnessed_robbery" @@ -251,17 +316,26 @@ tableWitnessedRobberyName = "witnessed_robbery" tableWitnessedRobberySchema1 :: Table tableWitnessedRobberySchema1 = tblTable - { tblName = tableWitnessedRobberyName - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "witness_id", colType = UuidT - , colNullable = False } - , tblColumn { colName = "robbery_id", colType = UuidT - , colNullable = False } - ] - , tblPrimaryKey = pkOnColumns ["witness_id", "robbery_id"] - , tblForeignKeys = [fkOnColumn "witness_id" "witness" "id" - ,fkOnColumn "robbery_id" "robbery" "id"] } + { tblName = tableWitnessedRobberyName + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "witness_id" + , colType = UuidT + , colNullable = False + } + , tblColumn + { colName = "robbery_id" + , colType = UuidT + , colNullable = False + } + ] + , tblPrimaryKey = pkOnColumns ["witness_id", "robbery_id"] + , tblForeignKeys = + [ fkOnColumn "witness_id" "witness" "id" + , fkOnColumn "robbery_id" "robbery" "id" + ] + } tableUnderArrestName :: RawSQL () tableUnderArrestName = "under_arrest" @@ -269,20 +343,32 @@ tableUnderArrestName = "under_arrest" tableUnderArrestSchema2 :: Table tableUnderArrestSchema2 = tblTable - { tblName = tableUnderArrestName - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "bad_guy_id", colType = UuidT - , colNullable = False } - , tblColumn { colName = "robbery_id", colType = UuidT - , colNullable = False } - , tblColumn { colName = "court_date", colType = DateT - , colNullable = False - , colDefault = Just "now()" } - ] - , tblPrimaryKey = pkOnColumns ["bad_guy_id", "robbery_id"] - , tblForeignKeys = [fkOnColumn "bad_guy_id" "bad_guy" "id" - ,fkOnColumn "robbery_id" "robbery" "id"] } + { tblName = tableUnderArrestName + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "bad_guy_id" + , colType = UuidT + , colNullable = False + } + , tblColumn + { colName = "robbery_id" + , colType = UuidT + , colNullable = False + } + , tblColumn + { colName = "court_date" + , colType = DateT + , colNullable = False + , colDefault = Just "now()" + } + ] + , tblPrimaryKey = pkOnColumns ["bad_guy_id", "robbery_id"] + , tblForeignKeys = + [ fkOnColumn "bad_guy_id" "bad_guy" "id" + , fkOnColumn "robbery_id" "robbery" "id" + ] + } tablePrisonSentenceName :: RawSQL () tablePrisonSentenceName = "prison_sentence" @@ -290,28 +376,43 @@ tablePrisonSentenceName = "prison_sentence" tablePrisonSentenceSchema3 :: Table tablePrisonSentenceSchema3 = tblTable - { tblName = tablePrisonSentenceName - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "bad_guy_id", colType = UuidT - , colNullable = False } - , tblColumn { colName = "robbery_id", colType = UuidT - , colNullable = False } - , tblColumn { colName = "sentence_start" - , colType = DateT - , colNullable = False - , colDefault = Just "now()" } - , tblColumn { colName = "sentence_length" - , colType = IntegerT - , colNullable = False - , colDefault = Just "6" } - , tblColumn { colName = "prison_name" - , colType = TextT - , colNullable = False } - ] - , tblPrimaryKey = pkOnColumns ["bad_guy_id", "robbery_id"] - , tblForeignKeys = [fkOnColumn "bad_guy_id" "bad_guy" "id" - ,fkOnColumn "robbery_id" "robbery" "id"] } + { tblName = tablePrisonSentenceName + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "bad_guy_id" + , colType = UuidT + , colNullable = False + } + , tblColumn + { colName = "robbery_id" + , colType = UuidT + , colNullable = False + } + , tblColumn + { colName = "sentence_start" + , colType = DateT + , colNullable = False + , colDefault = Just "now()" + } + , tblColumn + { colName = "sentence_length" + , colType = IntegerT + , colNullable = False + , colDefault = Just "6" + } + , tblColumn + { colName = "prison_name" + , colType = TextT + , colNullable = False + } + ] + , tblPrimaryKey = pkOnColumns ["bad_guy_id", "robbery_id"] + , tblForeignKeys = + [ fkOnColumn "bad_guy_id" "bad_guy" "id" + , fkOnColumn "robbery_id" "robbery" "id" + ] + } tablePrisonSentenceSchema4 :: Table tablePrisonSentenceSchema4 = tablePrisonSentenceSchema3 @@ -325,12 +426,12 @@ tableFlashName = "flash" tableFlash :: Table tableFlash = tblTable - { tblName = tableFlashName - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "flash_id", colType = UuidT, colNullable = False } - ] - } + { tblName = tableFlashName + , tblVersion = 1 + , tblColumns = + [ tblColumn {colName = "flash_id", colType = UuidT, colNullable = False} + ] + } tableCartelName :: RawSQL () tableCartelName = "cartel" @@ -338,46 +439,58 @@ tableCartelName = "cartel" tableCartel :: Table tableCartel = tblTable - { tblName = tableCartelName - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "cartel_member_id", colType = UuidT - , colNullable = False } - , tblColumn { colName = "cartel_boss_id", colType = UuidT - , colNullable = True } - ] - , tblPrimaryKey = pkOnColumns ["cartel_member_id"] - , tblForeignKeys = [fkOnColumn "cartel_member_id" "bad_guy" "id" - ,fkOnColumn "cartel_boss_id" "bad_guy" "id"] } + { tblName = tableCartelName + , tblVersion = 1 + , tblColumns = + [ tblColumn + { colName = "cartel_member_id" + , colType = UuidT + , colNullable = False + } + , tblColumn + { colName = "cartel_boss_id" + , colType = UuidT + , colNullable = True + } + ] + , tblPrimaryKey = pkOnColumns ["cartel_member_id"] + , tblForeignKeys = + [ fkOnColumn "cartel_member_id" "bad_guy" "id" + , fkOnColumn "cartel_boss_id" "bad_guy" "id" + ] + } tableCartelSchema1 :: Table tableCartelSchema1 = tableCartel -createTableMigration :: (MonadDB m) => Table -> Migration m -createTableMigration tbl = Migration - { mgrTableName = tblName tbl - , mgrFrom = 0 - , mgrAction = StandardMigration $ do - createTable True tbl - } - -dropTableMigration :: (MonadDB m) => Table -> Migration m -dropTableMigration tbl = Migration - { mgrTableName = tblName tbl - , mgrFrom = tblVersion tbl - , mgrAction = DropTableMigration DropTableCascade - } +createTableMigration :: MonadDB m => Table -> Migration m +createTableMigration tbl = + Migration + { mgrTableName = tblName tbl + , mgrFrom = 0 + , mgrAction = StandardMigration $ do + createTable True tbl + } + +dropTableMigration :: MonadDB m => Table -> Migration m +dropTableMigration tbl = + Migration + { mgrTableName = tblName tbl + , mgrFrom = tblVersion tbl + , mgrAction = DropTableMigration DropTableCascade + } schema1Tables :: [Table] -schema1Tables = [ tableBankSchema1 - , tableBadGuySchema1 - , tableRobberySchema1 - , tableParticipatedInRobberySchema1 - , tableWitnessSchema1 - , tableWitnessedRobberySchema1 - ] +schema1Tables = + [ tableBankSchema1 + , tableBadGuySchema1 + , tableRobberySchema1 + , tableParticipatedInRobberySchema1 + , tableWitnessSchema1 + , tableWitnessedRobberySchema1 + ] -schema1Migrations :: (MonadDB m) => [Migration m] +schema1Migrations :: MonadDB m => [Migration m] schema1Migrations = [ createTableMigration tableBankSchema1 , createTableMigration tableBadGuySchema1 @@ -388,95 +501,112 @@ schema1Migrations = ] schema2Tables :: [Table] -schema2Tables = [ tableBankSchema2 - , tableBadGuySchema2 - , tableRobberySchema2 - , tableParticipatedInRobberySchema2 - , tableUnderArrestSchema2 - ] +schema2Tables = + [ tableBankSchema2 + , tableBadGuySchema2 + , tableRobberySchema2 + , tableParticipatedInRobberySchema2 + , tableUnderArrestSchema2 + ] -schema2Migrations :: (MonadDB m) => [Migration m] -schema2Migrations = schema1Migrations - ++ [ dropTableMigration tableWitnessedRobberySchema1 - , dropTableMigration tableWitnessSchema1 - , createTableMigration tableUnderArrestSchema2 - ] +schema2Migrations :: MonadDB m => [Migration m] +schema2Migrations = + schema1Migrations + ++ [ dropTableMigration tableWitnessedRobberySchema1 + , dropTableMigration tableWitnessSchema1 + , createTableMigration tableUnderArrestSchema2 + ] schema3Tables :: [Table] -schema3Tables = [ tableBankSchema3 - , tableBadGuySchema3 - , tableRobberySchema3 - , tableParticipatedInRobberySchema3 - , tablePrisonSentenceSchema3 - ] +schema3Tables = + [ tableBankSchema3 + , tableBadGuySchema3 + , tableRobberySchema3 + , tableParticipatedInRobberySchema3 + , tablePrisonSentenceSchema3 + ] -schema3Migrations :: (MonadDB m) => [Migration m] -schema3Migrations = schema2Migrations - ++ [ dropTableMigration tableUnderArrestSchema2 - , createTableMigration tablePrisonSentenceSchema3 ] +schema3Migrations :: MonadDB m => [Migration m] +schema3Migrations = + schema2Migrations + ++ [ dropTableMigration tableUnderArrestSchema2 + , createTableMigration tablePrisonSentenceSchema3 + ] schema4Tables :: [Table] -schema4Tables = [ tableBankSchema4 - , tableBadGuySchema4 - , tableRobberySchema4 - , tableParticipatedInRobberySchema4 - , tablePrisonSentenceSchema4 - ] +schema4Tables = + [ tableBankSchema4 + , tableBadGuySchema4 + , tableRobberySchema4 + , tableParticipatedInRobberySchema4 + , tablePrisonSentenceSchema4 + ] -schema4Migrations :: (MonadDB m) => [Migration m] -schema4Migrations = schema3Migrations - ++ [ tableBankMigration4 ] +schema4Migrations :: MonadDB m => [Migration m] +schema4Migrations = + schema3Migrations + ++ [tableBankMigration4] schema5Tables :: [Table] -schema5Tables = [ tableBankSchema5 - , tableBadGuySchema5 - , tableRobberySchema5 - , tableParticipatedInRobberySchema5 - , tablePrisonSentenceSchema5 - ] +schema5Tables = + [ tableBankSchema5 + , tableBadGuySchema5 + , tableRobberySchema5 + , tableParticipatedInRobberySchema5 + , tablePrisonSentenceSchema5 + ] -schema5Migrations :: (MonadDB m) => [Migration m] -schema5Migrations = schema4Migrations - ++ [ createTableMigration tableFlash - , tableBankMigration5fst - , tableBankMigration5snd - , dropTableMigration tableFlash - ] +schema5Migrations :: MonadDB m => [Migration m] +schema5Migrations = + schema4Migrations + ++ [ createTableMigration tableFlash + , tableBankMigration5fst + , tableBankMigration5snd + , dropTableMigration tableFlash + ] schema6Tables :: [Table] schema6Tables = - [ tableBankSchema1 - , tableBadGuySchema1 - , tableRobberySchema1 - , tableParticipatedInRobberySchema1 - { tblVersion = tblVersion tableParticipatedInRobberySchema1 + 1, - tblPrimaryKey = Nothing } - , tableWitnessSchema1 - , tableWitnessedRobberySchema1 - ] + [ tableBankSchema1 + , tableBadGuySchema1 + , tableRobberySchema1 + , tableParticipatedInRobberySchema1 + { tblVersion = tblVersion tableParticipatedInRobberySchema1 + 1 + , tblPrimaryKey = Nothing + } + , tableWitnessSchema1 + , tableWitnessedRobberySchema1 + ] -schema6Migrations :: (MonadDB m) => Migration m +schema6Migrations :: MonadDB m => Migration m schema6Migrations = - Migration + Migration { mgrTableName = tblName tableParticipatedInRobberySchema1 , mgrFrom = tblVersion tableParticipatedInRobberySchema1 , mgrAction = - StandardMigration $ do - runQuery_ ("ALTER TABLE participated_in_robbery DROP CONSTRAINT \ - \pk__participated_in_robbery" :: RawSQL ()) + StandardMigration $ do + runQuery_ + ( "ALTER TABLE participated_in_robbery DROP CONSTRAINT \ + \pk__participated_in_robbery" + :: RawSQL () + ) } - type TestM a = DBT (LogT IO) a createTablesSchema1 :: (String -> TestM ()) -> TestM () createTablesSchema1 step = do - let extensions = ["pgcrypto"] - composites = [] - domains = [] + let extensions = ["pgcrypto"] + composites = [] + domains = [] step "Creating the database (schema version 1)..." - migrateDatabase defaultExtrasOptions extensions domains - composites schema1Tables schema1Migrations + migrateDatabase + defaultExtrasOptions + extensions + domains + composites + schema1Tables + schema1Migrations -- Add a local index that shouldn't trigger validation errors. runSQL_ "CREATE INDEX local_idx_bank_name ON bank(name)" @@ -489,13 +619,22 @@ testDBSchema1 step = do -- Populate the 'bank' table. runQuery_ . sqlInsert "bank" $ do - sqlSetList "name" ["HSBC" :: T.Text, "Swedbank", "Nordea", "Citi" - ,"Wells Fargo"] - sqlSetList "location" ["13 Foo St., Tucson, AZ, USa" :: T.Text - , "18 Bargatan, Stockholm, Sweden" - , "23 Baz Lane, Liverpool, UK" - , "2/3 Quux Ave., Milton Keynes, UK" - , "6600 Sunset Blvd., Los Angeles, CA, USA"] + sqlSetList + "name" + [ "HSBC" :: T.Text + , "Swedbank" + , "Nordea" + , "Citi" + , "Wells Fargo" + ] + sqlSetList + "location" + [ "13 Foo St., Tucson, AZ, USa" :: T.Text + , "18 Bargatan, Stockholm, Sweden" + , "23 Baz Lane, Liverpool, UK" + , "2/3 Quux Ave., Milton Keynes, UK" + , "6600 Sunset Blvd., Los Angeles, CA, USA" + ] sqlResult "id" (bankIds :: [UUID]) <- fetchMany runIdentity liftIO $ assertEqual "INSERT into 'bank' table" 5 (length bankIds) @@ -603,64 +742,104 @@ testDBSchema1 step = do -- Populate the 'bad_guy' table. runQuery_ . sqlInsert "bad_guy" $ do - sqlSetList "firstname" ["Neil" :: T.Text, "Lee", "Freddie", "Frankie" - ,"James", "Roy"] - sqlSetList "lastname" ["Hetzel"::T.Text, "Murray", "Foreman", "Fraser" - ,"Crosbie", "Shaw"] + sqlSetList + "firstname" + [ "Neil" :: T.Text + , "Lee" + , "Freddie" + , "Frankie" + , "James" + , "Roy" + ] + sqlSetList + "lastname" + [ "Hetzel" :: T.Text + , "Murray" + , "Foreman" + , "Fraser" + , "Crosbie" + , "Shaw" + ] sqlResult "id" (badGuyIds :: [UUID]) <- fetchMany runIdentity liftIO $ assertEqual "INSERT into 'bad_guy' table" 6 (length badGuyIds) -- Populate the 'robbery' table. runQuery_ . sqlInsert "robbery" $ do - sqlSetList "bank_id" [bankIds !! idx | idx <- [0,3]] + sqlSetList "bank_id" [bankIds !! idx | idx <- [0, 3]] sqlResult "id" (robberyIds :: [UUID]) <- fetchMany runIdentity liftIO $ assertEqual "INSERT into 'robbery' table" 2 (length robberyIds) -- Populate the 'participated_in_robbery' table. runQuery_ . sqlInsert "participated_in_robbery" $ do - sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [0,2]] + sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [0, 2]] sqlSet "robbery_id" (robberyIds !! 0) sqlResult "bad_guy_id" (participatorIds :: [UUID]) <- fetchMany runIdentity - liftIO $ assertEqual "INSERT into 'participated_in_robbery' table" 2 - (length participatorIds) + liftIO $ + assertEqual + "INSERT into 'participated_in_robbery' table" + 2 + (length participatorIds) runQuery_ . sqlInsert "participated_in_robbery" $ do - sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [3,4]] + sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [3, 4]] sqlSet "robbery_id" (robberyIds !! 1) sqlResult "bad_guy_id" (participatorIds' :: [UUID]) <- fetchMany runIdentity - liftIO $ assertEqual "INSERT into 'participated_in_robbery' table" 2 - (length participatorIds') + liftIO $ + assertEqual + "INSERT into 'participated_in_robbery' table" + 2 + (length participatorIds') -- Populate the 'witness' table. runQuery_ . sqlInsert "witness" $ do - sqlSetList "firstname" ["Meredith" :: T.Text, "Charlie", "Peter", "Emun" - ,"Benedict", "Erica"] - sqlSetList "lastname" ["Vickers"::T.Text, "Holloway", "Weyland", "Eliott" - ,"Wong", "Hackett"] + sqlSetList + "firstname" + [ "Meredith" :: T.Text + , "Charlie" + , "Peter" + , "Emun" + , "Benedict" + , "Erica" + ] + sqlSetList + "lastname" + [ "Vickers" :: T.Text + , "Holloway" + , "Weyland" + , "Eliott" + , "Wong" + , "Hackett" + ] sqlResult "id" (witnessIds :: [UUID]) <- fetchMany runIdentity liftIO $ assertEqual "INSERT into 'witness' table" 6 (length witnessIds) -- Populate the 'witnessed_robbery' table. runQuery_ . sqlInsert "witnessed_robbery" $ do - sqlSetList "witness_id" [witnessIds !! idx | idx <- [0,1]] + sqlSetList "witness_id" [witnessIds !! idx | idx <- [0, 1]] sqlSet "robbery_id" (robberyIds !! 0) sqlResult "witness_id" (robberyWitnessIds :: [UUID]) <- fetchMany runIdentity - liftIO $ assertEqual "INSERT into 'witnessed_robbery' table" 2 - (length robberyWitnessIds) + liftIO $ + assertEqual + "INSERT into 'witnessed_robbery' table" + 2 + (length robberyWitnessIds) runQuery_ . sqlInsert "witnessed_robbery" $ do - sqlSetList "witness_id" [witnessIds !! idx | idx <- [2,3,4]] + sqlSetList "witness_id" [witnessIds !! idx | idx <- [2, 3, 4]] sqlSet "robbery_id" (robberyIds !! 1) sqlResult "witness_id" (robberyWitnessIds' :: [UUID]) <- fetchMany runIdentity - liftIO $ assertEqual "INSERT into 'witnessed_robbery' table" 3 - (length robberyWitnessIds') + liftIO $ + assertEqual + "INSERT into 'witnessed_robbery' table" + 3 + (length robberyWitnessIds') -- Create a new record to test order-by case sensitivity. runQuery_ . sqlInsert "bank" $ do @@ -674,15 +853,17 @@ testDBSchema1 step = do sqlOrderBy "location" details8 <- fetchMany runIdentity - liftIO $ assertEqual "Using collation method \"C\" leads to case-sensitive ordering of results" - [ "18 Bargatan, Stockholm, Sweden" :: String - , "2/3 Quux Ave., Milton Keynes, UK" - , "23 Baz Lane, Liverpool, UK" - , "6600 Sunset Blvd., Los Angeles, CA, USA" - , "SYRIA" - , "Spain" - ] - details8 + liftIO $ + assertEqual + "Using collation method \"C\" leads to case-sensitive ordering of results" + [ "18 Bargatan, Stockholm, Sweden" :: String + , "2/3 Quux Ave., Milton Keynes, UK" + , "23 Baz Lane, Liverpool, UK" + , "6600 Sunset Blvd., Los Angeles, CA, USA" + , "SYRIA" + , "Spain" + ] + details8 -- Check that ordering results by the "name" column uses case-insensitive -- sorting (since the collation method for that column is "en_US"). @@ -691,15 +872,17 @@ testDBSchema1 step = do sqlOrderBy "name" details9 <- fetchMany runIdentity - liftIO $ assertEqual "Using collation method \"en_US\" leads to case-insensitive ordering of results" - [ "byblos bank" :: String - , "Citi" - , "Nordea" - , "Santander" - , "Swedbank" - , "Wells Fargo" - ] - details9 + liftIO $ + assertEqual + "Using collation method \"en_US\" leads to case-insensitive ordering of results" + [ "byblos bank" :: String + , "Citi" + , "Nordea" + , "Santander" + , "Swedbank" + , "Wells Fargo" + ] + details9 do deletedRows <- runQuery . sqlDelete "witness" $ do @@ -709,80 +892,109 @@ testDBSchema1 step = do liftIO $ assertEqual "DELETE FROM 'witness' table" 1 deletedRows deletedName <- fetchOne id - liftIO $ assertEqual "DELETE FROM 'witness' table RETURNING firstname, lastname" - ("Erica" :: String, "Hackett" :: String) - deletedName + liftIO $ + assertEqual + "DELETE FROM 'witness' table RETURNING firstname, lastname" + ("Erica" :: String, "Hackett" :: String) + deletedName return (badGuyIds, robberyIds) migrateDBToSchema2 :: (String -> TestM ()) -> TestM () migrateDBToSchema2 step = do - let extensions = ["pgcrypto"] - composites = [] - domains = [] + let extensions = ["pgcrypto"] + composites = [] + domains = [] step "Migrating the database (schema version 1 -> schema version 2)..." - migrateDatabase defaultExtrasOptions { eoLockTimeoutMs = Just 1000 } extensions composites domains - schema2Tables schema2Migrations + migrateDatabase + defaultExtrasOptions {eoLockTimeoutMs = Just 1000} + extensions + composites + domains + schema2Tables + schema2Migrations checkDatabase defaultExtrasOptions composites domains schema2Tables -- | Hacky version of 'migrateDBToSchema2' used by 'migrationTest3'. migrateDBToSchema2Hacky :: (String -> TestM ()) -> TestM () migrateDBToSchema2Hacky step = do - let extensions = ["pgcrypto"] - composites = [] - domains = [] - step "Hackily migrating the database (schema version 1 \ - \-> schema version 2)..." - migrateDatabase defaultExtrasOptions extensions composites domains - schema2Tables schema2Migrations' + let extensions = ["pgcrypto"] + composites = [] + domains = [] + step + "Hackily migrating the database (schema version 1 \ + \-> schema version 2)..." + migrateDatabase + defaultExtrasOptions + extensions + composites + domains + schema2Tables + schema2Migrations' checkDatabase defaultExtrasOptions composites domains schema2Tables - where - schema2Migrations' = createTableMigration tableFlash : schema2Migrations + where + schema2Migrations' = createTableMigration tableFlash : schema2Migrations testDBSchema2 :: (String -> TestM ()) -> [UUID] -> [UUID] -> TestM () testDBSchema2 step badGuyIds robberyIds = do step "Running test queries (schema version 2)..." -- Check that table 'witness' doesn't exist. - runSQL_ $ "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" - <> " AND tablename = 'witness')"; + runSQL_ $ + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" + <> " AND tablename = 'witness')" (witnessExists :: Bool) <- fetchOne runIdentity liftIO $ assertEqual "Table 'witness' doesn't exist" False witnessExists -- Check that table 'witnessed_robbery' doesn't exist. - runSQL_ $ "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" - <> " AND tablename = 'witnessed_robbery')"; + runSQL_ $ + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" + <> " AND tablename = 'witnessed_robbery')" (witnessedRobberyExists :: Bool) <- fetchOne runIdentity - liftIO $ assertEqual "Table 'witnessed_robbery' doesn't exist" False - witnessedRobberyExists + liftIO $ + assertEqual + "Table 'witnessed_robbery' doesn't exist" + False + witnessedRobberyExists -- Populate table 'under_arrest'. runQuery_ . sqlInsert "under_arrest" $ do - sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [0,2]] + sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [0, 2]] sqlSet "robbery_id" (robberyIds !! 0) sqlResult "bad_guy_id" (arrestedIds :: [UUID]) <- fetchMany runIdentity - liftIO $ assertEqual "INSERT into 'under_arrest' table" 2 - (length arrestedIds) + liftIO $ + assertEqual + "INSERT into 'under_arrest' table" + 2 + (length arrestedIds) runQuery_ . sqlInsert "under_arrest" $ do - sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [3,4]] + sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [3, 4]] sqlSet "robbery_id" (robberyIds !! 1) sqlResult "bad_guy_id" (arrestedIds' :: [UUID]) <- fetchMany runIdentity - liftIO $ assertEqual "INSERT into 'under_arrest' table" 2 - (length arrestedIds') + liftIO $ + assertEqual + "INSERT into 'under_arrest' table" + 2 + (length arrestedIds') return () migrateDBToSchema3 :: (String -> TestM ()) -> TestM () migrateDBToSchema3 step = do - let extensions = ["pgcrypto"] - composites = [] - domains = [] + let extensions = ["pgcrypto"] + composites = [] + domains = [] step "Migrating the database (schema version 2 -> schema version 3)..." - migrateDatabase defaultExtrasOptions extensions composites domains - schema3Tables schema3Migrations + migrateDatabase + defaultExtrasOptions + extensions + composites + domains + schema3Tables + schema3Migrations checkDatabase defaultExtrasOptions composites domains schema3Tables testDBSchema3 :: (String -> TestM ()) -> [UUID] -> [UUID] -> TestM () @@ -790,50 +1002,69 @@ testDBSchema3 step badGuyIds robberyIds = do step "Running test queries (schema version 3)..." -- Check that table 'under_arrest' doesn't exist. - runSQL_ $ "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" - <> " AND tablename = 'under_arrest')"; + runSQL_ $ + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" + <> " AND tablename = 'under_arrest')" (underArrestExists :: Bool) <- fetchOne runIdentity - liftIO $ assertEqual "Table 'under_arrest' doesn't exist" False - underArrestExists + liftIO $ + assertEqual + "Table 'under_arrest' doesn't exist" + False + underArrestExists -- Check that the table 'prison_sentence' exists. - runSQL_ $ "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" - <> " AND tablename = 'prison_sentence')"; + runSQL_ $ + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" + <> " AND tablename = 'prison_sentence')" (prisonSentenceExists :: Bool) <- fetchOne runIdentity - liftIO $ assertEqual "Table 'prison_sentence' does exist" True - prisonSentenceExists + liftIO $ + assertEqual + "Table 'prison_sentence' does exist" + True + prisonSentenceExists -- Populate table 'prison_sentence'. runQuery_ . sqlInsert "prison_sentence" $ do - sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [0,2]] + sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [0, 2]] sqlSet "robbery_id" (robberyIds !! 0) - sqlSet "sentence_length" (12::Int) - sqlSet "prison_name" ("Long Kesh"::T.Text) + sqlSet "sentence_length" (12 :: Int) + sqlSet "prison_name" ("Long Kesh" :: T.Text) sqlResult "bad_guy_id" (sentencedIds :: [UUID]) <- fetchMany runIdentity - liftIO $ assertEqual "INSERT into 'prison_sentence' table" 2 - (length sentencedIds) + liftIO $ + assertEqual + "INSERT into 'prison_sentence' table" + 2 + (length sentencedIds) runQuery_ . sqlInsert "prison_sentence" $ do - sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [3,4]] + sqlSetList "bad_guy_id" [badGuyIds !! idx | idx <- [3, 4]] sqlSet "robbery_id" (robberyIds !! 1) - sqlSet "sentence_length" (9::Int) - sqlSet "prison_name" ("Wormwood Scrubs"::T.Text) + sqlSet "sentence_length" (9 :: Int) + sqlSet "prison_name" ("Wormwood Scrubs" :: T.Text) sqlResult "bad_guy_id" (sentencedIds' :: [UUID]) <- fetchMany runIdentity - liftIO $ assertEqual "INSERT into 'prison_sentence' table" 2 - (length sentencedIds') + liftIO $ + assertEqual + "INSERT into 'prison_sentence' table" + 2 + (length sentencedIds') return () migrateDBToSchema4 :: (String -> TestM ()) -> TestM () migrateDBToSchema4 step = do - let extensions = ["pgcrypto"] - composites = [] - domains = [] + let extensions = ["pgcrypto"] + composites = [] + domains = [] step "Migrating the database (schema version 3 -> schema version 4)..." - migrateDatabase defaultExtrasOptions extensions composites domains - schema4Tables schema4Migrations + migrateDatabase + defaultExtrasOptions + extensions + composites + domains + schema4Tables + schema4Migrations checkDatabase defaultExtrasOptions composites domains schema4Tables testDBSchema4 :: (String -> TestM ()) -> TestM () @@ -841,24 +1072,33 @@ testDBSchema4 step = do step "Running test queries (schema version 4)..." -- Check that the 'bank' table has a 'cash' column. - runSQL_ $ "SELECT EXISTS (SELECT 1 FROM information_schema.columns" - <> " WHERE table_schema = 'public'" - <> " AND table_name = 'bank'" - <> " AND column_name = 'cash')"; + runSQL_ $ + "SELECT EXISTS (SELECT 1 FROM information_schema.columns" + <> " WHERE table_schema = 'public'" + <> " AND table_name = 'bank'" + <> " AND column_name = 'cash')" (colCashExists :: Bool) <- fetchOne runIdentity - liftIO $ assertEqual "Column 'cash' in the table 'bank' does exist" True - colCashExists + liftIO $ + assertEqual + "Column 'cash' in the table 'bank' does exist" + True + colCashExists return () migrateDBToSchema5 :: (String -> TestM ()) -> TestM () migrateDBToSchema5 step = do - let extensions = ["pgcrypto"] - composites = [] - domains = [] + let extensions = ["pgcrypto"] + composites = [] + domains = [] step "Migrating the database (schema version 4 -> schema version 5)..." - migrateDatabase defaultExtrasOptions extensions composites domains - schema5Tables schema5Migrations + migrateDatabase + defaultExtrasOptions + extensions + composites + domains + schema5Tables + schema5Migrations checkDatabase defaultExtrasOptions composites domains schema5Tables testDBSchema5 :: (String -> TestM ()) -> TestM () @@ -866,17 +1106,22 @@ testDBSchema5 step = do step "Running test queries (schema version 5)..." -- Check that the 'bank' table doesn't have a 'cash' column. - runSQL_ $ "SELECT EXISTS (SELECT 1 FROM information_schema.columns" - <> " WHERE table_schema = 'public'" - <> " AND table_name = 'bank'" - <> " AND column_name = 'cash')"; + runSQL_ $ + "SELECT EXISTS (SELECT 1 FROM information_schema.columns" + <> " WHERE table_schema = 'public'" + <> " AND table_name = 'bank'" + <> " AND column_name = 'cash')" (colCashExists :: Bool) <- fetchOne runIdentity - liftIO $ assertEqual "Column 'cash' in the table 'bank' doesn't exist" False - colCashExists + liftIO $ + assertEqual + "Column 'cash' in the table 'bank' doesn't exist" + False + colCashExists -- Check that the 'flash' table doesn't exist. - runSQL_ $ "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" - <> " AND tablename = 'flash')"; + runSQL_ $ + "SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public'" + <> " AND tablename = 'flash')" (flashExists :: Bool) <- fetchOne runIdentity liftIO $ assertEqual "Table 'flash' doesn't exist" False flashExists @@ -884,7 +1129,7 @@ testDBSchema5 step = do -- | May require 'ALTER SCHEMA public OWNER TO $user' the first time -- you run this. -freshTestDB :: (String -> TestM ()) -> TestM () +freshTestDB :: (String -> TestM ()) -> TestM () freshTestDB step = do step "Dropping the test DB schema..." runSQL_ "DROP SCHEMA public CASCADE" @@ -895,62 +1140,64 @@ migrationTest1Body :: (String -> TestM ()) -> TestM () migrationTest1Body step = do createTablesSchema1 step (badGuyIds, robberyIds) <- - testDBSchema1 step + testDBSchema1 step - migrateDBToSchema2 step - testDBSchema2 step badGuyIds robberyIds + migrateDBToSchema2 step + testDBSchema2 step badGuyIds robberyIds - migrateDBToSchema3 step - testDBSchema3 step badGuyIds robberyIds + migrateDBToSchema3 step + testDBSchema3 step badGuyIds robberyIds - migrateDBToSchema4 step - testDBSchema4 step + migrateDBToSchema4 step + testDBSchema4 step - migrateDBToSchema5 step - testDBSchema5 step + migrateDBToSchema5 step + testDBSchema5 step bankTrigger1 :: Trigger bankTrigger1 = - Trigger { triggerTable = "bank" - , triggerName = "trigger_1" - , triggerEvents = Set.fromList [TriggerInsert] - , triggerDeferrable = False - , triggerInitiallyDeferred = False - , triggerWhen = Nothing - , triggerFunction = - "begin" - <+> " perform true;" - <+> " return null;" - <+> "end;" - } + Trigger + { triggerTable = "bank" + , triggerName = "trigger_1" + , triggerEvents = Set.fromList [TriggerInsert] + , triggerDeferrable = False + , triggerInitiallyDeferred = False + , triggerWhen = Nothing + , triggerFunction = + "begin" + <+> " perform true;" + <+> " return null;" + <+> "end;" + } bankTrigger2 :: Trigger bankTrigger2 = bankTrigger1 - { triggerFunction = - "begin" - <+> " return null;" - <+> "end;" - } + { triggerFunction = + "begin" + <+> " return null;" + <+> "end;" + } bankTrigger3 :: Trigger bankTrigger3 = - Trigger { triggerTable = "bank" - , triggerName = "trigger_3" - , triggerEvents = Set.fromList [TriggerInsert, TriggerUpdateOf [unsafeSQL "location"]] - , triggerDeferrable = True - , triggerInitiallyDeferred = True - , triggerWhen = Nothing - , triggerFunction = - "begin" - <+> " perform true;" - <+> " return null;" - <+> "end;" - } + Trigger + { triggerTable = "bank" + , triggerName = "trigger_3" + , triggerEvents = Set.fromList [TriggerInsert, TriggerUpdateOf [unsafeSQL "location"]] + , triggerDeferrable = True + , triggerInitiallyDeferred = True + , triggerWhen = Nothing + , triggerFunction = + "begin" + <+> " perform true;" + <+> " return null;" + <+> "end;" + } bankTrigger2Proper :: Trigger bankTrigger2Proper = - bankTrigger2 { triggerName = "trigger_2" } + bankTrigger2 {triggerName = "trigger_2"} testTriggers :: HasCallStack => (String -> TestM ()) -> TestM () testTriggers step = do @@ -961,31 +1208,37 @@ testTriggers step = do do let msg = "checkDatabase fails if there are triggers in the database but not in the schema" - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [] - } - ] - ms = [ createTriggerMigration 1 bankTrigger1 ] + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [] + } + ] + ms = [createTriggerMigration 1 bankTrigger1] step msg assertException msg $ migrate ts ms do let msg = "checkDatabase fails if there are triggers in the schema but not in the database" - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [bankTrigger1] - } - ] + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [bankTrigger1] + } + ] ms = [] triggerStep msg $ do assertException msg $ migrate ts ms do let msg = "test succeeds when creating a single trigger" - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [bankTrigger1] - } - ] - ms = [ createTriggerMigration 1 bankTrigger1 ] + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [bankTrigger1] + } + ] + ms = [createTriggerMigration 1 bankTrigger1] triggerStep msg $ do assertNoException msg $ migrate ts ms verify [bankTrigger1] True @@ -994,60 +1247,73 @@ testTriggers step = do -- Attempt to create the same triggers twice. Should fail with a DBException saying -- that function already exists. let msg = "database exception is raised if trigger is created twice" - ts = [ tableBankSchema1 { tblVersion = 3 - , tblTriggers = [bankTrigger1] - } - ] - ms = [ createTriggerMigration 1 bankTrigger1 - , createTriggerMigration 2 bankTrigger1 - ] + ts = + [ tableBankSchema1 + { tblVersion = 3 + , tblTriggers = [bankTrigger1] + } + ] + ms = + [ createTriggerMigration 1 bankTrigger1 + , createTriggerMigration 2 bankTrigger1 + ] triggerStep msg $ do assertDBException msg $ migrate ts ms do let msg = "database exception is raised if triggers only differ in function name" - ts = [ tableBankSchema1 { tblVersion = 3 - , tblTriggers = [bankTrigger1, bankTrigger2] - } - ] - ms = [ createTriggerMigration 1 bankTrigger1 - , createTriggerMigration 2 bankTrigger2 - ] + ts = + [ tableBankSchema1 + { tblVersion = 3 + , tblTriggers = [bankTrigger1, bankTrigger2] + } + ] + ms = + [ createTriggerMigration 1 bankTrigger1 + , createTriggerMigration 2 bankTrigger2 + ] triggerStep msg $ do assertDBException msg $ migrate ts ms do let msg = "successfully migrate two triggers" - ts = [ tableBankSchema1 { tblVersion = 3 - , tblTriggers = [bankTrigger1, bankTrigger2Proper] - } - ] - ms = [ createTriggerMigration 1 bankTrigger1 - , createTriggerMigration 2 bankTrigger2Proper - ] + ts = + [ tableBankSchema1 + { tblVersion = 3 + , tblTriggers = [bankTrigger1, bankTrigger2Proper] + } + ] + ms = + [ createTriggerMigration 1 bankTrigger1 + , createTriggerMigration 2 bankTrigger2Proper + ] triggerStep msg $ do assertNoException msg $ migrate ts ms verify [bankTrigger1, bankTrigger2Proper] True do let msg = "database exception is raised if trigger's WHEN is syntactically incorrect" - trg = bankTrigger1 { triggerWhen = Just "WILL FAIL" } - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ createTriggerMigration 1 trg ] + trg = bankTrigger1 {triggerWhen = Just "WILL FAIL"} + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [createTriggerMigration 1 trg] triggerStep msg $ do assertDBException msg $ migrate ts ms do let msg = "database exception is raised if trigger's WHEN uses undefined column" - trg = bankTrigger1 { triggerWhen = Just "NEW.foobar = 1" } - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ createTriggerMigration 1 trg ] + trg = bankTrigger1 {triggerWhen = Just "NEW.foobar = 1"} + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [createTriggerMigration 1 trg] triggerStep msg $ do assertDBException msg $ migrate ts ms @@ -1061,96 +1327,116 @@ testTriggers step = do -- clauses. On the other hand, it's probably good enough as it is. -- See the comment for 'getDBTriggers' in src/Database/PostgreSQL/PQTypes/Model/Trigger.hs. let msg = "checkDatabase fails if WHEN clauses from database and code differ" - trg = bankTrigger1 { triggerWhen = Just "NEW.name != 'foobar'" } - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ createTriggerMigration 1 trg ] + trg = bankTrigger1 {triggerWhen = Just "NEW.name != 'foobar'"} + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [createTriggerMigration 1 trg] triggerStep msg $ do assertException msg $ migrate ts ms do let msg = "successfully migrate trigger with valid WHEN" - trg = bankTrigger1 { triggerWhen = Just "new.name <> 'foobar'::text" } - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ createTriggerMigration 1 trg ] + trg = bankTrigger1 {triggerWhen = Just "new.name <> 'foobar'::text"} + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [createTriggerMigration 1 trg] triggerStep msg $ do assertNoException msg $ migrate ts ms verify [trg] True do let msg = "successfully migrate trigger that is deferrable" - trg = bankTrigger1 { triggerDeferrable = True } - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ createTriggerMigration 1 trg ] + trg = bankTrigger1 {triggerDeferrable = True} + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [createTriggerMigration 1 trg] triggerStep msg $ do assertNoException msg $ migrate ts ms verify [trg] True do let msg = "successfully migrate trigger that is deferrable and initially deferred" - trg = bankTrigger1 { triggerDeferrable = True - , triggerInitiallyDeferred = True - } - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ createTriggerMigration 1 trg ] + trg = + bankTrigger1 + { triggerDeferrable = True + , triggerInitiallyDeferred = True + } + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [createTriggerMigration 1 trg] triggerStep msg $ do assertNoException msg $ migrate ts ms verify [trg] True do let msg = "database exception is raised if trigger is initially deferred but not deferrable" - trg = bankTrigger1 { triggerDeferrable = False - , triggerInitiallyDeferred = True - } - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ createTriggerMigration 1 trg ] + trg = + bankTrigger1 + { triggerDeferrable = False + , triggerInitiallyDeferred = True + } + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [createTriggerMigration 1 trg] triggerStep msg $ do assertDBException msg $ migrate ts ms do let msg = "database exception is raised if dropping trigger that does not exist" trg = bankTrigger1 - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ dropTriggerMigration 1 trg ] + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [dropTriggerMigration 1 trg] triggerStep msg $ do assertDBException msg $ migrate ts ms do let msg = "database exception is raised if dropping trigger function of which does not exist" trg = bankTrigger2 - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ dropTriggerMigration 1 trg ] + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [dropTriggerMigration 1 trg] triggerStep msg $ do assertDBException msg $ migrate ts ms do let msg = "successfully drop trigger" trg = bankTrigger1 - ts = [ tableBankSchema1 { tblVersion = 3 - , tblTriggers = [] - } - ] - ms = [ createTriggerMigration 1 trg, dropTriggerMigration 2 trg ] + ts = + [ tableBankSchema1 + { tblVersion = 3 + , tblTriggers = [] + } + ] + ms = [createTriggerMigration 1 trg, dropTriggerMigration 2 trg] triggerStep msg $ do assertNoException msg $ migrate ts ms verify [trg] False @@ -1158,26 +1444,29 @@ testTriggers step = do do let msg = "database exception is raised if dropping trigger twice" trg = bankTrigger2 - ts = [ tableBankSchema1 { tblVersion = 3 - , tblTriggers = [trg] - } - ] - ms = [ dropTriggerMigration 1 trg, dropTriggerMigration 2 trg ] + ts = + [ tableBankSchema1 + { tblVersion = 3 + , tblTriggers = [trg] + } + ] + ms = [dropTriggerMigration 1 trg, dropTriggerMigration 2 trg] triggerStep msg $ do assertDBException msg $ migrate ts ms do let msg = "successfully create trigger with multiple events" trg = bankTrigger3 - ts = [ tableBankSchema1 { tblVersion = 2 - , tblTriggers = [trg] - } - ] - ms = [ createTriggerMigration 1 trg ] + ts = + [ tableBankSchema1 + { tblVersion = 2 + , tblTriggers = [trg] + } + ] + ms = [createTriggerMigration 1 trg] triggerStep msg $ do assertNoException msg $ migrate ts ms verify [trg] True - where triggerStep msg rest = do recreateTriggerDB @@ -1199,11 +1488,12 @@ testTriggers step = do liftIO . assertBool err $ trans ok triggerMigration :: MonadDB m => (Trigger -> m ()) -> Int -> Trigger -> Migration m - triggerMigration fn from trg = Migration - { mgrTableName = tblName tableBankSchema1 - , mgrFrom = fromIntegral from - , mgrAction = StandardMigration $ fn trg - } + triggerMigration fn from trg = + Migration + { mgrTableName = tblName tableBankSchema1 + , mgrFrom = fromIntegral from + , mgrAction = StandardMigration $ fn trg + } createTriggerMigration :: MonadDB m => Int -> Trigger -> Migration m createTriggerMigration = triggerMigration createTrigger @@ -1217,7 +1507,7 @@ testTriggers step = do runSQL_ "DROP FUNCTION IF EXISTS trgfun__trigger_1;" runSQL_ "DROP FUNCTION IF EXISTS trgfun__trigger_2;" runSQL_ "DROP TABLE IF EXISTS bank;" - runSQL_ "DELETE FROM table_versions WHERE name = 'bank'"; + runSQL_ "DELETE FROM table_versions WHERE name = 'bank'" migrate [tableBankSchema1] [createTableMigration tableBankSchema1] testSqlWith :: HasCallStack => (String -> TestM ()) -> TestM () @@ -1300,7 +1590,7 @@ testSqlWithRecursive step = do -- Pablo is the boss of Gustavo, who is the boss of Mario runQuery_ . sqlInsert "cartel" $ do sqlSetList "cartel_member_id" badGuyIds - sqlSetList "cartel_boss_id" $ Nothing:(Just <$> take 2 badGuyIds) + sqlSetList "cartel_boss_id" $ Nothing : (Just <$> take 2 badGuyIds) step "Checking a recursive query on the cartel table" runQuery_ . sqlSelect "rcartel" $ do sqlWithRecursive "rcartel" $ do @@ -1331,12 +1621,14 @@ testSqlWithRecursive step = do toCartel (memberFn, memberLn, bossFn, bossLn) = (T.intercalate " " [memberFn, memberLn], T.intercalate " " <$> sequence [bossFn, bossLn]) results <- fetchMany toCartel - liftIO $ assertEqual "Wrong cartel hierarchy retrieved" results - [ ("Pablo Escobar", Nothing) - , ("Gustavo Rivero", Just "Pablo Escobar") - , ("Mario Vallejo", Just "Gustavo Rivero") - ] - + liftIO $ + assertEqual + "Wrong cartel hierarchy retrieved" + results + [ ("Pablo Escobar", Nothing) + , ("Gustavo Rivero", Just "Pablo Escobar") + , ("Mario Vallejo", Just "Gustavo Rivero") + ] testUnion :: HasCallStack => (String -> TestM ()) -> TestM () testUnion step = do @@ -1353,9 +1645,11 @@ testUnion step = do sqlResult "true" ] result <- fetchMany runIdentity - liftIO $ assertEqual "UNION of booleans" - [False, True] - result + liftIO $ + assertEqual + "UNION of booleans" + [False, True] + result testUnionAll :: HasCallStack => (String -> TestM ()) -> TestM () testUnionAll step = do @@ -1372,143 +1666,163 @@ testUnionAll step = do sqlResult "true" ] result <- fetchMany runIdentity - liftIO $ assertEqual "UNION ALL of booleans" - [True, False, True] - result + liftIO $ + assertEqual + "UNION ALL of booleans" + [True, False, True] + result migrationTest1 :: ConnectionSourceM (LogT IO) -> TestTree migrationTest1 connSource = testCaseSteps' "Migration test 1" connSource $ \step -> do - freshTestDB step + freshTestDB step - migrationTest1Body step + migrationTest1Body step -- | Test for behaviour of 'checkDatabase' and 'checkDatabaseAllowUnknownObjects' migrationTest2 :: ConnectionSourceM (LogT IO) -> TestTree migrationTest2 connSource = testCaseSteps' "Migration test 2" connSource $ \step -> do - freshTestDB step + freshTestDB step - createTablesSchema1 step + createTablesSchema1 step - let composite = CompositeType - { ctName = "composite" - , ctColumns = - [ CompositeColumn { ccName = "cint", ccType = UuidT } - , CompositeColumn { ccName = "ctext", ccType = TextT } - ] - } - currentSchema = schema1Tables - differentSchema = schema5Tables - extrasOptions = defaultExtrasOptions { eoEnforcePKs = True } - extrasOptionsWithUnknownObjects = extrasOptions { eoObjectsValidationMode = AllowUnknownObjects } - - runQuery_ $ sqlCreateComposite composite - - assertNoException "checkDatabase should run fine for consistent DB" $ - checkDatabase extrasOptions [composite] [] currentSchema - assertException "checkDatabase fails if composite type definition is not provided" $ - checkDatabase extrasOptions [] [] currentSchema - assertNoException "checkDatabaseAllowUnknownTables runs fine \ - \for consistent DB" $ - checkDatabase extrasOptionsWithUnknownObjects [composite] [] currentSchema - assertNoException "checkDatabaseAllowUnknownTables runs fine \ - \for consistent DB with unknown composite type in the database" $ - checkDatabase extrasOptionsWithUnknownObjects [] [] currentSchema - assertException "checkDatabase should throw exception for wrong schema" $ - checkDatabase extrasOptions [] [] differentSchema - assertException "checkDatabaseAllowUnknownObjects \ - \should throw exception for wrong scheme" $ - checkDatabase extrasOptionsWithUnknownObjects [] [] differentSchema - - runSQL_ "INSERT INTO table_versions (name, version) \ - \VALUES ('unknown_table', 0)" - assertException "checkDatabase throw when extra entry in 'table_versions'" $ - checkDatabase extrasOptions [] [] currentSchema - assertNoException "checkDatabaseAllowUnknownObjects \ - \accepts extra entry in 'table_versions'" $ - checkDatabase extrasOptionsWithUnknownObjects [] [] currentSchema - runSQL_ "DELETE FROM table_versions where name='unknown_table'" - - runSQL_ "CREATE TABLE unknown_table (title text)" - assertException "checkDatabase should throw with unknown table" $ - checkDatabase extrasOptions [] [] currentSchema - assertNoException "checkDatabaseAllowUnknownObjects accepts unknown table" $ - checkDatabase extrasOptionsWithUnknownObjects [] [] currentSchema - - runSQL_ "INSERT INTO table_versions (name, version) \ - \VALUES ('unknown_table', 0)" - assertException "checkDatabase should throw with unknown table" $ - checkDatabase extrasOptions [] [] currentSchema - assertNoException "checkDatabaseAllowUnknownObjects \ - \accepts unknown tables with version" $ - checkDatabase extrasOptionsWithUnknownObjects [] [] currentSchema - - freshTestDB step - - let schema1TablesWithMissingPK = schema6Tables - schema1MigrationsWithMissingPK = schema6Migrations - withMissingPKSchema = schema1TablesWithMissingPK - optionsNoPKCheck = defaultExtrasOptions - { eoEnforcePKs = False } - optionsWithPKCheck = defaultExtrasOptions - { eoEnforcePKs = True } - - step "Recreating the database (schema version 1, one table is missing PK)..." - - migrateDatabase optionsNoPKCheck ["pgcrypto"] [] [] - schema1TablesWithMissingPK [schema1MigrationsWithMissingPK] - checkDatabase optionsNoPKCheck [] [] withMissingPKSchema - - assertException - "checkDatabase should throw when PK missing from table \ - \'participated_in_robbery' and check is enabled" $ - checkDatabase optionsWithPKCheck [] [] withMissingPKSchema - assertNoException - "checkDatabase should not throw when PK missing from table \ - \'participated_in_robbery' and check is disabled" $ + let composite = + CompositeType + { ctName = "composite" + , ctColumns = + [ CompositeColumn {ccName = "cint", ccType = UuidT} + , CompositeColumn {ccName = "ctext", ccType = TextT} + ] + } + currentSchema = schema1Tables + differentSchema = schema5Tables + extrasOptions = defaultExtrasOptions {eoEnforcePKs = True} + extrasOptionsWithUnknownObjects = extrasOptions {eoObjectsValidationMode = AllowUnknownObjects} + + runQuery_ $ sqlCreateComposite composite + + assertNoException "checkDatabase should run fine for consistent DB" $ + checkDatabase extrasOptions [composite] [] currentSchema + assertException "checkDatabase fails if composite type definition is not provided" $ + checkDatabase extrasOptions [] [] currentSchema + assertNoException + "checkDatabaseAllowUnknownTables runs fine \ + \for consistent DB" + $ checkDatabase extrasOptionsWithUnknownObjects [composite] [] currentSchema + assertNoException + "checkDatabaseAllowUnknownTables runs fine \ + \for consistent DB with unknown composite type in the database" + $ checkDatabase extrasOptionsWithUnknownObjects [] [] currentSchema + assertException "checkDatabase should throw exception for wrong schema" $ + checkDatabase extrasOptions [] [] differentSchema + assertException + "checkDatabaseAllowUnknownObjects \ + \should throw exception for wrong scheme" + $ checkDatabase extrasOptionsWithUnknownObjects [] [] differentSchema + + runSQL_ + "INSERT INTO table_versions (name, version) \ + \VALUES ('unknown_table', 0)" + assertException "checkDatabase throw when extra entry in 'table_versions'" $ + checkDatabase extrasOptions [] [] currentSchema + assertNoException + "checkDatabaseAllowUnknownObjects \ + \accepts extra entry in 'table_versions'" + $ checkDatabase extrasOptionsWithUnknownObjects [] [] currentSchema + runSQL_ "DELETE FROM table_versions where name='unknown_table'" + + runSQL_ "CREATE TABLE unknown_table (title text)" + assertException "checkDatabase should throw with unknown table" $ + checkDatabase extrasOptions [] [] currentSchema + assertNoException "checkDatabaseAllowUnknownObjects accepts unknown table" $ + checkDatabase extrasOptionsWithUnknownObjects [] [] currentSchema + + runSQL_ + "INSERT INTO table_versions (name, version) \ + \VALUES ('unknown_table', 0)" + assertException "checkDatabase should throw with unknown table" $ + checkDatabase extrasOptions [] [] currentSchema + assertNoException + "checkDatabaseAllowUnknownObjects \ + \accepts unknown tables with version" + $ checkDatabase extrasOptionsWithUnknownObjects [] [] currentSchema + + freshTestDB step + + let schema1TablesWithMissingPK = schema6Tables + schema1MigrationsWithMissingPK = schema6Migrations + withMissingPKSchema = schema1TablesWithMissingPK + optionsNoPKCheck = + defaultExtrasOptions + { eoEnforcePKs = False + } + optionsWithPKCheck = + defaultExtrasOptions + { eoEnforcePKs = True + } + + step "Recreating the database (schema version 1, one table is missing PK)..." + + migrateDatabase + optionsNoPKCheck + ["pgcrypto"] + [] + [] + schema1TablesWithMissingPK + [schema1MigrationsWithMissingPK] checkDatabase optionsNoPKCheck [] [] withMissingPKSchema - freshTestDB step + assertException + "checkDatabase should throw when PK missing from table \ + \'participated_in_robbery' and check is enabled" + $ checkDatabase optionsWithPKCheck [] [] withMissingPKSchema + assertNoException + "checkDatabase should not throw when PK missing from table \ + \'participated_in_robbery' and check is disabled" + $ checkDatabase optionsNoPKCheck [] [] withMissingPKSchema + + freshTestDB step migrationTest3 :: ConnectionSourceM (LogT IO) -> TestTree migrationTest3 connSource = testCaseSteps' "Migration test 3" connSource $ \step -> do - freshTestDB step + freshTestDB step - createTablesSchema1 step - (badGuyIds, robberyIds) <- - testDBSchema1 step + createTablesSchema1 step + (badGuyIds, robberyIds) <- + testDBSchema1 step - migrateDBToSchema2 step - testDBSchema2 step badGuyIds robberyIds + migrateDBToSchema2 step + testDBSchema2 step badGuyIds robberyIds - assertException "Trying to run the same migration twice should fail, \ - \when starting with a createTable migration" $ - migrateDBToSchema2Hacky step + assertException + "Trying to run the same migration twice should fail, \ + \when starting with a createTable migration" + $ migrateDBToSchema2Hacky step - freshTestDB step + freshTestDB step -- | Test that running the same migrations twice doesn't result in -- unexpected errors. migrationTest4 :: ConnectionSourceM (LogT IO) -> TestTree migrationTest4 connSource = testCaseSteps' "Migration test 4" connSource $ \step -> do - freshTestDB step + freshTestDB step - migrationTest1Body step + migrationTest1Body step - -- Here we run step 5 for the second time. This should be a no-op. - migrateDBToSchema5 step - testDBSchema5 step + -- Here we run step 5 for the second time. This should be a no-op. + migrateDBToSchema5 step + testDBSchema5 step - freshTestDB step + freshTestDB step -- | Test triggers. triggerTests :: ConnectionSourceM (LogT IO) -> TestTree triggerTests connSource = testCaseSteps' "Trigger tests" connSource $ \step -> do - freshTestDB step + freshTestDB step testTriggers step sqlWithTests :: ConnectionSourceM (LogT IO) -> TestTree @@ -1555,9 +1869,10 @@ migrationTest5 connSource = -- Explicitly vacuum to update the catalog so that getting the row number estimates -- works. The bracket_ trick is here because vacuum can't run inside a transaction -- block, which every test runs in. - bracket_ (runSQL_ "COMMIT") - (runSQL_ "BEGIN") - (runSQL_ "VACUUM bank") + bracket_ + (runSQL_ "COMMIT") + (runSQL_ "BEGIN") + (runSQL_ "VACUUM bank") forM_ (zip4 tables migrations steps assertions) $ \(table, migration, step', assertion) -> do @@ -1567,36 +1882,42 @@ migrationTest5 connSource = uncurry assertNoException assertion freshTestDB step - where -- Chosen by a fair dice roll. - numbers = [1..101] :: [Int] + numbers = [1 .. 101] :: [Int] table1 = tableBankSchema1 - tables = [ table1 { tblVersion = 2 - , tblColumns = tblColumns table1 ++ [stringColumn] - } - , table1 { tblVersion = 3 - , tblColumns = tblColumns table1 ++ [stringColumn] - } - , table1 { tblVersion = 4 - , tblColumns = tblColumns table1 ++ [stringColumn, boolColumn] - } - , table1 { tblVersion = 5 - , tblColumns = tblColumns table1 ++ [stringColumn, boolColumn] - } - ] + tables = + [ table1 + { tblVersion = 2 + , tblColumns = tblColumns table1 ++ [stringColumn] + } + , table1 + { tblVersion = 3 + , tblColumns = tblColumns table1 ++ [stringColumn] + } + , table1 + { tblVersion = 4 + , tblColumns = tblColumns table1 ++ [stringColumn, boolColumn] + } + , table1 + { tblVersion = 5 + , tblColumns = tblColumns table1 ++ [stringColumn, boolColumn] + } + ] - migrations = [ addStringColumnMigration - , copyStringColumnMigration - , addBoolColumnMigration - , modifyBoolColumnMigration - ] + migrations = + [ addStringColumnMigration + , copyStringColumnMigration + , addBoolColumnMigration + , modifyBoolColumnMigration + ] - steps = [ "Adding string column (version 1 -> version 2)..." - , "Copying string column (version 2 -> version 3)..." - , "Adding bool column (version 3 -> version 4)..." - , "Modifying bool column (version 4 -> version 5)..." - ] + steps = + [ "Adding string column (version 1 -> version 2)..." + , "Copying string column (version 2 -> version 3)..." + , "Adding bool column (version 3 -> version 4)..." + , "Modifying bool column (version 4 -> version 5)..." + ] assertions = [ ("Check that the string column has been added" :: String, checkAddStringColumn) @@ -1605,48 +1926,60 @@ migrationTest5 connSource = , ("Check that the bool column has been modified", checkModifyBoolColumn) ] - stringColumn = tblColumn { colName = "name_new" - , colType = TextT - } + stringColumn = + tblColumn + { colName = "name_new" + , colType = TextT + } - boolColumn = tblColumn { colName = "name_is_true" - , colType = BoolT - , colNullable = False - , colDefault = Just "false" - } + boolColumn = + tblColumn + { colName = "name_is_true" + , colType = BoolT + , colNullable = False + , colDefault = Just "false" + } cursorSql = "SELECT id FROM bank" :: SQL - addStringColumnMigration = Migration - { mgrTableName = "bank" - , mgrFrom = 1 - , mgrAction = StandardMigration $ - runQuery_ $ sqlAlterTable "bank" [ sqlAddColumn stringColumn ] - } + addStringColumnMigration = + Migration + { mgrTableName = "bank" + , mgrFrom = 1 + , mgrAction = + StandardMigration $ + runQuery_ $ + sqlAlterTable "bank" [sqlAddColumn stringColumn] + } - copyStringColumnMigration = Migration - { mgrTableName = "bank" - , mgrFrom = 2 - , mgrAction = ModifyColumnMigration "bank" cursorSql copyColumnSql 1000 - } + copyStringColumnMigration = + Migration + { mgrTableName = "bank" + , mgrFrom = 2 + , mgrAction = ModifyColumnMigration "bank" cursorSql copyColumnSql 1000 + } copyColumnSql :: MonadDB m => [Identity UUID] -> m () copyColumnSql primaryKeys = runQuery_ . sqlUpdate "bank" $ do sqlSetCmd "name_new" "bank.name" sqlWhereEqualsAny "bank.id" $ runIdentity <$> primaryKeys - addBoolColumnMigration = Migration - { mgrTableName = "bank" - , mgrFrom = 3 - , mgrAction = StandardMigration $ - runQuery_ $ sqlAlterTable "bank" [ sqlAddColumn boolColumn ] - } + addBoolColumnMigration = + Migration + { mgrTableName = "bank" + , mgrFrom = 3 + , mgrAction = + StandardMigration $ + runQuery_ $ + sqlAlterTable "bank" [sqlAddColumn boolColumn] + } - modifyBoolColumnMigration = Migration - { mgrTableName = "bank" - , mgrFrom = 4 - , mgrAction = ModifyColumnMigration "bank" cursorSql modifyColumnSql 1000 - } + modifyBoolColumnMigration = + Migration + { mgrTableName = "bank" + , mgrFrom = 4 + , mgrAction = ModifyColumnMigration "bank" cursorSql modifyColumnSql 1000 + } modifyColumnSql :: MonadDB m => [Identity UUID] -> m () modifyColumnSql primaryKeys = runQuery_ . sqlUpdate "bank" $ do @@ -1664,7 +1997,8 @@ migrationTest5 connSource = runQuery_ . sqlSelect "bank" $ sqlResult "name" rows_old :: [Maybe T.Text] <- fetchMany runIdentity liftIO . assertEqual "All name_new are equal name" True $ - all (uncurry (==)) $ zip rows_new rows_old + all (uncurry (==)) $ + zip rows_new rows_old checkAddBoolColumn = do runQuery_ . sqlSelect "bank" $ sqlResult "name_is_true" @@ -1684,26 +2018,48 @@ foreignKeyIndexesTests connSource = step "Create database with two tables, no foreign key checking" do let options = defaultExtrasOptions - migrateDatabase options ["pgcrypto"] [] [] [table1, table2] + migrateDatabase + options + ["pgcrypto"] + [] + [] + [table1, table2] [createTableMigration table1, createTableMigration table2] checkDatabase defaultExtrasOptions [] [] [table1, table2] step "Create database with two tables, with foreign key checking" do - let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True } - assertException "Foreign keys are missing" $ migrateDatabase options ["pgcrypto"] [] [] [table1, table2] - [createTableMigration table1, createTableMigration table2] + let options = defaultExtrasOptions {eoCheckForeignKeysIndexes = True} + assertException "Foreign keys are missing" $ + migrateDatabase + options + ["pgcrypto"] + [] + [] + [table1, table2] + [createTableMigration table1, createTableMigration table2] step "Table is missing several foreign key indexes" do - let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True } - assertException "Foreign keys are missing" $ migrateDatabase options ["pgcrypto"] [] [] [table1, table2, table3] - [createTableMigration table1, createTableMigration table2, createTableMigration table3] + let options = defaultExtrasOptions {eoCheckForeignKeysIndexes = True} + assertException "Foreign keys are missing" $ + migrateDatabase + options + ["pgcrypto"] + [] + [] + [table1, table2, table3] + [createTableMigration table1, createTableMigration table2, createTableMigration table3] step "Multi column indexes covering a FK pass the checks" do - let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True } - migrateDatabase options ["pgcrypto"] [] [] [table4] + let options = defaultExtrasOptions {eoCheckForeignKeysIndexes = True} + migrateDatabase + options + ["pgcrypto"] + [] + [] + [table4] [ dropTableMigration table1 , dropTableMigration table2 , dropTableMigration table3 @@ -1712,87 +2068,98 @@ foreignKeyIndexesTests connSource = checkDatabase options [] [] [table4] step "Multi column indexes not covering a FK fail the checks" do - let options = defaultExtrasOptions { eoCheckForeignKeysIndexes = True } - assertException "Foreign keys are missing" $ migrateDatabase options ["pgcrypto"] [] [] [table5] - [ dropTableMigration table4 - , createTableMigration table5 - ] + let options = defaultExtrasOptions {eoCheckForeignKeysIndexes = True} + assertException "Foreign keys are missing" $ + migrateDatabase + options + ["pgcrypto"] + [] + [] + [table5] + [ dropTableMigration table4 + , createTableMigration table5 + ] where table1 :: Table - table1 = tblTable - { tblName = "fktest1" - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "id", colType = UuidT, colNullable = False } - , tblColumn { colName = "name", colType = TextT } - , tblColumn { colName = "location", colType = TextT } - ] - , tblPrimaryKey = pkOnColumn "id" - } + table1 = + tblTable + { tblName = "fktest1" + , tblVersion = 1 + , tblColumns = + [ tblColumn {colName = "id", colType = UuidT, colNullable = False} + , tblColumn {colName = "name", colType = TextT} + , tblColumn {colName = "location", colType = TextT} + ] + , tblPrimaryKey = pkOnColumn "id" + } table2 :: Table - table2 = tblTable - { tblName = "fktest2" - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "id", colType = UuidT, colNullable = False } - , tblColumn { colName = "fkid", colType = UuidT } - , tblColumn { colName = "fkname", colType = TextT } - ] - , tblPrimaryKey = pkOnColumn "id" - , tblForeignKeys = - [ fkOnColumn "fkid" "fktest1" "id" - ] - } + table2 = + tblTable + { tblName = "fktest2" + , tblVersion = 1 + , tblColumns = + [ tblColumn {colName = "id", colType = UuidT, colNullable = False} + , tblColumn {colName = "fkid", colType = UuidT} + , tblColumn {colName = "fkname", colType = TextT} + ] + , tblPrimaryKey = pkOnColumn "id" + , tblForeignKeys = + [ fkOnColumn "fkid" "fktest1" "id" + ] + } table3 :: Table - table3 = tblTable - { tblName = "fktest3" - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "id", colType = UuidT, colNullable = False } - , tblColumn { colName = "fk1id", colType = UuidT } - , tblColumn { colName = "fk2id", colType = UuidT } - , tblColumn { colName = "fkname", colType = TextT } - ] - , tblPrimaryKey = pkOnColumn "id" - , tblForeignKeys = - [ fkOnColumn "fk1id" "fktest1" "id" - , fkOnColumn "fk2id" "fktest2" "id" - ] - } + table3 = + tblTable + { tblName = "fktest3" + , tblVersion = 1 + , tblColumns = + [ tblColumn {colName = "id", colType = UuidT, colNullable = False} + , tblColumn {colName = "fk1id", colType = UuidT} + , tblColumn {colName = "fk2id", colType = UuidT} + , tblColumn {colName = "fkname", colType = TextT} + ] + , tblPrimaryKey = pkOnColumn "id" + , tblForeignKeys = + [ fkOnColumn "fk1id" "fktest1" "id" + , fkOnColumn "fk2id" "fktest2" "id" + ] + } table4 :: Table - table4 = tblTable - { tblName = "fktest4" - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "id", colType = UuidT, colNullable = False } - , tblColumn { colName = "fk4id", colType = UuidT } - , tblColumn { colName = "fk4name", colType = TextT } - ] - , tblPrimaryKey = pkOnColumn "id" - , tblForeignKeys = - [ fkOnColumn "fk4id" "fktest4" "id" - ] - , tblIndexes = - [ indexOnColumns [ indexColumn "fk4id", indexColumn "fk4name" ] - ] - } + table4 = + tblTable + { tblName = "fktest4" + , tblVersion = 1 + , tblColumns = + [ tblColumn {colName = "id", colType = UuidT, colNullable = False} + , tblColumn {colName = "fk4id", colType = UuidT} + , tblColumn {colName = "fk4name", colType = TextT} + ] + , tblPrimaryKey = pkOnColumn "id" + , tblForeignKeys = + [ fkOnColumn "fk4id" "fktest4" "id" + ] + , tblIndexes = + [ indexOnColumns [indexColumn "fk4id", indexColumn "fk4name"] + ] + } table5 :: Table - table5 = tblTable - { tblName = "fktest5" - , tblVersion = 1 - , tblColumns = - [ tblColumn { colName = "id", colType = UuidT, colNullable = False } - , tblColumn { colName = "fk5id", colType = UuidT } - , tblColumn { colName = "fk5name", colType = TextT } - ] - , tblPrimaryKey = pkOnColumn "id" - , tblForeignKeys = - [ fkOnColumn "fk5id" "fktest5" "id" - ] - , tblIndexes = - [ indexOnColumns [ indexColumn "fk5thing", indexColumn "fk5id" ] - ] - } + table5 = + tblTable + { tblName = "fktest5" + , tblVersion = 1 + , tblColumns = + [ tblColumn {colName = "id", colType = UuidT, colNullable = False} + , tblColumn {colName = "fk5id", colType = UuidT} + , tblColumn {colName = "fk5name", colType = TextT} + ] + , tblPrimaryKey = pkOnColumn "id" + , tblForeignKeys = + [ fkOnColumn "fk5id" "fktest5" "id" + ] + , tblIndexes = + [ indexOnColumns [indexColumn "fk5thing", indexColumn "fk5id"] + ] + } overlapingIndexesTests :: ConnectionSourceM (LogT IO) -> TestTree overlapingIndexesTests connSource = do @@ -1801,25 +2168,32 @@ overlapingIndexesTests connSource = do step "Check that overlapping indexes get flagged" do - let options = defaultExtrasOptions { eoCheckOverlappingIndexes = True } - assertException "Some indexes are overlapping" $ migrateDatabase options ["pgcrypto"] [] [] [table1] - [createTableMigration table1] - where - table1 :: Table - table1 = tblTable + let options = defaultExtrasOptions {eoCheckOverlappingIndexes = True} + assertException "Some indexes are overlapping" $ + migrateDatabase + options + ["pgcrypto"] + [] + [] + [table1] + [createTableMigration table1] + where + table1 :: Table + table1 = + tblTable { tblName = "idxTest" , tblVersion = 1 , tblColumns = - [ tblColumn { colName = "id", colType = UuidT, colNullable = False } - , tblColumn { colName = "idx1", colType = UuidT } - , tblColumn { colName = "idx2", colType = UuidT } - , tblColumn { colName = "idx3", colType = UuidT } - ] + [ tblColumn {colName = "id", colType = UuidT, colNullable = False} + , tblColumn {colName = "idx1", colType = UuidT} + , tblColumn {colName = "idx2", colType = UuidT} + , tblColumn {colName = "idx3", colType = UuidT} + ] , tblPrimaryKey = pkOnColumn "id" , tblIndexes = - [ indexOnColumns [ indexColumn "idx1", indexColumn "idx2" ] - , indexOnColumns [ indexColumn "idx1" ] - ] + [ indexOnColumns [indexColumn "idx1", indexColumn "idx2"] + , indexOnColumns [indexColumn "idx1"] + ] } nullsNotDistinctTests :: ConnectionSourceM (LogT IO) -> TestTree @@ -1829,7 +2203,12 @@ nullsNotDistinctTests connSource = do step "Create a database with indexes" do - migrateDatabase defaultExtrasOptions ["pgcrypto"] [] [] [nullTableTest1, nullTableTest2] + migrateDatabase + defaultExtrasOptions + ["pgcrypto"] + [] + [] + [nullTableTest1, nullTableTest2] [createTableMigration nullTableTest1, createTableMigration nullTableTest2] checkDatabase defaultExtrasOptions [] [] [nullTableTest1, nullTableTest2] @@ -1846,148 +2225,181 @@ nullsNotDistinctTests connSource = do sqlSet "content" (Nothing @T.Text) assertDBException "Cannot insert twice a null value with NULLS NOT DISTINCT" $ runQuery_ . sqlInsert "nulltests2" $ do sqlSet "content" (Nothing @T.Text) - - where - nullTableTest1 = tblTable + where + nullTableTest1 = + tblTable { tblName = "nulltests1" , tblVersion = 1 , tblColumns = - [ tblColumn { colName = "id", colType = UuidT, colNullable = False, colDefault = Just "gen_random_uuid()" } - , tblColumn { colName = "content", colType = TextT, colNullable = True } - ] + [ tblColumn {colName = "id", colType = UuidT, colNullable = False, colDefault = Just "gen_random_uuid()"} + , tblColumn {colName = "content", colType = TextT, colNullable = True} + ] , tblPrimaryKey = pkOnColumn "id" , tblIndexes = - [ uniqueIndexOnColumn "content" - ] + [ uniqueIndexOnColumn "content" + ] } - nullTableTest2 = tblTable + nullTableTest2 = + tblTable { tblName = "nulltests2" , tblVersion = 1 , tblColumns = - [ tblColumn { colName = "id", colType = UuidT, colNullable = False, colDefault = Just "gen_random_uuid()" } - , tblColumn { colName = "content", colType = TextT, colNullable = True } - ] + [ tblColumn {colName = "id", colType = UuidT, colNullable = False, colDefault = Just "gen_random_uuid()"} + , tblColumn {colName = "content", colType = TextT, colNullable = True} + ] , tblPrimaryKey = pkOnColumn "id" , tblIndexes = - [ (uniqueIndexOnColumn "content") { idxNotDistinctNulls = True } - ] + [ (uniqueIndexOnColumn "content") {idxNotDistinctNulls = True} + ] } sqlAnyAllTests :: TestTree -sqlAnyAllTests = testGroup "SQL ANY/ALL tests" - [ testCase "sqlAll produces correct queries" $ do - assertSqlEqual "empty sqlAll is TRUE" "TRUE" . sqlAll $ pure () - assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAll $ sqlWhere "cond" - assertSqlEqual "each condition as well as entire condition is parenthesized" - "((cond1) AND (cond2))" $ sqlAll $ do - sqlWhere "cond1" - sqlWhere "cond2" - - assertSqlEqual "sqlAll can be nested" - "((cond1) AND (cond2) AND (((cond3) AND (cond4))))" $ - sqlAll $ do +sqlAnyAllTests = + testGroup + "SQL ANY/ALL tests" + [ testCase "sqlAll produces correct queries" $ do + assertSqlEqual "empty sqlAll is TRUE" "TRUE" . sqlAll $ pure () + assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAll $ sqlWhere "cond" + assertSqlEqual + "each condition as well as entire condition is parenthesized" + "((cond1) AND (cond2))" + $ sqlAll + $ do + sqlWhere "cond1" + sqlWhere "cond2" + + assertSqlEqual + "sqlAll can be nested" + "((cond1) AND (cond2) AND (((cond3) AND (cond4))))" + $ sqlAll + $ do sqlWhere "cond1" sqlWhere "cond2" sqlWhere . sqlAll $ do sqlWhere "cond3" sqlWhere "cond4" - , testCase "sqlAny produces correct queries" $ do - assertSqlEqual "empty sqlAny is FALSE" "FALSE" . sqlAny $ pure () - assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAny $ sqlWhere "cond" - assertSqlEqual "each condition as well as entire condition is parenthesized" - "((cond1) OR (cond2))" $ sqlAny $ do - sqlWhere "cond1" - sqlWhere "cond2" - - assertSqlEqual "sqlAny can be nested" - "((cond1) OR (cond2) OR (((cond3) OR (cond4))))" $ - sqlAny $ do + , testCase "sqlAny produces correct queries" $ do + assertSqlEqual "empty sqlAny is FALSE" "FALSE" . sqlAny $ pure () + assertSqlEqual "sigle condition is emmited as is" "cond" $ sqlAny $ sqlWhere "cond" + assertSqlEqual + "each condition as well as entire condition is parenthesized" + "((cond1) OR (cond2))" + $ sqlAny + $ do + sqlWhere "cond1" + sqlWhere "cond2" + + assertSqlEqual + "sqlAny can be nested" + "((cond1) OR (cond2) OR (((cond3) OR (cond4))))" + $ sqlAny + $ do sqlWhere "cond1" sqlWhere "cond2" sqlWhere . sqlAny $ do sqlWhere "cond3" sqlWhere "cond4" - , testCase "mixing sqlAny and all produces correct queries" $ do - assertSqlEqual "sqlAny and sqlAll can be mixed" - "((((cond1) OR (cond2))) AND (((cond3) OR (cond4))))" $ - sqlAll $ do + , testCase "mixing sqlAny and all produces correct queries" $ do + assertSqlEqual + "sqlAny and sqlAll can be mixed" + "((((cond1) OR (cond2))) AND (((cond3) OR (cond4))))" + $ sqlAll + $ do sqlWhere . sqlAny $ do sqlWhere "cond1" sqlWhere "cond2" sqlWhere . sqlAny $ do sqlWhere "cond3" sqlWhere "cond4" - , testCase "sqlWhereAny produces correct queries" $ do - -- `sqlWhereAny` has to be wrapped in `sqlAll` to disambiguate the `SqlWhere` monad. - assertSqlEqual "empty sqlWhereAny is FALSE" "FALSE" . sqlAll $ sqlWhereAny [] - assertSqlEqual "each condition as well as entire condition is parenthesized and joined with OR" - "((cond1) OR (cond2))" . sqlAll $ sqlWhereAny [sqlWhere "cond1", sqlWhere "cond2"] - assertSqlEqual "nested multi-conditions are parenthesized and joined with AND" - "((cond1) OR (((cond2) AND (cond3))) OR (cond4))" . sqlAll $ sqlWhereAny - [ sqlWhere "cond1" - , do - sqlWhere "cond2" - sqlWhere "cond3" - , sqlWhere "cond4" - ] - ] + , testCase "sqlWhereAny produces correct queries" $ do + -- `sqlWhereAny` has to be wrapped in `sqlAll` to disambiguate the `SqlWhere` monad. + assertSqlEqual "empty sqlWhereAny is FALSE" "FALSE" . sqlAll $ sqlWhereAny [] + assertSqlEqual + "each condition as well as entire condition is parenthesized and joined with OR" + "((cond1) OR (cond2))" + . sqlAll + $ sqlWhereAny [sqlWhere "cond1", sqlWhere "cond2"] + assertSqlEqual + "nested multi-conditions are parenthesized and joined with AND" + "((cond1) OR (((cond2) AND (cond3))) OR (cond4))" + . sqlAll + $ sqlWhereAny + [ sqlWhere "cond1" + , do + sqlWhere "cond2" + sqlWhere "cond3" + , sqlWhere "cond4" + ] + ] where - assertSqlEqual :: (Sqlable a) => String -> a -> a -> Assertion - assertSqlEqual msg a b = assertEqual msg - (show $ toSQLCommand a) (show $ toSQLCommand b) + assertSqlEqual :: Sqlable a => String -> a -> a -> Assertion + assertSqlEqual msg a b = + assertEqual + msg + (show $ toSQLCommand a) + (show $ toSQLCommand b) assertNoException :: String -> TestM () -> TestM () -assertNoException t = eitherExc - (const $ liftIO $ assertFailure ("Exception thrown for: " ++ t)) - (const $ return ()) +assertNoException t = + eitherExc + (const $ liftIO $ assertFailure ("Exception thrown for: " ++ t)) + (const $ return ()) assertException :: String -> TestM () -> TestM () -assertException t = eitherExc - (const $ return ()) - (const $ liftIO $ assertFailure ("No exception thrown for: " ++ t)) +assertException t = + eitherExc + (const $ return ()) + (const $ liftIO $ assertFailure ("No exception thrown for: " ++ t)) assertDBException :: String -> TestM () -> TestM () assertDBException t c = - try c >>= either (\DBException{} -> pure ()) - (const . liftIO . assertFailure $ "No DBException thrown for: " ++ t) + try c + >>= either + (\DBException {} -> pure ()) + (const . liftIO . assertFailure $ "No DBException thrown for: " ++ t) -- | A variant of testCaseSteps that works in TestM monad. -testCaseSteps' :: TestName -> ConnectionSourceM (LogT IO) - -> ((String -> TestM ()) -> TestM ()) - -> TestTree +testCaseSteps' + :: TestName + -> ConnectionSourceM (LogT IO) + -> ((String -> TestM ()) -> TestM ()) + -> TestTree testCaseSteps' testName connSource f = testCaseSteps testName $ \step' -> do - let step s = liftIO $ step' s - withStdOutLogger $ \logger -> - runLogT "hpqtypes-extras-test" logger defaultLogLevel $ - runDBT connSource defaultTransactionSettings $ - f step + let step s = liftIO $ step' s + withStdOutLogger $ \logger -> + runLogT "hpqtypes-extras-test" logger defaultLogLevel $ + runDBT connSource defaultTransactionSettings $ + f step main :: IO () main = do defaultMainWithIngredients ings $ askOption $ \(ConnectionString connectionString) -> - let connSettings = defaultConnectionSettings - { csConnInfo = T.pack connectionString } - ConnectionSource connSource = simpleSource connSettings - in - testGroup "DB tests" [ migrationTest1 connSource - , migrationTest2 connSource - , migrationTest3 connSource - , migrationTest4 connSource - , migrationTest5 connSource - , triggerTests connSource - , sqlWithTests connSource - , unionTests connSource - , unionAllTests connSource - , sqlWithRecursiveTests connSource - , foreignKeyIndexesTests connSource - , overlapingIndexesTests connSource - , nullsNotDistinctTests connSource - , sqlAnyAllTests - ] + let connSettings = + defaultConnectionSettings + { csConnInfo = T.pack connectionString + } + ConnectionSource connSource = simpleSource connSettings + in testGroup + "DB tests" + [ migrationTest1 connSource + , migrationTest2 connSource + , migrationTest3 connSource + , migrationTest4 connSource + , migrationTest5 connSource + , triggerTests connSource + , sqlWithTests connSource + , unionTests connSource + , unionAllTests connSource + , sqlWithRecursiveTests connSource + , foreignKeyIndexesTests connSource + , overlapingIndexesTests connSource + , nullsNotDistinctTests connSource + , sqlAnyAllTests + ] where ings = includingOptions [Option (Proxy :: Proxy ConnectionString)] - : defaultIngredients + : defaultIngredients