Skip to content

Commit

Permalink
Write unit tests for Postgres package (#69)
Browse files Browse the repository at this point in the history
Changes:
* I add postgres_databaseutils_test.go for containing all utilities for mocking the database.
* I put appName as a required parameter passed to the New() since it is needed to form the namespace column name.
* I add GetMaxInterval() since it might be needed by the consumer and the test.
* I move the SQL string package variables directly into the function that need it so we could avoid maintenance nightmare in the future.
  • Loading branch information
vincent6767 authored and harlow committed Oct 14, 2018
1 parent d3b7634 commit cb35697
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 21 deletions.
19 changes: 18 additions & 1 deletion Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Gopkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@
[prune]
go-tests = true
unused-packages = true

[[constraint]]
name = "gopkg.in/DATA-DOG/go-sqlmock.v1"
version = "1.3.0"
42 changes: 22 additions & 20 deletions checkpoint/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,6 @@ import (
_ "github.com/lib/pq"
)

var getCheckpointQuery = `SELECT sequence_number
FROM %s
WHERE namespace=$1 AND shard_id=$2`

var upsertCheckpoint = `INSERT INTO %s (namespace, shard_id, sequence_number)
VALUES($1, $2, $3)
ON CONFLICT (namespace, shard_id)
DO
UPDATE
SET sequence_number= $3`

type key struct {
streamName string
shardID string
Expand All @@ -36,9 +25,10 @@ func WithMaxInterval(maxInterval time.Duration) Option {
}
}

// Checkpoint stores and retreives the last evaluated key from a DDB scan
// Checkpoint stores and retrieves the last evaluated key from a DDB scan
type Checkpoint struct {
appName string
tableName string
conn *sql.DB
mu *sync.Mutex // protects the checkpoints
done chan struct{}
Expand All @@ -49,9 +39,12 @@ type Checkpoint struct {
// New returns a checkpoint that uses PostgresDB for underlying storage
// Using connectionStr turn it more flexible to use specific db configs
func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint, error) {
if appName == "" {
return nil, errors.New("application name not defined")
}

if tableName == "" {
return nil, errors.New("Table name not defined")
return nil, errors.New("table name not defined")
}

conn, err := sql.Open("postgres", connectionStr)
Expand All @@ -60,14 +53,12 @@ func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint,
return nil, err
}

getCheckpointQuery = fmt.Sprintf(getCheckpointQuery, tableName)
upsertCheckpoint = fmt.Sprintf(upsertCheckpoint, tableName)

ck := &Checkpoint{
conn: conn,
appName: appName,
tableName: tableName,
done: make(chan struct{}),
maxInterval: time.Duration(1 * time.Minute),
maxInterval: 1 * time.Minute,
mu: new(sync.Mutex),
checkpoints: map[key]string{},
}
Expand All @@ -81,21 +72,25 @@ func New(appName, tableName, connectionStr string, opts ...Option) (*Checkpoint,
return ck, nil
}

// GetMaxInterval returns the maximum interval before the checkpoint
func (c *Checkpoint) GetMaxInterval() time.Duration {
return c.maxInterval
}

// Get determines if a checkpoint for a particular Shard exists.
// Typically used to determine whether we should start processing the shard with
// TRIM_HORIZON or AFTER_SEQUENCE_NUMBER (if checkpoint exists).
func (c *Checkpoint) Get(streamName, shardID string) (string, error) {
namespace := fmt.Sprintf("%s-%s", c.appName, streamName)

var sequenceNumber string

getCheckpointQuery := fmt.Sprintf(`SELECT sequence_number FROM %s WHERE namespace=$1 AND shard_id=$2;`, c.tableName) //nolint: gas, it replaces only the table name
err := c.conn.QueryRow(getCheckpointQuery, namespace, shardID).Scan(&sequenceNumber)

if err != nil {
if err == sql.ErrNoRows {
return "", nil
}

return "", err
}

Expand Down Expand Up @@ -150,8 +145,15 @@ func (c *Checkpoint) save() error {
c.mu.Lock()
defer c.mu.Unlock()

for key, sequenceNumber := range c.checkpoints {
//nolint: gas, it replaces only the table name
upsertCheckpoint := fmt.Sprintf(`INSERT INTO %s (namespace, shard_id, sequence_number)
VALUES($1, $2, $3)
ON CONFLICT (namespace, shard_id)
DO
UPDATE
SET sequence_number= $3;`, c.tableName)

for key, sequenceNumber := range c.checkpoints {
if _, err := c.conn.Exec(upsertCheckpoint, fmt.Sprintf("%s-%s", c.appName, key.streamName), key.shardID, sequenceNumber); err != nil {
return err
}
Expand Down
7 changes: 7 additions & 0 deletions checkpoint/postgres/postgres_databaseutils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package postgres

import "database/sql"

func (c *Checkpoint) SetConn(conn *sql.DB) {
c.conn = conn
}
Loading

0 comments on commit cb35697

Please sign in to comment.