Skip to content

Commit

Permalink
WIP: use phantom types to restrict roles
Browse files Browse the repository at this point in the history
  • Loading branch information
jpfuentes2 committed Jan 23, 2017
1 parent 695b4ca commit 4253ae0
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 56 deletions.
57 changes: 34 additions & 23 deletions src/main/scala/com/simple/jdub/Database.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@ import com.zaxxer.hikari.{HikariConfig, HikariDataSource}

import java.io.FileInputStream
import java.security.KeyStore
import java.util.{UUID, Properties}
import java.util.{Properties, UUID}
import javax.sql.DataSource

import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.health.HealthCheckRegistry
import com.simple.jdub.Database.Primary
import com.simple.jdub.Database.Replica

import scala.annotation.implicitNotFound

object Database {

sealed trait Role
sealed trait Primary extends Role
sealed trait Replica extends Role

/**
* Create a pool of connections to the given database.
*
Expand All @@ -20,7 +27,7 @@ object Database {
* @param password the database password
* @param sslSettings if present, uses the given SSL settings for a client-side SSL cert.
*/
def connect(url: String,
def connect[R <: Role](url: String,
username: String,
password: String,
name: Option[String] = None,
Expand All @@ -30,7 +37,7 @@ object Database {
sslSettings: Option[SslSettings] = None,
healthCheckRegistry: Option[HealthCheckRegistry] = None,
metricRegistry: Option[MetricRegistry] = None,
connectionInitSql: Option[String] = None): Database = {
connectionInitSql: Option[String] = None): Database[R] = {

val properties = new Properties

Expand Down Expand Up @@ -95,8 +102,8 @@ object Database {
/**
* A set of pooled connections to a database.
*/
class Database protected(val source: DataSource, metrics: Option[MetricRegistry])
extends Queryable {
class Database[R <: Database.Role] protected(val source: DataSource, metrics: Option[MetricRegistry])
extends Queryable[R] {

private[jdub] def time[A](klass: java.lang.Class[_])(f: => A) = {
metrics.fold(f) { registry =>
Expand All @@ -110,34 +117,38 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
}
}

val transactionProvider: TransactionProvider = new TransactionManager
val transactionProvider: TransactionProvider[R] = new TransactionManager

def replica: Database[Replica] = new Database[Replica](source, metrics)

def primary: Database[Primary] = new Database[Primary](source, metrics)

/**
* Opens a transaction which is committed after `f` is called. If `f` throws
* an exception, the transaction is rolled back.
*/
def transaction[A](f: Transaction => A): A = transaction(true, f)
def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(true, f)

/**
* Opens a transaction which is committed after `f` is called. If `f` throws
* an exception, the transaction is rolled back, but the exception is not
* logged (since it is rethrown).
*/
def quietTransaction[A](f: Transaction => A): A = transaction(false, f)
def quietTransaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(false, f)

def transaction[A](logError: Boolean, f: Transaction => A): A = transaction(false, false, f)
def transaction[A](logError: Boolean, f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(false, false, f)

/**
* Opens a transaction which is committed after `f` is called. If `f` throws
* an exception, the transaction is rolled back.
*/
def transaction[A](logError: Boolean, forceNew: Boolean, f: Transaction => A): A = {
def transaction[A](logError: Boolean, forceNew: Boolean, f: Transaction[R] => A)(implicit ev: R =:= Primary): A = {
if (!forceNew && transactionProvider.transactionExists) {
f(transactionProvider.currentTransaction)
} else {
val connection = source.getConnection
connection.setAutoCommit(false)
val txn = new Transaction(connection)
val txn = new Transaction[R](connection)
try {
logger.debug("Starting transaction")
val result = f(txn)
Expand All @@ -162,8 +173,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
* thread within the scope of `f`. If `f` throws an exception the transaction
* is rolled back. Logs exceptions thrown by `f` as errors.
*/
def transactionScope[A](f: => A): A = {
transaction(logError = true, forceNew = false, (txn: Transaction) => {
def transactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
transaction(logError = true, forceNew = false, (txn: Transaction[R]) => {
transactionProvider.begin(txn)
try {
f
Expand All @@ -181,8 +192,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
* exception the transaction is rolled back. Logs exceptions thrown by
* `f` as errors.
*/
def newTransactionScope[A](f: => A): A = {
transaction(logError = true, forceNew = true, (txn: Transaction) => {
def newTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
transaction(logError = true, forceNew = true, (txn: Transaction[R]) => {
transactionProvider.begin(txn)
try {
f
Expand All @@ -197,8 +208,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
* thread within the scope of `f`. If `f` throws an exception the transaction
* is rolled back. Will not log exceptions thrown by `f`.
*/
def quietTransactionScope[A](f: => A): A = {
transaction(logError = false, forceNew = false, (txn: Transaction) => {
def quietTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
transaction(logError = false, forceNew = false, (txn: Transaction[R]) => {
transactionProvider.begin(txn)
try {
f
Expand All @@ -216,8 +227,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
* exception the transaction is rolled back. Will not log exceptions
* thrown by `f`.
*/
def newQuietTransactionScope[A](f: => A): A = {
transaction(logError = false, forceNew = true, (txn: Transaction) => {
def newQuietTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
transaction(logError = false, forceNew = true, (txn: Transaction[R]) => {
transactionProvider.begin(txn)
try {
f
Expand All @@ -230,7 +241,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
/**
* The transaction currently scoped via transactionScope.
*/
def currentTransaction = {
def currentTransaction(implicit ev: R =:= Primary) = {
transactionProvider.currentTransaction
}

Expand Down Expand Up @@ -260,7 +271,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
/**
* Executes an update, insert, delete, or DDL statement.
*/
def execute(statement: Statement) = {
def execute(statement: Statement)(implicit ev: R =:= Primary): Int = {
if (transactionProvider.transactionExists) {
transactionProvider.currentTransaction.execute(statement)
} else {
Expand All @@ -278,7 +289,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
/**
* Rollback any existing ambient transaction
*/
def rollback() {
def rollback()(implicit ev: R =:= Primary) {
transactionProvider.rollback
}

Expand Down
18 changes: 10 additions & 8 deletions src/main/scala/com/simple/jdub/Queryable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

package com.simple.jdub

import java.sql.Connection
import com.simple.jdub.Database.Primary
import com.simple.jdub.Database.Role

import java.sql.Connection
import grizzled.slf4j.Logging

trait Queryable extends Logging {
trait Queryable[R <: Role] extends Logging {
import Utils._

/**
Expand All @@ -35,7 +37,7 @@ trait Queryable extends Logging {
/**
* Executes an update, insert, delete, or DDL statement.
*/
def execute(connection: Connection, statement: Statement): Int = {
def execute(connection: Connection, statement: Statement)(implicit ev: R =:= Primary): Int = {
logger.debug("%s with %s".format(statement.sql,
statement.values.mkString("(", ", ", ")")))
val stmt = connection.prepareStatement(prependComment(statement, statement.sql))
Expand All @@ -47,9 +49,9 @@ trait Queryable extends Logging {
}
}

def execute(statement: Statement): Int
def execute(statement: Statement)(implicit ev: R =:= Primary): Int
def apply[A](query: RawQuery[A]): A
def transaction[A](f: Transaction => A): A
def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A

/**
* Performs a query and returns the results.
Expand All @@ -59,15 +61,15 @@ trait Queryable extends Logging {
/**
* Executes an update statement.
*/
def update(statement: Statement): Int = execute(statement)
def update(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement)

/**
* Executes an insert statement.
*/
def insert(statement: Statement): Int = execute(statement)
def insert(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement)

/**
* Executes a delete statement.
*/
def delete(statement: Statement): Int = execute(statement)
def delete(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement)
}
4 changes: 3 additions & 1 deletion src/main/scala/com/simple/jdub/RawQuery.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.simple.jdub

import com.simple.jdub.Database.Role

import java.sql.ResultSet

trait RawQuery[A] extends SqlBase {
Expand All @@ -9,5 +11,5 @@ trait RawQuery[A] extends SqlBase {

def handle(results: ResultSet): A

def apply(db: Database): A = db(this)
def apply(db: Database[Role]): A = db(this)
}
21 changes: 12 additions & 9 deletions src/main/scala/com/simple/jdub/Transaction.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package com.simple.jdub

import com.simple.jdub.Database.Primary
import com.simple.jdub.Database.Role

import java.sql.{Connection, Savepoint}
import scala.collection.mutable.ListBuffer

class Transaction(val connection: Connection) extends Queryable {
class Transaction[R <: Role](val connection: Connection) extends Queryable[R] {
private[this] var rolledback = false

/**
Expand All @@ -14,12 +17,12 @@ class Transaction(val connection: Connection) extends Queryable {
/**
* Executes an update, insert, delete, or DDL statement.
*/
def execute(statement: Statement) = execute(connection, statement)
def execute(statement: Statement)(implicit ev: R =:= Primary) = execute(connection, statement)

/**
* Roll back the transaction.
*/
def rollback() {
def rollback()(implicit ev: R =:= Primary) {
logger.debug("Rolling back transaction")
connection.rollback()
rolledback = true
Expand All @@ -29,36 +32,36 @@ class Transaction(val connection: Connection) extends Queryable {
/**
* Roll back the transaction to a savepoint.
*/
def rollback(savepoint: Savepoint) {
def rollback(savepoint: Savepoint)(implicit ev: R =:= Primary) {
logger.debug("Rolling back to savepoint")
connection.rollback(savepoint)
}

/**
* Release a transaction from a savepoint.
*/
def release(savepoint: Savepoint) {
def release(savepoint: Savepoint)(implicit ev: R =:= Primary) {
logger.debug("Releasing savepoint")
connection.releaseSavepoint(savepoint)
}

/**
* Set an unnamed savepoint.
*/
def savepoint(): Savepoint = {
def savepoint()(implicit ev: R =:= Primary): Savepoint = {
logger.debug("Setting unnamed savepoint")
connection.setSavepoint()
}

/**
* Set a named savepoint.
*/
def savepoint(name: String): Savepoint = {
def savepoint(name: String)(implicit ev: R =:= Primary): Savepoint = {
logger.debug("Setting savepoint")
connection.setSavepoint(name)
}

private[jdub] def commit() {
private[jdub] def commit()(implicit ev: R =:= Primary) {
if (!rolledback) {
logger.debug("Committing transaction")
connection.commit()
Expand All @@ -72,7 +75,7 @@ class Transaction(val connection: Connection) extends Queryable {
onClose.foreach(_())
}

def transaction[A](f: Transaction => A): A = f(this)
def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = f(this)

var onCommit: ListBuffer[() => Unit] = ListBuffer.empty[() => Unit]
var onClose: ListBuffer[() => Unit] = ListBuffer.empty[() => Unit]
Expand Down
29 changes: 16 additions & 13 deletions src/main/scala/com/simple/jdub/TransactionManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@

package com.simple.jdub

import com.simple.jdub.Database.Primary
import com.simple.jdub.Database.Role

import java.util.Stack

trait TransactionProvider {
trait TransactionProvider[R <: Role] {
def transactionExists: Boolean
def currentTransaction: Transaction
def begin(transaction: Transaction): Unit
def end(): Unit
def rollback(): Unit
def currentTransaction: Transaction[R]
def begin(transaction: Transaction[R])(implicit ev: R =:= Primary): Unit
def end()(implicit ev: R =:= Primary): Unit
def rollback()(implicit ev: R =:= Primary): Unit
}

class TransactionManager extends TransactionProvider {
case class TransactionState(transactions: Stack[Transaction])
class TransactionManager[R <: Role] extends TransactionProvider[R] {
case class TransactionState(transactions: Stack[Transaction[R]])

private val localTransactionStorage = new ThreadLocal[Option[TransactionState]] {
override def initialValue = None
Expand All @@ -26,7 +29,7 @@ class TransactionManager extends TransactionProvider {
localTransactionStorage.get
}

protected def ambientTransaction: Option[Transaction] = {
protected def ambientTransaction: Option[Transaction[R]] = {
ambientTransactionState.map(_.transactions.peek)
}

Expand All @@ -40,23 +43,23 @@ class TransactionManager extends TransactionProvider {
ambientTransactionState.isDefined
}

def currentTransaction: Transaction = {
def currentTransaction: Transaction[R] = {
ambientTransaction.getOrElse(
throw new Exception("No transaction in current context")
)
}

def begin(transaction: Transaction): Unit = {
def begin(transaction: Transaction[R])(implicit ev: R =:= Primary): Unit = {
if (!transactionExists) {
val stack = new Stack[Transaction]()
val stack = new Stack[Transaction[R]]()
stack.push(transaction)
localTransactionStorage.set(Some(new TransactionState(stack)))
} else {
currentTransactionState.transactions.push(transaction)
}
}

def end(): Unit = {
def end()(implicit ev: R =:= Primary): Unit = {
if (!transactionExists) {
throw new Exception("No transaction in current context")
} else {
Expand All @@ -67,7 +70,7 @@ class TransactionManager extends TransactionProvider {
}
}

def rollback(): Unit = {
def rollback()(implicit ev: R =:= Primary): Unit = {
currentTransaction.rollback
}
}
Loading

0 comments on commit 4253ae0

Please sign in to comment.