Skip to content

Commit

Permalink
Rudimentary enum support
Browse files Browse the repository at this point in the history
  • Loading branch information
zlondrej committed Jan 8, 2025
1 parent a993e40 commit fcd4d9f
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 3 deletions.
1 change: 1 addition & 0 deletions hpqtypes-extras.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ library
, Database.PostgreSQL.PQTypes.Model.ColumnType
, Database.PostgreSQL.PQTypes.Model.CompositeType
, Database.PostgreSQL.PQTypes.Model.Domain
, Database.PostgreSQL.PQTypes.Model.EnumType
, Database.PostgreSQL.PQTypes.Model.Extension
, Database.PostgreSQL.PQTypes.Model.ForeignKey
, Database.PostgreSQL.PQTypes.Model.Index
Expand Down
49 changes: 46 additions & 3 deletions src/Database/PostgreSQL/PQTypes/Checks.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ migrateDatabase
=> ExtrasOptions
-> [Extension]
-> [CompositeType]
-> [EnumType]
-> [Domain]
-> [Table]
-> [Migration m]
Expand All @@ -63,20 +64,22 @@ migrateDatabase
options
extensions
composites
enums
domains
tables
migrations = do
setDBTimeZoneToUTC
mapM_ checkExtension extensions
tablesWithVersions <- getTableVersions (tableVersions : tables)
-- 'checkDBConsistency' also performs migrations.
checkDBConsistency options domains tablesWithVersions migrations
checkDBConsistency options domains enums tablesWithVersions migrations
resultCheck
=<< checkCompositesStructure
tablesWithVersions
CreateCompositesIfDatabaseEmpty
(eoObjectsValidationMode options)
composites
resultCheck =<< checkEnumTypes enums
resultCheck =<< checkDomainsStructure domains
resultCheck =<< checkDBStructure options tablesWithVersions
resultCheck =<< checkTablesWereDropped migrations
Expand All @@ -98,10 +101,11 @@ checkDatabase
. (MonadDB m, MonadLog m, MonadThrow m)
=> ExtrasOptions
-> [CompositeType]
-> [EnumType]
-> [Domain]
-> [Table]
-> m ()
checkDatabase options composites domains tables = do
checkDatabase options composites enums domains tables = do
tablesWithVersions <- getTableVersions (tableVersions : tables)
resultCheck $ checkVersions options tablesWithVersions
resultCheck
Expand All @@ -110,6 +114,7 @@ checkDatabase options composites domains tables = do
DontCreateComposites
(eoObjectsValidationMode options)
composites
resultCheck =<< checkEnumTypes enums
resultCheck =<< checkDomainsStructure domains
resultCheck =<< checkDBStructure options tablesWithVersions
when (eoObjectsValidationMode options == DontAllowUnknownObjects) $ do
Expand Down Expand Up @@ -340,6 +345,41 @@ checkDomainsStructure defs = fmap mconcat . forM defs $ \def -> do
<+> T.pack (show $ attr def)
<> ")"

checkEnumTypes
:: (MonadDB m, MonadThrow m)
=> [EnumType]
-> m ValidationResult
checkEnumTypes defs = fmap mconcat . forM defs $ \def -> do
runQuery_ . sqlSelect "pg_catalog.pg_type t" $ do
sqlResult "t.typname::text" -- name
sqlResult
"ARRAY(select e.enumlabel from pg_catalog.pg_enum e where e.enumtypid = t.oid order by e.enumsortorder)" -- values
sqlWhereEq "t.typname" $ unRawSQL $ etName def
enum <- fetchMaybe $
\(enumName, enumValues) ->
EnumType
{ etName = unsafeSQL enumName
, etValues = map unsafeSQL $ unArray1 enumValues
}
return $ case enum of
Just e
| e /= def ->
topMessage "enum" (unRawSQL $ etName e) $
validationError $
"Enum '"
<> unRawSQL (etName e)
<> "' does not match (database:"
<+> T.pack (show $ etValues e)
<> ", definition:"
<+> T.pack (show $ etValues def)
<> ")"
| otherwise -> mempty
Nothing ->
validationError $
"Enum '"
<> unRawSQL (etName def)
<> "' doesn't exist in the database"

-- | Check that the tables that must have been dropped are actually
-- missing from the DB.
checkTablesWereDropped
Expand Down Expand Up @@ -748,10 +788,11 @@ checkDBConsistency
. (MonadIO m, MonadDB m, MonadLog m, MonadMask m)
=> ExtrasOptions
-> [Domain]
-> [EnumType]
-> TablesWithVersions
-> [Migration m]
-> m ()
checkDBConsistency options domains tablesWithVersions migrations = do
checkDBConsistency options domains enums tablesWithVersions migrations = do
autoTransaction <- tsAutoTransaction <$> getTransactionSettings
unless autoTransaction $ do
error "checkDBConsistency: tsAutoTransaction setting needs to be True"
Expand Down Expand Up @@ -876,6 +917,8 @@ checkDBConsistency options domains tablesWithVersions migrations = do
createDBSchema = do
logInfo_ "Creating domains..."
mapM_ createDomain domains
logInfo_ "Creating enums..."
mapM_ (runQuery_ . sqlCreateEnum) enums
-- Create all tables with no constraints first to allow cyclic references.
logInfo_ "Creating tables..."
mapM_ (createTable False) tables
Expand Down
2 changes: 2 additions & 0 deletions src/Database/PostgreSQL/PQTypes/Model.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Database.PostgreSQL.PQTypes.Model
, module Database.PostgreSQL.PQTypes.Model.ColumnType
, module Database.PostgreSQL.PQTypes.Model.CompositeType
, module Database.PostgreSQL.PQTypes.Model.Domain
, module Database.PostgreSQL.PQTypes.Model.EnumType
, module Database.PostgreSQL.PQTypes.Model.Extension
, module Database.PostgreSQL.PQTypes.Model.ForeignKey
, module Database.PostgreSQL.PQTypes.Model.Index
Expand All @@ -16,6 +17,7 @@ import Database.PostgreSQL.PQTypes.Model.Check
import Database.PostgreSQL.PQTypes.Model.ColumnType
import Database.PostgreSQL.PQTypes.Model.CompositeType
import Database.PostgreSQL.PQTypes.Model.Domain
import Database.PostgreSQL.PQTypes.Model.EnumType
import Database.PostgreSQL.PQTypes.Model.Extension
import Database.PostgreSQL.PQTypes.Model.ForeignKey
import Database.PostgreSQL.PQTypes.Model.Index
Expand Down
32 changes: 32 additions & 0 deletions src/Database/PostgreSQL/PQTypes/Model/EnumType.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module Database.PostgreSQL.PQTypes.Model.EnumType
( EnumType (..)
, sqlCreateEnum
, sqlDropEnum
) where

import Data.Monoid.Utils
import Data.Text qualified as T
import Database.PostgreSQL.PQTypes

data EnumType = EnumType
{ etName :: !(RawSQL ())
, etValues :: ![RawSQL ()]
}
deriving (Eq, Ord, Show)

-- | Make SQL query that creates a composite type.
sqlCreateEnum :: EnumType -> RawSQL ()
sqlCreateEnum EnumType {..} =
smconcat
[ "CREATE TYPE"
, etName
, "AS ENUM ("
, mintercalate ", " $ map quotedValue etValues
, ")"
]
where
quotedValue v = rawSQL ("'" <> T.replace "'" "''" (unRawSQL v) <> "'" :: T.Text) ()

-- | Make SQL query that drops a composite type.
sqlDropEnum :: RawSQL () -> RawSQL ()
sqlDropEnum = ("DROP TYPE" <+>)

0 comments on commit fcd4d9f

Please sign in to comment.