Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] use phantom types to restrict roles #63

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions src/main/scala/com/simple/jdub/Database.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@ import com.zaxxer.hikari.{HikariConfig, HikariDataSource}

import java.io.FileInputStream
import java.security.KeyStore
import java.util.{UUID, Properties}
import java.util.{Properties, UUID}
import javax.sql.DataSource

import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.health.HealthCheckRegistry
import com.simple.jdub.Database.Primary
import com.simple.jdub.Database.Replica

import scala.annotation.implicitNotFound

object Database {

sealed trait Role
final class Primary extends Role
final class Replica extends Role

/**
* Create a pool of connections to the given database.
*
Expand All @@ -20,7 +27,7 @@ object Database {
* @param password the database password
* @param sslSettings if present, uses the given SSL settings for a client-side SSL cert.
*/
def connect(url: String,
def connect[R <: Role](url: String,
username: String,
password: String,
name: Option[String] = None,
Expand All @@ -30,7 +37,7 @@ object Database {
sslSettings: Option[SslSettings] = None,
healthCheckRegistry: Option[HealthCheckRegistry] = None,
metricRegistry: Option[MetricRegistry] = None,
connectionInitSql: Option[String] = None): Database = {
connectionInitSql: Option[String] = None): Database[R] = {

val properties = new Properties

Expand Down Expand Up @@ -95,8 +102,8 @@ object Database {
/**
* A set of pooled connections to a database.
*/
class Database protected(val source: DataSource, metrics: Option[MetricRegistry])
extends Queryable {
class Database[R <: Database.Role] protected(val source: DataSource, metrics: Option[MetricRegistry])
extends Queryable[R] {

private[jdub] def time[A](klass: java.lang.Class[_])(f: => A) = {
metrics.fold(f) { registry =>
Expand All @@ -110,34 +117,38 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
}
}

val transactionProvider: TransactionProvider = new TransactionManager
val transactionProvider: TransactionProvider[R] = new TransactionManager

def replica: Database[Replica] = new Database[Replica](source, metrics)

def primary: Database[Primary] = new Database[Primary](source, metrics)

/**
* Opens a transaction which is committed after `f` is called. If `f` throws
* an exception, the transaction is rolled back.
*/
def transaction[A](f: Transaction => A): A = transaction(true, f)
def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(true, f)

/**
* Opens a transaction which is committed after `f` is called. If `f` throws
* an exception, the transaction is rolled back, but the exception is not
* logged (since it is rethrown).
*/
def quietTransaction[A](f: Transaction => A): A = transaction(false, f)
def quietTransaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(false, f)

def transaction[A](logError: Boolean, f: Transaction => A): A = transaction(false, false, f)
def transaction[A](logError: Boolean, f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(false, false, f)

/**
* Opens a transaction which is committed after `f` is called. If `f` throws
* an exception, the transaction is rolled back.
*/
def transaction[A](logError: Boolean, forceNew: Boolean, f: Transaction => A): A = {
def transaction[A](logError: Boolean, forceNew: Boolean, f: Transaction[R] => A)(implicit ev: R =:= Primary): A = {
if (!forceNew && transactionProvider.transactionExists) {
f(transactionProvider.currentTransaction)
} else {
val connection = source.getConnection
connection.setAutoCommit(false)
val txn = new Transaction(connection)
val txn = new Transaction[R](connection)
try {
logger.debug("Starting transaction")
val result = f(txn)
Expand All @@ -162,8 +173,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
* thread within the scope of `f`. If `f` throws an exception the transaction
* is rolled back. Logs exceptions thrown by `f` as errors.
*/
def transactionScope[A](f: => A): A = {
transaction(logError = true, forceNew = false, (txn: Transaction) => {
def transactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
transaction(logError = true, forceNew = false, (txn: Transaction[R]) => {
transactionProvider.begin(txn)
try {
f
Expand All @@ -181,8 +192,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
* exception the transaction is rolled back. Logs exceptions thrown by
* `f` as errors.
*/
def newTransactionScope[A](f: => A): A = {
transaction(logError = true, forceNew = true, (txn: Transaction) => {
def newTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
transaction(logError = true, forceNew = true, (txn: Transaction[R]) => {
transactionProvider.begin(txn)
try {
f
Expand All @@ -197,8 +208,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
* thread within the scope of `f`. If `f` throws an exception the transaction
* is rolled back. Will not log exceptions thrown by `f`.
*/
def quietTransactionScope[A](f: => A): A = {
transaction(logError = false, forceNew = false, (txn: Transaction) => {
def quietTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
transaction(logError = false, forceNew = false, (txn: Transaction[R]) => {
transactionProvider.begin(txn)
try {
f
Expand All @@ -216,8 +227,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
* exception the transaction is rolled back. Will not log exceptions
* thrown by `f`.
*/
def newQuietTransactionScope[A](f: => A): A = {
transaction(logError = false, forceNew = true, (txn: Transaction) => {
def newQuietTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
transaction(logError = false, forceNew = true, (txn: Transaction[R]) => {
transactionProvider.begin(txn)
try {
f
Expand All @@ -230,7 +241,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
/**
* The transaction currently scoped via transactionScope.
*/
def currentTransaction = {
def currentTransaction(implicit ev: R =:= Primary) = {
transactionProvider.currentTransaction
}

Expand Down Expand Up @@ -260,7 +271,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
/**
* Executes an update, insert, delete, or DDL statement.
*/
def execute(statement: Statement) = {
def execute(statement: Statement)(implicit ev: R =:= Primary): Int = {
if (transactionProvider.transactionExists) {
transactionProvider.currentTransaction.execute(statement)
} else {
Expand All @@ -278,7 +289,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
/**
* Rollback any existing ambient transaction
*/
def rollback() {
def rollback()(implicit ev: R =:= Primary) {
transactionProvider.rollback
}

Expand Down
18 changes: 10 additions & 8 deletions src/main/scala/com/simple/jdub/Queryable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

package com.simple.jdub

import java.sql.Connection
import com.simple.jdub.Database.Primary
import com.simple.jdub.Database.Role

import java.sql.Connection
import grizzled.slf4j.Logging

trait Queryable extends Logging {
trait Queryable[R <: Role] extends Logging {
import Utils._

/**
Expand All @@ -35,7 +37,7 @@ trait Queryable extends Logging {
/**
* Executes an update, insert, delete, or DDL statement.
*/
def execute(connection: Connection, statement: Statement): Int = {
def execute(connection: Connection, statement: Statement)(implicit ev: R =:= Primary): Int = {
logger.debug("%s with %s".format(statement.sql,
statement.values.mkString("(", ", ", ")")))
val stmt = connection.prepareStatement(prependComment(statement, statement.sql))
Expand All @@ -47,9 +49,9 @@ trait Queryable extends Logging {
}
}

def execute(statement: Statement): Int
def execute(statement: Statement)(implicit ev: R =:= Primary): Int
def apply[A](query: RawQuery[A]): A
def transaction[A](f: Transaction => A): A
def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A

/**
* Performs a query and returns the results.
Expand All @@ -59,15 +61,15 @@ trait Queryable extends Logging {
/**
* Executes an update statement.
*/
def update(statement: Statement): Int = execute(statement)
def update(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement)

/**
* Executes an insert statement.
*/
def insert(statement: Statement): Int = execute(statement)
def insert(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement)

/**
* Executes a delete statement.
*/
def delete(statement: Statement): Int = execute(statement)
def delete(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement)
}
4 changes: 3 additions & 1 deletion src/main/scala/com/simple/jdub/RawQuery.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.simple.jdub

import com.simple.jdub.Database.Role

import java.sql.ResultSet

trait RawQuery[A] extends SqlBase {
Expand All @@ -9,5 +11,5 @@ trait RawQuery[A] extends SqlBase {

def handle(results: ResultSet): A

def apply(db: Database): A = db(this)
def apply(db: Database[Role]): A = db(this)
}
21 changes: 12 additions & 9 deletions src/main/scala/com/simple/jdub/Transaction.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package com.simple.jdub

import com.simple.jdub.Database.Primary
import com.simple.jdub.Database.Role

import java.sql.{Connection, Savepoint}
import scala.collection.mutable.ListBuffer

class Transaction(val connection: Connection) extends Queryable {
class Transaction[R <: Role](val connection: Connection) extends Queryable[R] {
private[this] var rolledback = false

/**
Expand All @@ -14,12 +17,12 @@ class Transaction(val connection: Connection) extends Queryable {
/**
* Executes an update, insert, delete, or DDL statement.
*/
def execute(statement: Statement) = execute(connection, statement)
def execute(statement: Statement)(implicit ev: R =:= Primary) = execute(connection, statement)

/**
* Roll back the transaction.
*/
def rollback() {
def rollback()(implicit ev: R =:= Primary) {
logger.debug("Rolling back transaction")
connection.rollback()
rolledback = true
Expand All @@ -29,36 +32,36 @@ class Transaction(val connection: Connection) extends Queryable {
/**
* Roll back the transaction to a savepoint.
*/
def rollback(savepoint: Savepoint) {
def rollback(savepoint: Savepoint)(implicit ev: R =:= Primary) {
logger.debug("Rolling back to savepoint")
connection.rollback(savepoint)
}

/**
* Release a transaction from a savepoint.
*/
def release(savepoint: Savepoint) {
def release(savepoint: Savepoint)(implicit ev: R =:= Primary) {
logger.debug("Releasing savepoint")
connection.releaseSavepoint(savepoint)
}

/**
* Set an unnamed savepoint.
*/
def savepoint(): Savepoint = {
def savepoint()(implicit ev: R =:= Primary): Savepoint = {
logger.debug("Setting unnamed savepoint")
connection.setSavepoint()
}

/**
* Set a named savepoint.
*/
def savepoint(name: String): Savepoint = {
def savepoint(name: String)(implicit ev: R =:= Primary): Savepoint = {
logger.debug("Setting savepoint")
connection.setSavepoint(name)
}

private[jdub] def commit() {
private[jdub] def commit()(implicit ev: R =:= Primary) {
if (!rolledback) {
logger.debug("Committing transaction")
connection.commit()
Expand All @@ -72,7 +75,7 @@ class Transaction(val connection: Connection) extends Queryable {
onClose.foreach(_())
}

def transaction[A](f: Transaction => A): A = f(this)
def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = f(this)

var onCommit: ListBuffer[() => Unit] = ListBuffer.empty[() => Unit]
var onClose: ListBuffer[() => Unit] = ListBuffer.empty[() => Unit]
Expand Down
29 changes: 16 additions & 13 deletions src/main/scala/com/simple/jdub/TransactionManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@

package com.simple.jdub

import com.simple.jdub.Database.Primary
import com.simple.jdub.Database.Role

import java.util.Stack

trait TransactionProvider {
trait TransactionProvider[R <: Role] {
def transactionExists: Boolean
def currentTransaction: Transaction
def begin(transaction: Transaction): Unit
def end(): Unit
def rollback(): Unit
def currentTransaction: Transaction[R]
def begin(transaction: Transaction[R])(implicit ev: R =:= Primary): Unit
def end()(implicit ev: R =:= Primary): Unit
def rollback()(implicit ev: R =:= Primary): Unit
}

class TransactionManager extends TransactionProvider {
case class TransactionState(transactions: Stack[Transaction])
class TransactionManager[R <: Role] extends TransactionProvider[R] {
case class TransactionState(transactions: Stack[Transaction[R]])

private val localTransactionStorage = new ThreadLocal[Option[TransactionState]] {
override def initialValue = None
Expand All @@ -26,7 +29,7 @@ class TransactionManager extends TransactionProvider {
localTransactionStorage.get
}

protected def ambientTransaction: Option[Transaction] = {
protected def ambientTransaction: Option[Transaction[R]] = {
ambientTransactionState.map(_.transactions.peek)
}

Expand All @@ -40,23 +43,23 @@ class TransactionManager extends TransactionProvider {
ambientTransactionState.isDefined
}

def currentTransaction: Transaction = {
def currentTransaction: Transaction[R] = {
ambientTransaction.getOrElse(
throw new Exception("No transaction in current context")
)
}

def begin(transaction: Transaction): Unit = {
def begin(transaction: Transaction[R])(implicit ev: R =:= Primary): Unit = {
if (!transactionExists) {
val stack = new Stack[Transaction]()
val stack = new Stack[Transaction[R]]()
stack.push(transaction)
localTransactionStorage.set(Some(new TransactionState(stack)))
} else {
currentTransactionState.transactions.push(transaction)
}
}

def end(): Unit = {
def end()(implicit ev: R =:= Primary): Unit = {
if (!transactionExists) {
throw new Exception("No transaction in current context")
} else {
Expand All @@ -67,7 +70,7 @@ class TransactionManager extends TransactionProvider {
}
}

def rollback(): Unit = {
def rollback()(implicit ev: R =:= Primary): Unit = {
currentTransaction.rollback
}
}
Loading