Skip to content

Commit

Permalink
Add support for RECURSIVE withs
Browse files Browse the repository at this point in the history
  • Loading branch information
Raveline committed Jul 18, 2024
1 parent 6326d31 commit 11631a0
Showing 1 changed file with 107 additions and 67 deletions.
174 changes: 107 additions & 67 deletions src/Database/PostgreSQL/PQTypes/SQL/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ module Database.PostgreSQL.PQTypes.SQL.Builder
, sqlLimit
, sqlDistinct
, sqlWith
, sqlWithRecursive
, sqlWithMaterialized
, sqlUnion
, sqlUnionAll
, checkAndRememberMaterializationSupport
, checkAndRememberRecursiveSupport

, sqlSelect
, sqlSelect2
Expand Down Expand Up @@ -169,11 +171,11 @@ module Database.PostgreSQL.PQTypes.SQL.Builder
)
where

import Control.Monad.State
import Control.Monad.Catch
import Control.Monad.State
import Data.Either
import Data.Int
import Data.IORef
import Data.Either
import Data.List
import Data.Maybe
import Data.Monoid.Utils
Expand Down Expand Up @@ -230,59 +232,64 @@ instance Sqlable SqlCondition where
toSQLCommand (SqlExistsCondition a) = "EXISTS (" <> toSQLCommand (a { sqlSelectResult = ["TRUE"] }) <> ")"

data SqlSelect = SqlSelect
{ sqlSelectFrom :: SQL
, sqlSelectUnion :: [SQL]
, sqlSelectUnionAll :: [SQL]
, sqlSelectDistinct :: Bool
, sqlSelectResult :: [SQL]
, sqlSelectWhere :: [SqlCondition]
, sqlSelectOrderBy :: [SQL]
, sqlSelectGroupBy :: [SQL]
, sqlSelectHaving :: [SQL]
, sqlSelectOffset :: Integer
, sqlSelectLimit :: Integer
, sqlSelectWith :: [(SQL, SQL, Materialized)]
{ sqlSelectFrom :: SQL
, sqlSelectUnion :: [SQL]
, sqlSelectUnionAll :: [SQL]
, sqlSelectDistinct :: Bool
, sqlSelectResult :: [SQL]
, sqlSelectWhere :: [SqlCondition]
, sqlSelectOrderBy :: [SQL]
, sqlSelectGroupBy :: [SQL]
, sqlSelectHaving :: [SQL]
, sqlSelectOffset :: Integer
, sqlSelectLimit :: Integer
, sqlSelectWith :: [(SQL, SQL, Materialized)]
, sqlSelectRecursiveWith :: Recursive
}

data SqlUpdate = SqlUpdate
{ sqlUpdateWhat :: SQL
, sqlUpdateFrom :: SQL
, sqlUpdateWhere :: [SqlCondition]
, sqlUpdateSet :: [(SQL,SQL)]
, sqlUpdateResult :: [SQL]
, sqlUpdateWith :: [(SQL, SQL, Materialized)]
{ sqlUpdateWhat :: SQL
, sqlUpdateFrom :: SQL
, sqlUpdateWhere :: [SqlCondition]
, sqlUpdateSet :: [(SQL,SQL)]
, sqlUpdateResult :: [SQL]
, sqlUpdateWith :: [(SQL, SQL, Materialized)]
, sqlUpdateRecursiveWith :: Recursive
}

data SqlInsert = SqlInsert
{ sqlInsertWhat :: SQL
, sqlInsertOnConflict :: Maybe (SQL, Maybe SQL)
, sqlInsertSet :: [(SQL, Multiplicity SQL)]
, sqlInsertResult :: [SQL]
, sqlInsertWith :: [(SQL, SQL, Materialized)]
{ sqlInsertWhat :: SQL
, sqlInsertOnConflict :: Maybe (SQL, Maybe SQL)
, sqlInsertSet :: [(SQL, Multiplicity SQL)]
, sqlInsertResult :: [SQL]
, sqlInsertWith :: [(SQL, SQL, Materialized)]
, sqlInsertRecursiveWith :: Recursive
}

data SqlInsertSelect = SqlInsertSelect
{ sqlInsertSelectWhat :: SQL
, sqlInsertSelectOnConflict :: Maybe (SQL, Maybe SQL)
, sqlInsertSelectDistinct :: Bool
, sqlInsertSelectSet :: [(SQL, SQL)]
, sqlInsertSelectResult :: [SQL]
, sqlInsertSelectFrom :: SQL
, sqlInsertSelectWhere :: [SqlCondition]
, sqlInsertSelectOrderBy :: [SQL]
, sqlInsertSelectGroupBy :: [SQL]
, sqlInsertSelectHaving :: [SQL]
, sqlInsertSelectOffset :: Integer
, sqlInsertSelectLimit :: Integer
, sqlInsertSelectWith :: [(SQL, SQL, Materialized)]
{ sqlInsertSelectWhat :: SQL
, sqlInsertSelectOnConflict :: Maybe (SQL, Maybe SQL)
, sqlInsertSelectDistinct :: Bool
, sqlInsertSelectSet :: [(SQL, SQL)]
, sqlInsertSelectResult :: [SQL]
, sqlInsertSelectFrom :: SQL
, sqlInsertSelectWhere :: [SqlCondition]
, sqlInsertSelectOrderBy :: [SQL]
, sqlInsertSelectGroupBy :: [SQL]
, sqlInsertSelectHaving :: [SQL]
, sqlInsertSelectOffset :: Integer
, sqlInsertSelectLimit :: Integer
, sqlInsertSelectWith :: [(SQL, SQL, Materialized)]
, sqlInsertSelectRecursiveWith :: Recursive
}

data SqlDelete = SqlDelete
{ sqlDeleteFrom :: SQL
, sqlDeleteUsing :: SQL
, sqlDeleteWhere :: [SqlCondition]
, sqlDeleteResult :: [SQL]
, sqlDeleteWith :: [(SQL, SQL, Materialized)]
{ sqlDeleteFrom :: SQL
, sqlDeleteUsing :: SQL
, sqlDeleteWhere :: [SqlCondition]
, sqlDeleteResult :: [SQL]
, sqlDeleteWith :: [(SQL, SQL, Materialized)]
, sqlDeleteRecursiveWith :: Recursive
}

-- | This is not exported and is used as an implementation detail in
Expand Down Expand Up @@ -340,8 +347,8 @@ instance IsSQL SqlDelete where

instance Sqlable SqlSelect where
toSQLCommand cmd = smconcat
[ emitClausesSepComma "WITH" $
map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlSelectWith cmd)
[ emitClausesSepComma (recursiveClause $ sqlSelectRecursiveWith cmd) $
map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlSelectWith cmd)
, if hasUnion || hasUnionAll
then emitClausesSep "" unionKeyword (mainSelectClause : unionCmd)
else mainSelectClause
Expand Down Expand Up @@ -394,7 +401,8 @@ emitClauseOnConflictForInsert = \case

instance Sqlable SqlInsert where
toSQLCommand cmd =
emitClausesSepComma "WITH" (map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlInsertWith cmd)) <+>
emitClausesSepComma (recursiveClause $ sqlInsertRecursiveWith cmd)
(map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlInsertWith cmd)) <+>
"INSERT INTO" <+> sqlInsertWhat cmd <+>
parenthesize (sqlConcatComma (map fst (sqlInsertSet cmd))) <+>
emitClausesSep "VALUES" "," (map sqlConcatComma (transpose (map (makeLongEnough . snd) (sqlInsertSet cmd)))) <+>
Expand All @@ -412,8 +420,8 @@ instance Sqlable SqlInsertSelect where
toSQLCommand cmd = smconcat
-- WITH clause needs to be at the top level, so we emit it here and not
-- include it in the SqlSelect below.
[ emitClausesSepComma "WITH" $
map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlInsertSelectWith cmd)
[ emitClausesSepComma (recursiveClause $ sqlInsertSelectRecursiveWith cmd) $
map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlInsertSelectWith cmd)
, "INSERT INTO" <+> sqlInsertSelectWhat cmd
, parenthesize . sqlConcatComma . map fst $ sqlInsertSelectSet cmd
, parenthesize . toSQLCommand $ SqlSelect { sqlSelectFrom = sqlInsertSelectFrom cmd
Expand All @@ -428,6 +436,7 @@ instance Sqlable SqlInsertSelect where
, sqlSelectOffset = sqlInsertSelectOffset cmd
, sqlSelectLimit = sqlInsertSelectLimit cmd
, sqlSelectWith = []
, sqlSelectRecursiveWith = NonRecursive
}
, emitClauseOnConflictForInsert (sqlInsertSelectOnConflict cmd)
, emitClausesSepComma "RETURNING" $ sqlInsertSelectResult cmd
Expand All @@ -443,6 +452,16 @@ checkAndRememberMaterializationSupport = do
fetchOne runIdentity
liftIO $ writeIORef withMaterializedSupported (isRight res)

-- This function has to be called as one of first things in your program
-- for the library to make sure that it is aware if the "WITH RECURSIVE"
-- clause is supported by your PostgreSQL version.
checkAndRememberRecursiveSupport :: (MonadDB m, MonadIO m, MonadMask m) => m ()
checkAndRememberRecursiveSupport = do
res :: Either DBException Int64 <- try . withNewConnection $ do
runSQL01_ "WITH RECURSIVE t(n) AS (SELECT (1 :: bigint)) SELECT n FROM t LIMIT 1"
fetchOne runIdentity
liftIO $ writeIORef withRecursiveSupported (isRight res)

withMaterializedSupported :: IORef Bool
{-# NOINLINE withMaterializedSupported #-}
withMaterializedSupported = unsafePerformIO $ newIORef False
Expand All @@ -451,13 +470,26 @@ isWithMaterializedSupported :: Bool
{-# NOINLINE isWithMaterializedSupported #-}
isWithMaterializedSupported = unsafePerformIO $ readIORef withMaterializedSupported

withRecursiveSupported :: IORef Bool
{-# NOINLINE withRecursiveSupported #-}
withRecursiveSupported = unsafePerformIO $ newIORef False

isWithRecursiveSupported :: Bool
{-# NOINLINE isWithRecursiveSupported #-}
isWithRecursiveSupported = unsafePerformIO $ readIORef withRecursiveSupported

materializedClause :: Materialized -> SQL
materializedClause Materialized = if isWithMaterializedSupported then "MATERIALIZED" else ""
materializedClause NonMaterialized = if isWithMaterializedSupported then "NOT MATERIALIZED" else ""

recursiveClause :: Recursive -> SQL
recursiveClause Recursive = if isWithRecursiveSupported then "WITH RECURSIVE" else "WITH"
recursiveClause NonRecursive = "WITH"

instance Sqlable SqlUpdate where
toSQLCommand cmd =
emitClausesSepComma "WITH" (map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlUpdateWith cmd)) <+>
emitClausesSepComma (recursiveClause $ sqlUpdateRecursiveWith cmd)
(map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlUpdateWith cmd)) <+>
"UPDATE" <+> sqlUpdateWhat cmd <+> "SET" <+>
sqlConcatComma (map (\(name, command) -> name <> "=" <> command) (sqlUpdateSet cmd)) <+>
emitClause "FROM" (sqlUpdateFrom cmd) <+>
Expand All @@ -466,7 +498,8 @@ instance Sqlable SqlUpdate where

instance Sqlable SqlDelete where
toSQLCommand cmd =
emitClausesSepComma "WITH" (map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlDeleteWith cmd)) <+>
emitClausesSepComma (recursiveClause $ sqlDeleteRecursiveWith cmd)
(map (\(name,command,mat) -> name <+> "AS" <+> materializedClause mat <+> parenthesize command) (sqlDeleteWith cmd)) <+>
"DELETE FROM" <+> sqlDeleteFrom cmd <+>
emitClause "USING" (sqlDeleteUsing cmd) <+>
emitClausesSep "WHERE" "AND" (map toSQLCommand $ sqlDeleteWhere cmd) <+>
Expand All @@ -479,15 +512,15 @@ instance Sqlable SqlAll where

sqlSelect :: SQL -> State SqlSelect () -> SqlSelect
sqlSelect table refine =
execState refine (SqlSelect table [] [] False [] [] [] [] [] 0 (-1) [])
execState refine (SqlSelect table [] [] False [] [] [] [] [] 0 (-1) [] NonRecursive)

sqlSelect2 :: SQL -> State SqlSelect () -> SqlSelect
sqlSelect2 from refine =
execState refine (SqlSelect from [] [] False [] [] [] [] [] 0 (-1) [])
execState refine (SqlSelect from [] [] False [] [] [] [] [] 0 (-1) [] NonRecursive)

sqlInsert :: SQL -> State SqlInsert () -> SqlInsert
sqlInsert table refine =
execState refine (SqlInsert table Nothing mempty [] [])
execState refine (SqlInsert table Nothing mempty [] [] NonRecursive)

sqlInsertSelect :: SQL -> SQL -> State SqlInsertSelect () -> SqlInsertSelect
sqlInsertSelect table from refine =
Expand All @@ -505,11 +538,12 @@ sqlInsertSelect table from refine =
, sqlInsertSelectOffset = 0
, sqlInsertSelectLimit = -1
, sqlInsertSelectWith = []
, sqlInsertSelectRecursiveWith = NonRecursive
})

sqlUpdate :: SQL -> State SqlUpdate () -> SqlUpdate
sqlUpdate table refine =
execState refine (SqlUpdate table mempty [] [] [] [])
execState refine (SqlUpdate table mempty [] [] [] [] NonRecursive)

sqlDelete :: SQL -> State SqlDelete () -> SqlDelete
sqlDelete table refine =
Expand All @@ -518,32 +552,38 @@ sqlDelete table refine =
, sqlDeleteWhere = []
, sqlDeleteResult = []
, sqlDeleteWith = []
, sqlDeleteRecursiveWith = NonRecursive
})


data Materialized = Materialized | NonMaterialized
data Recursive = Recursive | NonRecursive

class SqlWith a where
sqlWith1 :: a -> SQL -> SQL -> Materialized -> a
sqlWith1 :: a -> SQL -> SQL -> Materialized -> Recursive -> a


instance SqlWith SqlSelect where
sqlWith1 cmd name sql mat = cmd { sqlSelectWith = sqlSelectWith cmd ++ [(name,sql, mat)] }
sqlWith1 cmd name sql mat recurse = cmd { sqlSelectWith = sqlSelectWith cmd ++ [(name,sql,mat)], sqlSelectRecursiveWith = recurse }

instance SqlWith SqlInsertSelect where
sqlWith1 cmd name sql mat = cmd { sqlInsertSelectWith = sqlInsertSelectWith cmd ++ [(name,sql,mat)] }
sqlWith1 cmd name sql mat recurse = cmd { sqlInsertSelectWith = sqlInsertSelectWith cmd ++ [(name,sql,mat)], sqlInsertSelectRecursiveWith = recurse }

instance SqlWith SqlUpdate where
sqlWith1 cmd name sql mat = cmd { sqlUpdateWith = sqlUpdateWith cmd ++ [(name,sql,mat)] }
sqlWith1 cmd name sql mat recurse = cmd { sqlUpdateWith = sqlUpdateWith cmd ++ [(name,sql,mat)], sqlUpdateRecursiveWith = recurse }

instance SqlWith SqlDelete where
sqlWith1 cmd name sql mat = cmd { sqlDeleteWith = sqlDeleteWith cmd ++ [(name,sql,mat)] }
sqlWith1 cmd name sql mat recurse = cmd { sqlDeleteWith = sqlDeleteWith cmd ++ [(name,sql,mat)], sqlDeleteRecursiveWith = recurse }

sqlWith :: (MonadState v m, SqlWith v, Sqlable s) => SQL -> s -> m ()
sqlWith name sql = modify (\cmd -> sqlWith1 cmd name (toSQLCommand sql) NonMaterialized)
sqlWith name sql = modify (\cmd -> sqlWith1 cmd name (toSQLCommand sql) NonMaterialized NonRecursive)

sqlWithMaterialized :: (MonadState v m, SqlWith v, Sqlable s) => SQL -> s -> m ()
sqlWithMaterialized name sql = modify (\cmd -> sqlWith1 cmd name (toSQLCommand sql) Materialized)
sqlWithMaterialized name sql = modify (\cmd -> sqlWith1 cmd name (toSQLCommand sql) Materialized NonRecursive)

-- | Note: RECURSIVE only powers SELECTs (but the SELECT can feed an UPDATE outside of the recursive query).
sqlWithRecursive :: (MonadState v m, SqlWith v, Sqlable s) => SQL -> s -> m ()
sqlWithRecursive name sql = modify (\cmd -> sqlWith1 cmd name (toSQLCommand sql) NonMaterialized Recursive)

-- | Note: WHERE clause of the main SELECT is treated specially, i.e. it only
-- applies to the main SELECT, not the whole union.
Expand Down Expand Up @@ -731,19 +771,19 @@ class SqlOnConflict a where
sqlOnConflictOnColumns1 :: Sqlable sql => a -> [SQL] -> sql -> a

instance SqlOnConflict SqlInsert where
sqlOnConflictDoNothing1 cmd =
sqlOnConflictDoNothing1 cmd =
cmd { sqlInsertOnConflict = Just ("", Nothing) }
sqlOnConflictOnColumns1 cmd columns sql =
sqlOnConflictOnColumns1 cmd columns sql =
cmd { sqlInsertOnConflict = Just (parenthesize $ sqlConcatComma columns, Just $ toSQLCommand sql) }
sqlOnConflictOnColumnsDoNothing1 cmd columns =
sqlOnConflictOnColumnsDoNothing1 cmd columns =
cmd { sqlInsertOnConflict = Just (parenthesize $ sqlConcatComma columns, Nothing) }

instance SqlOnConflict SqlInsertSelect where
sqlOnConflictDoNothing1 cmd =
sqlOnConflictDoNothing1 cmd =
cmd { sqlInsertSelectOnConflict = Just ("", Nothing) }
sqlOnConflictOnColumns1 cmd columns sql =
sqlOnConflictOnColumns1 cmd columns sql =
cmd { sqlInsertSelectOnConflict = Just (parenthesize $ sqlConcatComma columns, Just $ toSQLCommand sql) }
sqlOnConflictOnColumnsDoNothing1 cmd columns =
sqlOnConflictOnColumnsDoNothing1 cmd columns =
cmd { sqlInsertSelectOnConflict = Just (parenthesize $ sqlConcatComma columns, Nothing) }

sqlOnConflictDoNothing :: (MonadState v m, SqlOnConflict v) => m ()
Expand Down

0 comments on commit 11631a0

Please sign in to comment.