diff --git a/cmd/repl/repl.go b/cmd/repl/repl.go index 2824027f07..a95ad01a4f 100644 --- a/cmd/repl/repl.go +++ b/cmd/repl/repl.go @@ -105,15 +105,18 @@ func flagPassed(name ...string) bool { return found } -func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds string) error { +func runStmt(stmt string, isTerminal bool, modelDir string, ds string) error { if !isTerminal { fmt.Println("sqlflow>", stmt) } tableRendered := false table := tablewriter.NewWriter(os.Stdout) sess := makeSessionFromEnv() + if ds != "" { + sess.DbConnStr = ds + } - stream := sql.RunSQLProgram(stmt, db, modelDir, sess) + stream := sql.RunSQLProgram(stmt, modelDir, sess) for rsp := range stream.ReadAll() { // pagination. avoid exceed memory if render(rsp, table) && table.NumLines() == tablePageSize { @@ -128,14 +131,19 @@ func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds strin return nil } -func repl(scanner *bufio.Scanner, modelDir string, db *sql.DB, ds string) { +func repl(scanner *bufio.Scanner, modelDir string, ds string) { + db, err := sql.NewDB(ds) + if err != nil { + log.Fatalf("failed to open database: %v", err) + } + defer db.Close() for { stmt, err := readStmt(scanner) fmt.Println() if err == io.EOF && stmt == "" { return } - if err := runStmt(stmt, false, modelDir, db, ds); err != nil { + if err := runStmt(stmt, false, modelDir, ds); err != nil { log.Fatalf("run SQL statment failed: %v", err) } } @@ -196,12 +204,6 @@ func main() { os.Exit(0) } - db, err := sql.NewDB(*ds) - if err != nil { - log.Fatalf("failed to open database: %v", err) - } - defer db.Close() - if *modelDir != "" { if _, derr := os.Stat(*modelDir); derr != nil { os.Mkdir(*modelDir, os.ModePerm) @@ -211,6 +213,7 @@ func main() { isTerminal := !flagPassed("execute", "e", "file", "f") && terminal.IsTerminal(syscall.Stdin) sqlFile := os.Stdin + var err error if flagPassed("file", "f") { sqlFile, err = os.Open(*sqlFileName) if err != nil { @@ -225,8 +228,8 @@ func main() { } scanner := bufio.NewScanner(reader) if isTerminal { - runPrompt(func(stmt string) { runStmt(stmt, true, *modelDir, db, *ds) }) + runPrompt(func(stmt string) { runStmt(stmt, true, *modelDir, *ds) }) } else { - repl(scanner, *modelDir, db, *ds) + repl(scanner, *modelDir, *ds) } } diff --git a/cmd/repl/repl_test.go b/cmd/repl/repl_test.go index 6a0786feec..758384f4b9 100644 --- a/cmd/repl/repl_test.go +++ b/cmd/repl/repl_test.go @@ -230,7 +230,7 @@ INTO sqlflow_models.mymodel;` // run one train SQL to save the model then test predict/analyze use the model sess := &irpb.Session{DbConnStr: dataSourceStr} - stream := sf.RunSQLProgram(trainSQL, testdb, "", sess) + stream := sf.RunSQLProgram(trainSQL, "", sess) lastResp := list.New() keepSize := 10 for rsp := range stream.ReadAll() { diff --git a/pkg/server/sqlflowserver.go b/pkg/server/sqlflowserver.go index b39c82cf7a..6689321b29 100644 --- a/pkg/server/sqlflowserver.go +++ b/pkg/server/sqlflowserver.go @@ -32,14 +32,14 @@ import ( ) // NewServer returns a server instance -func NewServer(run func(string, *sf.DB, string, *pb.Session) *sf.PipeReader, +func NewServer(run func(string, string, *pb.Session) *sf.PipeReader, modelDir string) *Server { return &Server{run: run, modelDir: modelDir} } // Server is the instance will be used to connect to DB and execute training type Server struct { - run func(sql string, db *sf.DB, modelDir string, session *pb.Session) *sf.PipeReader + run func(sql string, modelDir string, session *pb.Session) *sf.PipeReader modelDir string } @@ -51,18 +51,11 @@ func (s *Server) Fetch(ctx context.Context, job *pb.Job) (*pb.JobStatus, error) // Run implements `rpc Run (Request) returns (stream Response)` func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error { - var db *sf.DB - - var err error - if db, err = sf.NewDB(req.Session.DbConnStr); err != nil { - return fmt.Errorf("create DB failed: %v", err) - } - defer db.Close() sqlStatements, err := sf.SplitMultipleSQL(req.Sql) if err != nil { return err } - rd := s.run(req.Sql, db, s.modelDir, req.Session) + rd := s.run(req.Sql, s.modelDir, req.Session) defer rd.Close() for r := range rd.ReadAll() { diff --git a/pkg/server/sqlflowserver_test.go b/pkg/server/sqlflowserver_test.go index c240cd9363..9a1660e3c8 100644 --- a/pkg/server/sqlflowserver_test.go +++ b/pkg/server/sqlflowserver_test.go @@ -46,7 +46,7 @@ const ( var testServerAddress string -func mockRun(sql string, db *sf.DB, modelDir string, session *pb.Session) *sf.PipeReader { +func mockRun(sql string, modelDir string, session *pb.Session) *sf.PipeReader { rd, wr := sf.Pipe() singleSQL := sql go func() { diff --git a/pkg/sql/database.go b/pkg/sql/database.go index 1b4afd060f..2dce78039d 100644 --- a/pkg/sql/database.go +++ b/pkg/sql/database.go @@ -39,13 +39,13 @@ type DB struct { // In addition to sql.Open, it also does the book keeping on driverName and // dataSourceName func open(datasource string) (*DB, error) { - dses := strings.Split(datasource, "://") - if len(dses) != 2 { - return nil, fmt.Errorf("Expecting but cannot find :// in datasource %v", datasource) + driverName, datasourName, err := SplitDataSource(datasource) + if err != nil { + return nil, err } - db := &DB{driverName: dses[0], dataSourceName: dses[1]} + db := &DB{driverName: driverName, dataSourceName: datasourName} - err := openDB(db) + err = openDB(db) return db, err } @@ -63,6 +63,15 @@ func openDB(db *DB) error { return fmt.Errorf("sqlflow currently doesn't support DB %s", db.driverName) } +// SplitDataSource splits the datasource into drivername and datasource name +func SplitDataSource(datasource string) (string, string, error) { + dses := strings.Split(datasource, "://") + if len(dses) != 2 { + return "", "", fmt.Errorf("Expecting but cannot find :// in datasource %v", datasource) + } + return dses[0], dses[1], nil +} + // NewDB returns a DB object with verifying the datasource name. func NewDB(datasource string) (*DB, error) { db, err := open(datasource) diff --git a/pkg/sql/database_test.go b/pkg/sql/database_test.go index 9a351a9c31..d16ca79537 100644 --- a/pkg/sql/database_test.go +++ b/pkg/sql/database_test.go @@ -42,3 +42,12 @@ func TestDatabaseOpenMysql(t *testing.T) { _, e = db.Exec("show databases") a.NoError(e) } + +func TestSplitDataSource(t *testing.T) { + a := assert.New(t) + ds := "mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0" + driverName, datasourceName, e := SplitDataSource(ds) + a.EqualValues(driverName, "mysql") + a.EqualValues(datasourceName, "root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0") + a.NoError(e) +} diff --git a/pkg/sql/executor_ir.go b/pkg/sql/executor_ir.go index a63a53d1f9..8884e99578 100644 --- a/pkg/sql/executor_ir.go +++ b/pkg/sql/executor_ir.go @@ -42,11 +42,17 @@ type WorkflowJob struct { } // RunSQLProgram run a SQL program. -func RunSQLProgram(sqlProgram string, db *DB, modelDir string, session *pb.Session) *PipeReader { +func RunSQLProgram(sqlProgram string, modelDir string, session *pb.Session) *PipeReader { rd, wr := Pipe() go func() { + var db *DB + var err error + if db, err = NewDB(session.DbConnStr); err != nil { + wr.Write(fmt.Sprintf("create DB failed: %v", err)) + log.Errorf("create DB failed: %v", err) + } defer wr.Close() - err := runSQLProgram(wr, sqlProgram, db, modelDir, session) + err = runSQLProgram(wr, sqlProgram, db, modelDir, session) if err != nil { log.Errorf("runSQLProgram error: %v", err) @@ -106,13 +112,17 @@ func ParseSQLStatement(sql string, session *pb.Session) (string, error) { } // SubmitWorkflow submits an Argo workflow -func SubmitWorkflow(sqlProgram string, db *DB, modelDir string, session *pb.Session) *PipeReader { +func SubmitWorkflow(sqlProgram string, modelDir string, session *pb.Session) *PipeReader { rd, wr := Pipe() go func() { defer wr.Close() - err := submitWorkflow(wr, sqlProgram, db, modelDir, session) + err := submitWorkflow(wr, sqlProgram, modelDir, session) if err != nil { - log.Errorf("submit Workflow error: %v", err) + if err != ErrClosedPipe { + if err := wr.Write(err); err != nil { + log.Errorf("submit workflow error(piping): %v", err) + } + } } }() return rd @@ -153,8 +163,12 @@ func writeArgoFile(coulerFileName string) (string, error) { return argoYaml.Name(), nil } -func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, session *pb.Session) error { - sqls, err := parse(db.driverName, sqlProgram) +func submitWorkflow(wr *PipeWriter, sqlProgram string, modelDir string, session *pb.Session) error { + driverName, dataSourceName, err := SplitDataSource(session.DbConnStr) + if err != nil { + return err + } + sqls, err := parse(driverName, sqlProgram) if err != nil { return err } @@ -166,7 +180,7 @@ func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, spIRs := []ir.SQLStatement{} for _, sql := range sqls { var r ir.SQLStatement - connStr := fmt.Sprintf("%s://%s", db.driverName, db.dataSourceName) + connStr := fmt.Sprintf("%s://%s", driverName, dataSourceName) if sql.extended != nil { parsed := sql.extended if parsed.train { @@ -203,7 +217,7 @@ func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, cmd := exec.Command("kubectl", "create", "-f", argoFile) output, err := cmd.CombinedOutput() if err != nil { - return fmt.Errorf("submit Argo YAML error: %v", err) + return fmt.Errorf("submit Argo YAML error: %v, output: %s", err, string(output)) } reWorkflow := regexp.MustCompile(`.+/(.+) .+`) wf := reWorkflow.FindStringSubmatch(string(output)) @@ -409,6 +423,7 @@ func loadModelMeta(pr *extendedSelect, db *DB, cwd, modelDir, modelName string) if modelDir != "" { modelURI = fmt.Sprintf("file://%s/%s", modelDir, modelName) } + m, e = load(modelURI, cwd, db) if e != nil { return nil, fmt.Errorf("load %v", e) diff --git a/pkg/sql/executor_ir_test.go b/pkg/sql/executor_ir_test.go index a97fe994d8..b32a1226d8 100644 --- a/pkg/sql/executor_ir_test.go +++ b/pkg/sql/executor_ir_test.go @@ -142,7 +142,7 @@ USING sqlflow_models.my_xgboost_model_by_program; SELECT sepal_length as sl, sepal_width as sw, class FROM iris.train TO EXPLAIN sqlflow_models.my_xgboost_model_by_program USING TreeExplainer; -`, testDB, modelDir, getDefaultSession()) +`, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) }) @@ -152,17 +152,17 @@ func TestExecuteXGBoostClassifier(t *testing.T) { a := assert.New(t) modelDir := "" a.NotPanics(func() { - stream := RunSQLProgram(testTrainSelectWithLimit, testDB, modelDir, getDefaultSession()) + stream := RunSQLProgram(testTrainSelectWithLimit, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) - stream = RunSQLProgram(testXGBoostPredictIris, testDB, modelDir, getDefaultSession()) + stream = RunSQLProgram(testXGBoostPredictIris, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) }) a.NotPanics(func() { - stream := RunSQLProgram(testXGBoostTrainSelectIris, testDB, modelDir, getDefaultSession()) + stream := RunSQLProgram(testXGBoostTrainSelectIris, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) - stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, testDB, modelDir, getDefaultSession()) + stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) - stream = RunSQLProgram(testXGBoostPredictIris, testDB, modelDir, getDefaultSession()) + stream = RunSQLProgram(testXGBoostPredictIris, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) }) } @@ -171,11 +171,11 @@ func TestExecuteXGBoostRegression(t *testing.T) { a := assert.New(t) modelDir := "" a.NotPanics(func() { - stream := RunSQLProgram(testXGBoostTrainSelectHousing, testDB, modelDir, getDefaultSession()) + stream := RunSQLProgram(testXGBoostTrainSelectHousing, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) - stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, testDB, modelDir, getDefaultSession()) + stream = RunSQLProgram(testAnalyzeTreeModelSelectIris, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) - stream = RunSQLProgram(testXGBoostPredictHousing, testDB, modelDir, getDefaultSession()) + stream = RunSQLProgram(testXGBoostPredictHousing, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) }) } @@ -184,9 +184,9 @@ func TestExecutorTrainAndPredictDNN(t *testing.T) { a := assert.New(t) modelDir := "" a.NotPanics(func() { - stream := RunSQLProgram(testTrainSelectIris, testDB, modelDir, getDefaultSession()) + stream := RunSQLProgram(testTrainSelectIris, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) - stream = RunSQLProgram(testPredictSelectIris, testDB, modelDir, getDefaultSession()) + stream = RunSQLProgram(testPredictSelectIris, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) }) } @@ -197,9 +197,9 @@ func TestExecutorTrainAndPredictClusteringLocalFS(t *testing.T) { a.Nil(e) defer os.RemoveAll(modelDir) a.NotPanics(func() { - stream := RunSQLProgram(testClusteringTrain, testDB, modelDir, getDefaultSession()) + stream := RunSQLProgram(testClusteringTrain, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) - stream = RunSQLProgram(testClusteringPredict, testDB, modelDir, getDefaultSession()) + stream = RunSQLProgram(testClusteringPredict, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) }) } @@ -210,9 +210,9 @@ func TestExecutorTrainAndPredictDNNLocalFS(t *testing.T) { a.Nil(e) defer os.RemoveAll(modelDir) a.NotPanics(func() { - stream := RunSQLProgram(testTrainSelectIris, testDB, modelDir, getDefaultSession()) + stream := RunSQLProgram(testTrainSelectIris, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) - stream = RunSQLProgram(testPredictSelectIris, testDB, modelDir, getDefaultSession()) + stream = RunSQLProgram(testPredictSelectIris, modelDir, getDefaultSession()) a.True(goodStream(stream.ReadAll())) }) } @@ -234,14 +234,14 @@ train.verbose = 1 COLUMN NUMERIC(dense, 4) LABEL class INTO sqlflow_models.my_dense_dnn_model;` - stream := RunSQLProgram(trainSQL, testDB, "", getDefaultSession()) + stream := RunSQLProgram(trainSQL, "", getDefaultSession()) a.True(goodStream(stream.ReadAll())) predSQL := `SELECT * FROM iris.test_dense TO PREDICT iris.predict_dense.class USING sqlflow_models.my_dense_dnn_model ;` - stream = RunSQLProgram(predSQL, testDB, "", getDefaultSession()) + stream = RunSQLProgram(predSQL, "", getDefaultSession()) a.True(goodStream(stream.ReadAll())) }) } @@ -287,7 +287,7 @@ func TestSubmitWorkflow(t *testing.T) { a := assert.New(t) modelDir := "" a.NotPanics(func() { - rd := SubmitWorkflow(testXGBoostTrainSelectIris, testDB, modelDir, getDefaultSession()) + rd := SubmitWorkflow(testXGBoostTrainSelectIris, modelDir, getDefaultSession()) for r := range rd.ReadAll() { switch r.(type) { case WorkflowJob: diff --git a/pkg/sql/executor_standard_sql_test.go b/pkg/sql/executor_standard_sql_test.go index 6b19ef8747..325a7b50a9 100644 --- a/pkg/sql/executor_standard_sql_test.go +++ b/pkg/sql/executor_standard_sql_test.go @@ -15,10 +15,12 @@ package sql import ( "container/list" + "fmt" "strings" "testing" "github.com/stretchr/testify/assert" + pb "sqlflow.org/sqlflow/pkg/proto" ) const ( @@ -85,7 +87,8 @@ func TestStandardSQL(t *testing.T) { func TestSQLLexerError(t *testing.T) { a := assert.New(t) - stream := RunSQLProgram("SELECT * FROM ``?[] AS WHERE LIMIT;", testDB, "", nil) + ds := fmt.Sprintf("%s://%s", testDB.driverName, testDB.dataSourceName) + stream := RunSQLProgram("SELECT * FROM ``?[] AS WHERE LIMIT;", "", &pb.Session{DbConnStr: ds}) a.False(goodStream(stream.ReadAll())) } diff --git a/pkg/sql/ir_generator_test.go b/pkg/sql/ir_generator_test.go index 4f73fd7734..ece264977b 100644 --- a/pkg/sql/ir_generator_test.go +++ b/pkg/sql/ir_generator_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + pb "sqlflow.org/sqlflow/pkg/proto" "sqlflow.org/sqlflow/pkg/sql/ir" ) @@ -207,7 +208,7 @@ TO TRAIN DNNClassifier WITH model.n_classes=3, model.hidden_units=[10,20] COLUMN sepal_length, sepal_width, petal_length, petal_width LABEL class -INTO sqlflow_models.mymodel;`, testDB, modelDir, nil) +INTO sqlflow_models.mymodel;`, modelDir, &pb.Session{DbConnStr: connStr}) a.True(goodStream(stream.ReadAll())) predIR, err := generatePredictIR(r, connStr, modelDir, true) @@ -227,6 +228,7 @@ func TestGenerateAnalyzeIR(t *testing.T) { t.Skip(fmt.Sprintf("%s: skip test", getEnv("SQLFLOW_TEST_DB", "mysql"))) } a := assert.New(t) + connStr := "mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0" modelDir, e := ioutil.TempDir("/tmp", "sqlflow_models") a.Nil(e) @@ -241,7 +243,7 @@ WITH COLUMN sepal_length, sepal_width, petal_length, petal_width LABEL class INTO sqlflow_models.my_xgboost_model; -`, testDB, modelDir, nil) +`, modelDir, &pb.Session{DbConnStr: connStr}) a.NoError(e) a.True(goodStream(stream.ReadAll())) @@ -257,7 +259,6 @@ INTO sqlflow_models.my_xgboost_model; `) a.NoError(e) - connStr := "mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0" AnalyzeIR, e := generateAnalyzeIR(pr, connStr, modelDir, true) a.NoError(e) a.Equal(AnalyzeIR.DataSource, connStr)