From c73180a62a92990cdba3da03127ab46130c5c4eb Mon Sep 17 00:00:00 2001 From: Jacques Fuentes Date: Mon, 23 Jan 2017 10:05:07 -0500 Subject: [PATCH] WIP: use phantom types to restrict roles --- src/main/scala/com/simple/jdub/Database.scala | 57 +++++++++++-------- .../scala/com/simple/jdub/Queryable.scala | 18 +++--- src/main/scala/com/simple/jdub/RawQuery.scala | 4 +- .../scala/com/simple/jdub/Transaction.scala | 21 ++++--- .../com/simple/jdub/TransactionManager.scala | 29 +++++----- .../com/simple/jdub/tests/DatabaseSpec.scala | 10 ++++ .../com/simple/jdub/tests/JdubSpec.scala | 4 +- 7 files changed, 87 insertions(+), 56 deletions(-) diff --git a/src/main/scala/com/simple/jdub/Database.scala b/src/main/scala/com/simple/jdub/Database.scala index f9af0d4..a75c6a1 100644 --- a/src/main/scala/com/simple/jdub/Database.scala +++ b/src/main/scala/com/simple/jdub/Database.scala @@ -4,14 +4,21 @@ import com.zaxxer.hikari.{HikariConfig, HikariDataSource} import java.io.FileInputStream import java.security.KeyStore -import java.util.{UUID, Properties} +import java.util.{Properties, UUID} import javax.sql.DataSource - import com.codahale.metrics.MetricRegistry import com.codahale.metrics.health.HealthCheckRegistry +import com.simple.jdub.Database.Primary +import com.simple.jdub.Database.Replica + +import scala.annotation.implicitNotFound object Database { + sealed trait Role + final class Primary extends Role + final class Replica extends Role + /** * Create a pool of connections to the given database. * @@ -20,7 +27,7 @@ object Database { * @param password the database password * @param sslSettings if present, uses the given SSL settings for a client-side SSL cert. */ - def connect(url: String, + def connect[R <: Role](url: String, username: String, password: String, name: Option[String] = None, @@ -30,7 +37,7 @@ object Database { sslSettings: Option[SslSettings] = None, healthCheckRegistry: Option[HealthCheckRegistry] = None, metricRegistry: Option[MetricRegistry] = None, - connectionInitSql: Option[String] = None): Database = { + connectionInitSql: Option[String] = None): Database[R] = { val properties = new Properties @@ -95,8 +102,8 @@ object Database { /** * A set of pooled connections to a database. */ -class Database protected(val source: DataSource, metrics: Option[MetricRegistry]) - extends Queryable { +class Database[R <: Database.Role] protected(val source: DataSource, metrics: Option[MetricRegistry]) + extends Queryable[R] { private[jdub] def time[A](klass: java.lang.Class[_])(f: => A) = { metrics.fold(f) { registry => @@ -110,34 +117,38 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry] } } - val transactionProvider: TransactionProvider = new TransactionManager + val transactionProvider: TransactionProvider[R] = new TransactionManager + + def replica: Database[Replica] = new Database[Replica](source, metrics) + + def primary: Database[Primary] = new Database[Primary](source, metrics) /** * Opens a transaction which is committed after `f` is called. If `f` throws * an exception, the transaction is rolled back. */ - def transaction[A](f: Transaction => A): A = transaction(true, f) + def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(true, f) /** * Opens a transaction which is committed after `f` is called. If `f` throws * an exception, the transaction is rolled back, but the exception is not * logged (since it is rethrown). */ - def quietTransaction[A](f: Transaction => A): A = transaction(false, f) + def quietTransaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(false, f) - def transaction[A](logError: Boolean, f: Transaction => A): A = transaction(false, false, f) + def transaction[A](logError: Boolean, f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(false, false, f) /** * Opens a transaction which is committed after `f` is called. If `f` throws * an exception, the transaction is rolled back. */ - def transaction[A](logError: Boolean, forceNew: Boolean, f: Transaction => A): A = { + def transaction[A](logError: Boolean, forceNew: Boolean, f: Transaction[R] => A)(implicit ev: R =:= Primary): A = { if (!forceNew && transactionProvider.transactionExists) { f(transactionProvider.currentTransaction) } else { val connection = source.getConnection connection.setAutoCommit(false) - val txn = new Transaction(connection) + val txn = new Transaction[R](connection) try { logger.debug("Starting transaction") val result = f(txn) @@ -162,8 +173,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry] * thread within the scope of `f`. If `f` throws an exception the transaction * is rolled back. Logs exceptions thrown by `f` as errors. */ - def transactionScope[A](f: => A): A = { - transaction(logError = true, forceNew = false, (txn: Transaction) => { + def transactionScope[A](f: => A)(implicit ev: R =:= Primary): A = { + transaction(logError = true, forceNew = false, (txn: Transaction[R]) => { transactionProvider.begin(txn) try { f @@ -181,8 +192,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry] * exception the transaction is rolled back. Logs exceptions thrown by * `f` as errors. */ - def newTransactionScope[A](f: => A): A = { - transaction(logError = true, forceNew = true, (txn: Transaction) => { + def newTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = { + transaction(logError = true, forceNew = true, (txn: Transaction[R]) => { transactionProvider.begin(txn) try { f @@ -197,8 +208,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry] * thread within the scope of `f`. If `f` throws an exception the transaction * is rolled back. Will not log exceptions thrown by `f`. */ - def quietTransactionScope[A](f: => A): A = { - transaction(logError = false, forceNew = false, (txn: Transaction) => { + def quietTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = { + transaction(logError = false, forceNew = false, (txn: Transaction[R]) => { transactionProvider.begin(txn) try { f @@ -216,8 +227,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry] * exception the transaction is rolled back. Will not log exceptions * thrown by `f`. */ - def newQuietTransactionScope[A](f: => A): A = { - transaction(logError = false, forceNew = true, (txn: Transaction) => { + def newQuietTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = { + transaction(logError = false, forceNew = true, (txn: Transaction[R]) => { transactionProvider.begin(txn) try { f @@ -230,7 +241,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry] /** * The transaction currently scoped via transactionScope. */ - def currentTransaction = { + def currentTransaction(implicit ev: R =:= Primary) = { transactionProvider.currentTransaction } @@ -260,7 +271,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry] /** * Executes an update, insert, delete, or DDL statement. */ - def execute(statement: Statement) = { + def execute(statement: Statement)(implicit ev: R =:= Primary): Int = { if (transactionProvider.transactionExists) { transactionProvider.currentTransaction.execute(statement) } else { @@ -278,7 +289,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry] /** * Rollback any existing ambient transaction */ - def rollback() { + def rollback()(implicit ev: R =:= Primary) { transactionProvider.rollback } diff --git a/src/main/scala/com/simple/jdub/Queryable.scala b/src/main/scala/com/simple/jdub/Queryable.scala index fd10b78..a5c9c5f 100644 --- a/src/main/scala/com/simple/jdub/Queryable.scala +++ b/src/main/scala/com/simple/jdub/Queryable.scala @@ -5,11 +5,13 @@ package com.simple.jdub -import java.sql.Connection +import com.simple.jdub.Database.Primary +import com.simple.jdub.Database.Role +import java.sql.Connection import grizzled.slf4j.Logging -trait Queryable extends Logging { +trait Queryable[R <: Role] extends Logging { import Utils._ /** @@ -35,7 +37,7 @@ trait Queryable extends Logging { /** * Executes an update, insert, delete, or DDL statement. */ - def execute(connection: Connection, statement: Statement): Int = { + def execute(connection: Connection, statement: Statement)(implicit ev: R =:= Primary): Int = { logger.debug("%s with %s".format(statement.sql, statement.values.mkString("(", ", ", ")"))) val stmt = connection.prepareStatement(prependComment(statement, statement.sql)) @@ -47,9 +49,9 @@ trait Queryable extends Logging { } } - def execute(statement: Statement): Int + def execute(statement: Statement)(implicit ev: R =:= Primary): Int def apply[A](query: RawQuery[A]): A - def transaction[A](f: Transaction => A): A + def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A /** * Performs a query and returns the results. @@ -59,15 +61,15 @@ trait Queryable extends Logging { /** * Executes an update statement. */ - def update(statement: Statement): Int = execute(statement) + def update(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement) /** * Executes an insert statement. */ - def insert(statement: Statement): Int = execute(statement) + def insert(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement) /** * Executes a delete statement. */ - def delete(statement: Statement): Int = execute(statement) + def delete(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement) } diff --git a/src/main/scala/com/simple/jdub/RawQuery.scala b/src/main/scala/com/simple/jdub/RawQuery.scala index 1d067a3..42038e5 100644 --- a/src/main/scala/com/simple/jdub/RawQuery.scala +++ b/src/main/scala/com/simple/jdub/RawQuery.scala @@ -1,5 +1,7 @@ package com.simple.jdub +import com.simple.jdub.Database.Role + import java.sql.ResultSet trait RawQuery[A] extends SqlBase { @@ -9,5 +11,5 @@ trait RawQuery[A] extends SqlBase { def handle(results: ResultSet): A - def apply(db: Database): A = db(this) + def apply(db: Database[Role]): A = db(this) } diff --git a/src/main/scala/com/simple/jdub/Transaction.scala b/src/main/scala/com/simple/jdub/Transaction.scala index 83340d7..927699f 100644 --- a/src/main/scala/com/simple/jdub/Transaction.scala +++ b/src/main/scala/com/simple/jdub/Transaction.scala @@ -1,9 +1,12 @@ package com.simple.jdub +import com.simple.jdub.Database.Primary +import com.simple.jdub.Database.Role + import java.sql.{Connection, Savepoint} import scala.collection.mutable.ListBuffer -class Transaction(val connection: Connection) extends Queryable { +class Transaction[R <: Role](val connection: Connection) extends Queryable[R] { private[this] var rolledback = false /** @@ -14,12 +17,12 @@ class Transaction(val connection: Connection) extends Queryable { /** * Executes an update, insert, delete, or DDL statement. */ - def execute(statement: Statement) = execute(connection, statement) + def execute(statement: Statement)(implicit ev: R =:= Primary) = execute(connection, statement) /** * Roll back the transaction. */ - def rollback() { + def rollback()(implicit ev: R =:= Primary) { logger.debug("Rolling back transaction") connection.rollback() rolledback = true @@ -29,7 +32,7 @@ class Transaction(val connection: Connection) extends Queryable { /** * Roll back the transaction to a savepoint. */ - def rollback(savepoint: Savepoint) { + def rollback(savepoint: Savepoint)(implicit ev: R =:= Primary) { logger.debug("Rolling back to savepoint") connection.rollback(savepoint) } @@ -37,7 +40,7 @@ class Transaction(val connection: Connection) extends Queryable { /** * Release a transaction from a savepoint. */ - def release(savepoint: Savepoint) { + def release(savepoint: Savepoint)(implicit ev: R =:= Primary) { logger.debug("Releasing savepoint") connection.releaseSavepoint(savepoint) } @@ -45,7 +48,7 @@ class Transaction(val connection: Connection) extends Queryable { /** * Set an unnamed savepoint. */ - def savepoint(): Savepoint = { + def savepoint()(implicit ev: R =:= Primary): Savepoint = { logger.debug("Setting unnamed savepoint") connection.setSavepoint() } @@ -53,12 +56,12 @@ class Transaction(val connection: Connection) extends Queryable { /** * Set a named savepoint. */ - def savepoint(name: String): Savepoint = { + def savepoint(name: String)(implicit ev: R =:= Primary): Savepoint = { logger.debug("Setting savepoint") connection.setSavepoint(name) } - private[jdub] def commit() { + private[jdub] def commit()(implicit ev: R =:= Primary) { if (!rolledback) { logger.debug("Committing transaction") connection.commit() @@ -72,7 +75,7 @@ class Transaction(val connection: Connection) extends Queryable { onClose.foreach(_()) } - def transaction[A](f: Transaction => A): A = f(this) + def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = f(this) var onCommit: ListBuffer[() => Unit] = ListBuffer.empty[() => Unit] var onClose: ListBuffer[() => Unit] = ListBuffer.empty[() => Unit] diff --git a/src/main/scala/com/simple/jdub/TransactionManager.scala b/src/main/scala/com/simple/jdub/TransactionManager.scala index 518fdd8..3d114e9 100644 --- a/src/main/scala/com/simple/jdub/TransactionManager.scala +++ b/src/main/scala/com/simple/jdub/TransactionManager.scala @@ -5,18 +5,21 @@ package com.simple.jdub +import com.simple.jdub.Database.Primary +import com.simple.jdub.Database.Role + import java.util.Stack -trait TransactionProvider { +trait TransactionProvider[R <: Role] { def transactionExists: Boolean - def currentTransaction: Transaction - def begin(transaction: Transaction): Unit - def end(): Unit - def rollback(): Unit + def currentTransaction: Transaction[R] + def begin(transaction: Transaction[R])(implicit ev: R =:= Primary): Unit + def end()(implicit ev: R =:= Primary): Unit + def rollback()(implicit ev: R =:= Primary): Unit } -class TransactionManager extends TransactionProvider { - case class TransactionState(transactions: Stack[Transaction]) +class TransactionManager[R <: Role] extends TransactionProvider[R] { + case class TransactionState(transactions: Stack[Transaction[R]]) private val localTransactionStorage = new ThreadLocal[Option[TransactionState]] { override def initialValue = None @@ -26,7 +29,7 @@ class TransactionManager extends TransactionProvider { localTransactionStorage.get } - protected def ambientTransaction: Option[Transaction] = { + protected def ambientTransaction: Option[Transaction[R]] = { ambientTransactionState.map(_.transactions.peek) } @@ -40,15 +43,15 @@ class TransactionManager extends TransactionProvider { ambientTransactionState.isDefined } - def currentTransaction: Transaction = { + def currentTransaction: Transaction[R] = { ambientTransaction.getOrElse( throw new Exception("No transaction in current context") ) } - def begin(transaction: Transaction): Unit = { + def begin(transaction: Transaction[R])(implicit ev: R =:= Primary): Unit = { if (!transactionExists) { - val stack = new Stack[Transaction]() + val stack = new Stack[Transaction[R]]() stack.push(transaction) localTransactionStorage.set(Some(new TransactionState(stack))) } else { @@ -56,7 +59,7 @@ class TransactionManager extends TransactionProvider { } } - def end(): Unit = { + def end()(implicit ev: R =:= Primary): Unit = { if (!transactionExists) { throw new Exception("No transaction in current context") } else { @@ -67,7 +70,7 @@ class TransactionManager extends TransactionProvider { } } - def rollback(): Unit = { + def rollback()(implicit ev: R =:= Primary): Unit = { currentTransaction.rollback } } diff --git a/src/test/scala/com/simple/jdub/tests/DatabaseSpec.scala b/src/test/scala/com/simple/jdub/tests/DatabaseSpec.scala index c01d9fb..8444552 100644 --- a/src/test/scala/com/simple/jdub/tests/DatabaseSpec.scala +++ b/src/test/scala/com/simple/jdub/tests/DatabaseSpec.scala @@ -1,5 +1,6 @@ package com.simple.jdub.tests +import com.simple.jdub.Database.Replica import com.simple.jdub._ class DatabaseSpec extends JdubSpec { @@ -58,4 +59,13 @@ class DatabaseSpec extends JdubSpec { .map(_.map(_.toSeq)) // arrays compare by reference, seqs by value .must(be(Seq(None))) } + + test("replica databases can only execute statements and read-only queries") { + val replica = db.replica + + replica(AgeQuery("Coda Hale")).must(be(Some(29))) + + // the following will not compile when using a replica + // replica.execute(SQL("CREATE TABLE i_cannot_do_this()")) + } } diff --git a/src/test/scala/com/simple/jdub/tests/JdubSpec.scala b/src/test/scala/com/simple/jdub/tests/JdubSpec.scala index 0a448a3..831348d 100644 --- a/src/test/scala/com/simple/jdub/tests/JdubSpec.scala +++ b/src/test/scala/com/simple/jdub/tests/JdubSpec.scala @@ -4,7 +4,7 @@ package com.simple.jdub.tests import com.simple.jdub.Database - +import com.simple.jdub.Database.Primary import org.scalatest.BeforeAndAfter import org.scalatest.BeforeAndAfterEach import org.scalatest.FunSuite @@ -18,7 +18,7 @@ import java.util.concurrent.atomic.AtomicInteger object Global { val i = new AtomicInteger - val db = Database.connect("jdbc:hsqldb:mem:DbTest" + Global.i.incrementAndGet(), "sa", "") + val db = Database.connect[Primary]("jdbc:hsqldb:mem:DbTest" + Global.i.incrementAndGet(), "sa", "") } trait JdubSpec extends FunSuite