Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay initialize db object #1323

Merged
merged 8 commits into from
Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions cmd/repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,18 @@ func flagPassed(name ...string) bool {
return found
}

func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds string) error {
func runStmt(stmt string, isTerminal bool, modelDir string, ds string) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ds -> datasource?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, will polish it in the next PR.

if !isTerminal {
fmt.Println("sqlflow>", stmt)
}
tableRendered := false
table := tablewriter.NewWriter(os.Stdout)
sess := makeSessionFromEnv()
if ds != "" {
sess.DbConnStr = ds
}

stream := sql.RunSQLProgram(stmt, db, modelDir, sess)
stream := sql.RunSQLProgram(stmt, modelDir, sess)
for rsp := range stream.ReadAll() {
// pagination. avoid exceed memory
if render(rsp, table) && table.NumLines() == tablePageSize {
Expand All @@ -128,14 +131,19 @@ func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds strin
return nil
}

func repl(scanner *bufio.Scanner, modelDir string, db *sql.DB, ds string) {
func repl(scanner *bufio.Scanner, modelDir string, ds string) {
db, err := sql.NewDB(ds)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
defer db.Close()
for {
stmt, err := readStmt(scanner)
fmt.Println()
if err == io.EOF && stmt == "" {
return
}
if err := runStmt(stmt, false, modelDir, db, ds); err != nil {
if err := runStmt(stmt, false, modelDir, ds); err != nil {
log.Fatalf("run SQL statment failed: %v", err)
}
}
Expand Down Expand Up @@ -196,12 +204,6 @@ func main() {
os.Exit(0)
}

db, err := sql.NewDB(*ds)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
defer db.Close()

if *modelDir != "" {
if _, derr := os.Stat(*modelDir); derr != nil {
os.Mkdir(*modelDir, os.ModePerm)
Expand All @@ -211,6 +213,7 @@ func main() {
isTerminal := !flagPassed("execute", "e", "file", "f") && terminal.IsTerminal(syscall.Stdin)

sqlFile := os.Stdin
var err error
if flagPassed("file", "f") {
sqlFile, err = os.Open(*sqlFileName)
if err != nil {
Expand All @@ -225,8 +228,8 @@ func main() {
}
scanner := bufio.NewScanner(reader)
if isTerminal {
runPrompt(func(stmt string) { runStmt(stmt, true, *modelDir, db, *ds) })
runPrompt(func(stmt string) { runStmt(stmt, true, *modelDir, *ds) })
} else {
repl(scanner, *modelDir, db, *ds)
repl(scanner, *modelDir, *ds)
}
}
2 changes: 1 addition & 1 deletion cmd/repl/repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ INTO sqlflow_models.mymodel;`

// run one train SQL to save the model then test predict/analyze use the model
sess := &irpb.Session{DbConnStr: dataSourceStr}
stream := sf.RunSQLProgram(trainSQL, testdb, "", sess)
stream := sf.RunSQLProgram(trainSQL, "", sess)
lastResp := list.New()
keepSize := 10
for rsp := range stream.ReadAll() {
Expand Down
13 changes: 3 additions & 10 deletions pkg/server/sqlflowserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ import (
)

// NewServer returns a server instance
func NewServer(run func(string, *sf.DB, string, *pb.Session) *sf.PipeReader,
func NewServer(run func(string, string, *pb.Session) *sf.PipeReader,
modelDir string) *Server {
return &Server{run: run, modelDir: modelDir}
}

// Server is the instance will be used to connect to DB and execute training
type Server struct {
run func(sql string, db *sf.DB, modelDir string, session *pb.Session) *sf.PipeReader
run func(sql string, modelDir string, session *pb.Session) *sf.PipeReader
modelDir string
}

Expand All @@ -51,18 +51,11 @@ func (s *Server) Fetch(ctx context.Context, job *pb.Job) (*pb.JobStatus, error)

// Run implements `rpc Run (Request) returns (stream Response)`
func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error {
var db *sf.DB

var err error
if db, err = sf.NewDB(req.Session.DbConnStr); err != nil {
return fmt.Errorf("create DB failed: %v", err)
}
defer db.Close()
sqlStatements, err := sf.SplitMultipleSQL(req.Sql)
if err != nil {
return err
}
rd := s.run(req.Sql, db, s.modelDir, req.Session)
rd := s.run(req.Sql, s.modelDir, req.Session)
defer rd.Close()

for r := range rd.ReadAll() {
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/sqlflowserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const (

var testServerAddress string

func mockRun(sql string, db *sf.DB, modelDir string, session *pb.Session) *sf.PipeReader {
func mockRun(sql string, modelDir string, session *pb.Session) *sf.PipeReader {
rd, wr := sf.Pipe()
singleSQL := sql
go func() {
Expand Down
19 changes: 14 additions & 5 deletions pkg/sql/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ type DB struct {
// In addition to sql.Open, it also does the book keeping on driverName and
// dataSourceName
func open(datasource string) (*DB, error) {
dses := strings.Split(datasource, "://")
if len(dses) != 2 {
return nil, fmt.Errorf("Expecting but cannot find :// in datasource %v", datasource)
driverName, datasourName, err := SplitDataSource(datasource)
if err != nil {
return nil, err
}
db := &DB{driverName: dses[0], dataSourceName: dses[1]}
db := &DB{driverName: driverName, dataSourceName: datasourName}

err := openDB(db)
err = openDB(db)
return db, err
}

Expand All @@ -63,6 +63,15 @@ func openDB(db *DB) error {
return fmt.Errorf("sqlflow currently doesn't support DB %s", db.driverName)
}

// SplitDataSource splits the datasource into drivername and datasource name
func SplitDataSource(datasource string) (string, string, error) {
dses := strings.Split(datasource, "://")
if len(dses) != 2 {
return "", "", fmt.Errorf("Expecting but cannot find :// in datasource %v", datasource)
}
return dses[0], dses[1], nil
}

// NewDB returns a DB object with verifying the datasource name.
func NewDB(datasource string) (*DB, error) {
db, err := open(datasource)
Expand Down
9 changes: 9 additions & 0 deletions pkg/sql/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,12 @@ func TestDatabaseOpenMysql(t *testing.T) {
_, e = db.Exec("show databases")
a.NoError(e)
}

func TestSplitDataSource(t *testing.T) {
a := assert.New(t)
ds := "mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0"
driverName, datasourceName, e := SplitDataSource(ds)
a.EqualValues(driverName, "mysql")
a.EqualValues(datasourceName, "root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0")
a.NoError(e)
}
33 changes: 24 additions & 9 deletions pkg/sql/executor_ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ type WorkflowJob struct {
}

// RunSQLProgram run a SQL program.
func RunSQLProgram(sqlProgram string, db *DB, modelDir string, session *pb.Session) *PipeReader {
func RunSQLProgram(sqlProgram string, modelDir string, session *pb.Session) *PipeReader {
rd, wr := Pipe()
go func() {
var db *DB
var err error
if db, err = NewDB(session.DbConnStr); err != nil {
wr.Write(fmt.Sprintf("create DB failed: %v", err))
log.Errorf("create DB failed: %v", err)
}
defer wr.Close()
err := runSQLProgram(wr, sqlProgram, db, modelDir, session)
err = runSQLProgram(wr, sqlProgram, db, modelDir, session)

if err != nil {
log.Errorf("runSQLProgram error: %v", err)
Expand Down Expand Up @@ -106,13 +112,17 @@ func ParseSQLStatement(sql string, session *pb.Session) (string, error) {
}

// SubmitWorkflow submits an Argo workflow
func SubmitWorkflow(sqlProgram string, db *DB, modelDir string, session *pb.Session) *PipeReader {
func SubmitWorkflow(sqlProgram string, modelDir string, session *pb.Session) *PipeReader {
rd, wr := Pipe()
go func() {
defer wr.Close()
err := submitWorkflow(wr, sqlProgram, db, modelDir, session)
err := submitWorkflow(wr, sqlProgram, modelDir, session)
if err != nil {
log.Errorf("submit Workflow error: %v", err)
if err != ErrClosedPipe {
if err := wr.Write(err); err != nil {
log.Errorf("submit workflow error(piping): %v", err)
}
}
}
}()
return rd
Expand Down Expand Up @@ -153,8 +163,12 @@ func writeArgoFile(coulerFileName string) (string, error) {
return argoYaml.Name(), nil
}

func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, session *pb.Session) error {
sqls, err := parse(db.driverName, sqlProgram)
func submitWorkflow(wr *PipeWriter, sqlProgram string, modelDir string, session *pb.Session) error {
driverName, dataSourceName, err := SplitDataSource(session.DbConnStr)
if err != nil {
return err
}
sqls, err := parse(driverName, sqlProgram)
if err != nil {
return err
}
Expand All @@ -166,7 +180,7 @@ func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string,
spIRs := []ir.SQLStatement{}
for _, sql := range sqls {
var r ir.SQLStatement
connStr := fmt.Sprintf("%s://%s", db.driverName, db.dataSourceName)
connStr := fmt.Sprintf("%s://%s", driverName, dataSourceName)
if sql.extended != nil {
parsed := sql.extended
if parsed.train {
Expand Down Expand Up @@ -203,7 +217,7 @@ func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string,
cmd := exec.Command("kubectl", "create", "-f", argoFile)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("submit Argo YAML error: %v", err)
return fmt.Errorf("submit Argo YAML error: %v, output: %s", err, string(output))
}
reWorkflow := regexp.MustCompile(`.+/(.+) .+`)
wf := reWorkflow.FindStringSubmatch(string(output))
Expand Down Expand Up @@ -409,6 +423,7 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string)
if modelDir != "" {
modelURI = fmt.Sprintf("file://%s/%s", modelDir, modelName)
}

m, e = load(modelURI, cwd, db)
if e != nil {
return nil, fmt.Errorf("load %v", e)
Expand Down
36 changes: 18 additions & 18 deletions pkg/sql/executor_ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ USING sqlflow_models.my_xgboost_model_by_program;
SELECT sepal_length as sl, sepal_width as sw, class FROM iris.train
TO EXPLAIN sqlflow_models.my_xgboost_model_by_program
USING TreeExplainer;
`, testDB, modelDir, getDefaultSession())
`, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})

Expand All @@ -152,17 +152,17 @@ func TestExecuteXGBoostClassifier(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
stream := RunSQLProgram(testTrainSelectWithLimit, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testTrainSelectWithLimit, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testXGBoostPredictIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testXGBoostPredictIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
a.NotPanics(func() {
stream := RunSQLProgram(testXGBoostTrainSelectIris, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testXGBoostTrainSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testXGBoostPredictIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testXGBoostPredictIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -171,11 +171,11 @@ func TestExecuteXGBoostRegression(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
stream := RunSQLProgram(testXGBoostTrainSelectHousing, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testXGBoostTrainSelectHousing, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testXGBoostPredictHousing, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testXGBoostPredictHousing, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -184,9 +184,9 @@ func TestExecutorTrainAndPredictDNN(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
stream := RunSQLProgram(testTrainSelectIris, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testTrainSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testPredictSelectIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testPredictSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -197,9 +197,9 @@ func TestExecutorTrainAndPredictClusteringLocalFS(t *testing.T) {
a.Nil(e)
defer os.RemoveAll(modelDir)
a.NotPanics(func() {
stream := RunSQLProgram(testClusteringTrain, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testClusteringTrain, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testClusteringPredict, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testClusteringPredict, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -210,9 +210,9 @@ func TestExecutorTrainAndPredictDNNLocalFS(t *testing.T) {
a.Nil(e)
defer os.RemoveAll(modelDir)
a.NotPanics(func() {
stream := RunSQLProgram(testTrainSelectIris, testDB, modelDir, getDefaultSession())
stream := RunSQLProgram(testTrainSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
stream = RunSQLProgram(testPredictSelectIris, testDB, modelDir, getDefaultSession())
stream = RunSQLProgram(testPredictSelectIris, modelDir, getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand All @@ -234,14 +234,14 @@ train.verbose = 1
COLUMN NUMERIC(dense, 4)
LABEL class
INTO sqlflow_models.my_dense_dnn_model;`
stream := RunSQLProgram(trainSQL, testDB, "", getDefaultSession())
stream := RunSQLProgram(trainSQL, "", getDefaultSession())
a.True(goodStream(stream.ReadAll()))

predSQL := `SELECT * FROM iris.test_dense
TO PREDICT iris.predict_dense.class
USING sqlflow_models.my_dense_dnn_model
;`
stream = RunSQLProgram(predSQL, testDB, "", getDefaultSession())
stream = RunSQLProgram(predSQL, "", getDefaultSession())
a.True(goodStream(stream.ReadAll()))
})
}
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestSubmitWorkflow(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
rd := SubmitWorkflow(testXGBoostTrainSelectIris, testDB, modelDir, getDefaultSession())
rd := SubmitWorkflow(testXGBoostTrainSelectIris, modelDir, getDefaultSession())
for r := range rd.ReadAll() {
switch r.(type) {
case WorkflowJob:
Expand Down
5 changes: 4 additions & 1 deletion pkg/sql/executor_standard_sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ package sql

import (
"container/list"
"fmt"
"strings"
"testing"

"github.com/stretchr/testify/assert"
pb "sqlflow.org/sqlflow/pkg/proto"
)

const (
Expand Down Expand Up @@ -85,7 +87,8 @@ func TestStandardSQL(t *testing.T) {

func TestSQLLexerError(t *testing.T) {
a := assert.New(t)
stream := RunSQLProgram("SELECT * FROM ``?[] AS WHERE LIMIT;", testDB, "", nil)
ds := fmt.Sprintf("%s://%s", testDB.driverName, testDB.dataSourceName)
stream := RunSQLProgram("SELECT * FROM ``?[] AS WHERE LIMIT;", "", &pb.Session{DbConnStr: ds})
a.False(goodStream(stream.ReadAll()))
}

Expand Down
Loading