Skip to content

Commit

Permalink
Delay initialize db object (#1323)
Browse files Browse the repository at this point in the history
* lateer init db object

* fix ci

* fix ci

* fix ci

* fix ci
  • Loading branch information
Yancey1989 authored Dec 4, 2019
1 parent 1d81083 commit c103f83
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 60 deletions.
27 changes: 15 additions & 12 deletions cmd/repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,18 @@ func flagPassed(name ...string) bool {
return found
}

func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds string) error {
func runStmt(stmt string, isTerminal bool, modelDir string, ds string) error {
if !isTerminal {
fmt.Println("sqlflow>", stmt)
}
tableRendered := false
table := tablewriter.NewWriter(os.Stdout)
sess := makeSessionFromEnv()
if ds != "" {
sess.DbConnStr = ds
}

stream := sql.RunSQLProgram(stmt, db, modelDir, sess)
stream := sql.RunSQLProgram(stmt, modelDir, sess)
for rsp := range stream.ReadAll() {
// pagination. avoid exceed memory
if render(rsp, table) && table.NumLines() == tablePageSize {
Expand All @@ -128,14 +131,19 @@ func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds strin
return nil
}

func repl(scanner *bufio.Scanner, modelDir string, db *sql.DB, ds string) {
func repl(scanner *bufio.Scanner, modelDir string, ds string) {
db, err := sql.NewDB(ds)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
defer db.Close()
for {
stmt, err := readStmt(scanner)
fmt.Println()
if err == io.EOF && stmt == "" {
return
}
if err := runStmt(stmt, false, modelDir, db, ds); err != nil {
if err := runStmt(stmt, false, modelDir, ds); err != nil {
log.Fatalf("run SQL statment failed: %v", err)
}
}
Expand Down Expand Up @@ -196,12 +204,6 @@ func main() {
os.Exit(0)
}

db, err := sql.NewDB(*ds)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
defer db.Close()

if *modelDir != "" {
if _, derr := os.Stat(*modelDir); derr != nil {
os.Mkdir(*modelDir, os.ModePerm)
Expand All @@ -211,6 +213,7 @@ func main() {
isTerminal := !flagPassed("execute", "e", "file", "f") && terminal.IsTerminal(syscall.Stdin)

sqlFile := os.Stdin
var err error
if flagPassed("file", "f") {
sqlFile, err = os.Open(*sqlFileName)
if err != nil {
Expand All @@ -225,8 +228,8 @@ func main() {
}
scanner := bufio.NewScanner(reader)
if isTerminal {
runPrompt(func(stmt string) { runStmt(stmt, true, *modelDir, db, *ds) })
runPrompt(func(stmt string) { runStmt(stmt, true, *modelDir, *ds) })
} else {
repl(scanner, *modelDir, db, *ds)
repl(scanner, *modelDir, *ds)
}
}
2 changes: 1 addition & 1 deletion cmd/repl/repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ INTO sqlflow_models.mymodel;`

// run one train SQL to save the model then test predict/analyze use the model
sess := &irpb.Session{DbConnStr: dataSourceStr}
stream := sf.RunSQLProgram(trainSQL, testdb, "", sess)
stream := sf.RunSQLProgram(trainSQL, "", sess)
lastResp := list.New()
keepSize := 10
for rsp := range stream.ReadAll() {
Expand Down
13 changes: 3 additions & 10 deletions pkg/server/sqlflowserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ import (
)

// NewServer returns a server instance
func NewServer(run func(string, *sf.DB, string, *pb.Session) *sf.PipeReader,
func NewServer(run func(string, string, *pb.Session) *sf.PipeReader,
modelDir string) *Server {
return &Server{run: run, modelDir: modelDir}
}

// Server is the instance will be used to connect to DB and execute training
type Server struct {
run func(sql string, db *sf.DB, modelDir string, session *pb.Session) *sf.PipeReader
run func(sql string, modelDir string, session *pb.Session) *sf.PipeReader
modelDir string
}

Expand All @@ -51,18 +51,11 @@ func (s *Server) Fetch(ctx context.Context, job *pb.Job) (*pb.JobStatus, error)

// Run implements `rpc Run (Request) returns (stream Response)`
func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error {
var db *sf.DB

var err error
if db, err = sf.NewDB(req.Session.DbConnStr); err != nil {
return fmt.Errorf("create DB failed: %v", err)
}
defer db.Close()
sqlStatements, err := sf.SplitMultipleSQL(req.Sql)
if err != nil {
return err
}
rd := s.run(req.Sql, db, s.modelDir, req.Session)
rd := s.run(req.Sql, s.modelDir, req.Session)
defer rd.Close()

for r := range rd.ReadAll() {
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/sqlflowserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const (

var testServerAddress string

func mockRun(sql string, db *sf.DB, modelDir string, session *pb.Session) *sf.PipeReader {
func mockRun(sql string, modelDir string, session *pb.Session) *sf.PipeReader {
rd, wr := sf.Pipe()
singleSQL := sql
go func() {
Expand Down
19 changes: 14 additions & 5 deletions pkg/sql/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ type DB struct {
// In addition to sql.Open, it also does the book keeping on driverName and
// dataSourceName
func open(datasource string) (*DB, error) {
dses := strings.Split(datasource, "://")
if len(dses) != 2 {
return nil, fmt.Errorf("Expecting but cannot find :// in datasource %v", datasource)
driverName, datasourName, err := SplitDataSource(datasource)
if err != nil {
return nil, err
}
db := &DB{driverName: dses[0], dataSourceName: dses[1]}
db := &DB{driverName: driverName, dataSourceName: datasourName}

err := openDB(db)
err = openDB(db)
return db, err
}

Expand All @@ -63,6 +63,15 @@ func openDB(db *DB) error {
return fmt.Errorf("sqlflow currently doesn't support DB %s", db.driverName)
}

// SplitDataSource splits the datasource into drivername and datasource name
func SplitDataSource(datasource string) (string, string, error) {
dses := strings.Split(datasource, "://")
if len(dses) != 2 {
return "", "", fmt.Errorf("Expecting but cannot find :// in datasource %v", datasource)
}
return dses[0], dses[1], nil
}

// NewDB returns a DB object with verifying the datasource name.
func NewDB(datasource string) (*DB, error) {
db, err := open(datasource)
Expand Down
9 changes: 9 additions & 0 deletions pkg/sql/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,12 @@ func TestDatabaseOpenMysql(t *testing.T) {
_, e = db.Exec("show databases")
a.NoError(e)
}

func TestSplitDataSource(t *testing.T) {
a := assert.New(t)
ds := "mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0"
driverName, datasourceName, e := SplitDataSource(ds)
a.EqualValues(driverName, "mysql")
a.EqualValues(datasourceName, "root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0")
a.NoError(e)
}
33 changes: 24 additions & 9 deletions pkg/sql/executor_ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ type WorkflowJob struct {
}

// RunSQLProgram run a SQL program.
func RunSQLProgram(sqlProgram string, db *DB, modelDir string, session *pb.Session) *PipeReader {
func RunSQLProgram(sqlProgram string, modelDir string, session *pb.Session) *PipeReader {
rd, wr := Pipe()
go func() {
var db *DB
var err error
if db, err = NewDB(session.DbConnStr); err != nil {
wr.Write(fmt.Sprintf("create DB failed: %v", err))
log.Errorf("create DB failed: %v", err)
}
defer wr.Close()
err := runSQLProgram(wr, sqlProgram, db, modelDir, session)
err = runSQLProgram(wr, sqlProgram, db, modelDir, session)

if err != nil {
log.Errorf("runSQLProgram error: %v", err)
Expand Down Expand Up @@ -106,13 +112,17 @@ func ParseSQLStatement(sql string, session *pb.Session) (string, error) {
}

// SubmitWorkflow submits an Argo workflow
func SubmitWorkflow(sqlProgram string, db *DB, modelDir string, session *pb.Session) *PipeReader {
func SubmitWorkflow(sqlProgram string, modelDir string, session *pb.Session) *PipeReader {
rd, wr := Pipe()
go func() {
defer wr.Close()
err := submitWorkflow(wr, sqlProgram, db, modelDir, session)
err := submitWorkflow(wr, sqlProgram, modelDir, session)
if err != nil {
log.Errorf("submit Workflow error: %v", err)
if err != ErrClosedPipe {
if err := wr.Write(err); err != nil {
log.Errorf("submit workflow error(piping): %v", err)
}
}
}
}()
return rd
Expand Down Expand Up @@ -153,8 +163,12 @@ func writeArgoFile(coulerFileName string) (string, error) {
return argoYaml.Name(), nil
}

func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, session *pb.Session) error {
sqls, err := parse(db.driverName, sqlProgram)
func submitWorkflow(wr *PipeWriter, sqlProgram string, modelDir string, session *pb.Session) error {
driverName, dataSourceName, err := SplitDataSource(session.DbConnStr)
if err != nil {
return err
}
sqls, err := parse(driverName, sqlProgram)
if err != nil {
return err
}
Expand All @@ -166,7 +180,7 @@ func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string,
spIRs := []ir.SQLStatement{}
for _, sql := range sqls {
var r ir.SQLStatement
connStr := fmt.Sprintf("%s://%s", db.driverName, db.dataSourceName)
connStr := fmt.Sprintf("%s://%s", driverName, dataSourceName)
if sql.extended != nil {
parsed := sql.extended
if parsed.train {
Expand Down Expand Up @@ -203,7 +217,7 @@ func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string,
cmd := exec.Command("kubectl", "create", "-f", argoFile)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("submit Argo YAML error: %v", err)
return fmt.Errorf("submit Argo YAML error: %v, output: %s", err, string(output))
}
reWorkflow := regexp.MustCompile(`.+/(.+) .+`)
wf := reWorkflow.FindStringSubmatch(string(output))
Expand Down Expand Up @@ -409,6 +423,7 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string)
if modelDir != "" {
modelURI = fmt.Sprintf("file://%s/%s", modelDir, modelName)
}

m, e = load(modelURI, cwd, db)
if e != nil {
return nil, fmt.Errorf("load %v", e)
Expand Down
36 changes: 18 additions & 18 deletions pkg/sql/executor_ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ USING sqlflow_models.my_xgboost_model_by_program;
SELECT sepal_length as sl, sepal_width as sw, class FROM iris.train
TO EXPLAIN sqlflow_models.my_xgboost_model_by_program
USING TreeExplainer;
`, testDB, modelDir, getDefaultSession())
`, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})

Expand All @@ -152,17 +152,17 @@ func TestExecuteXGBoostClassifier(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
stream := RunSQLProgram(testTrainSelectWithLimit, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testTrainSelectWithLimit, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testXGBoostPredictIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testXGBoostPredictIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
a.NotPanics(func() {
stream := RunSQLProgram(testXGBoostTrainSelectIris, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testXGBoostTrainSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testXGBoostPredictIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testXGBoostPredictIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -171,11 +171,11 @@ func TestExecuteXGBoostRegression(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
stream := RunSQLProgram(testXGBoostTrainSelectHousing, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testXGBoostTrainSelectHousing, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testXGBoostPredictHousing, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testXGBoostPredictHousing, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -184,9 +184,9 @@ func TestExecutorTrainAndPredictDNN(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
stream := RunSQLProgram(testTrainSelectIris, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testTrainSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testPredictSelectIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testPredictSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -197,9 +197,9 @@ func TestExecutorTrainAndPredictClusteringLocalFS(t *testing.T) {
a.Nil(e)
defer os.RemoveAll(modelDir)
a.NotPanics(func() {
stream := RunSQLProgram(testClusteringTrain, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testClusteringTrain, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testClusteringPredict, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testClusteringPredict, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -210,9 +210,9 @@ func TestExecutorTrainAndPredictDNNLocalFS(t *testing.T) {
a.Nil(e)
defer os.RemoveAll(modelDir)
a.NotPanics(func() {
stream := RunSQLProgram(testTrainSelectIris, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testTrainSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testPredictSelectIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testPredictSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -234,14 +234,14 @@ train.verbose = 1
COLUMN NUMERIC(dense, 4)
LABEL class
INTO sqlflow_models.my_dense_dnn_model;`
stream := RunSQLProgram(trainSQL, testDB, "", getDefaultSession())
stream := RunSQLProgram(trainSQL, "", getDefaultSession())
a.True(goodStream(stream.ReadAll()))

predSQL := `SELECT * FROM iris.test_dense
TO PREDICT iris.predict_dense.class
USING sqlflow_models.my_dense_dnn_model
;`
stream = RunSQLProgram(predSQL, testDB, "", getDefaultSession())
stream = RunSQLProgram(predSQL, "", getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestSubmitWorkflow(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
rd := SubmitWorkflow(testXGBoostTrainSelectIris, testDB, modelDir, getDefaultSession())
rd := SubmitWorkflow(testXGBoostTrainSelectIris, modelDir, getDefaultSession())
for r := range rd.ReadAll() {
switch r.(type) {
case WorkflowJob:
Expand Down
5 changes: 4 additions & 1 deletion pkg/sql/executor_standard_sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ package sql

import (
"container/list"
"fmt"
"strings"
"testing"

"github.com/stretchr/testify/assert"
pb "sqlflow.org/sqlflow/pkg/proto"
)

const (
Expand Down Expand Up @@ -85,7 +87,8 @@ func TestStandardSQL(t *testing.T) {

func TestSQLLexerError(t *testing.T) {
a := assert.New(t)
stream := RunSQLProgram("SELECT * FROM ``?[] AS WHERE LIMIT;", testDB, "", nil)
ds := fmt.Sprintf("%s://%s", testDB.driverName, testDB.dataSourceName)
stream := RunSQLProgram("SELECT * FROM ``?[] AS WHERE LIMIT;", "", &pb.Session{DbConnStr: ds})
a.False(goodStream(stream.ReadAll()))
}

Expand Down
Loading

0 comments on commit c103f83

Please sign in to comment.