From 4cc901233b652de06a8079274b22c670645e1dfd Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 2 Dec 2019 17:41:37 +0800 Subject: [PATCH 1/5] fix repl hive --- cmd/repl/repl.go | 50 ++++++++++++++++++++++++++++-------------------- pkg/sql/model.go | 5 ++++- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/cmd/repl/repl.go b/cmd/repl/repl.go index 879473078a..3c536dc930 100644 --- a/cmd/repl/repl.go +++ b/cmd/repl/repl.go @@ -64,7 +64,7 @@ func header(head map[string]interface{}) ([]string, error) { return cols, nil } -func render(rsp interface{}, table *tablewriter.Table) bool { +func render(rsp interface{}, table *tablewriter.Table) (bool, error) { isTable := false switch s := rsp.(type) { case map[string]interface{}: // table header @@ -81,15 +81,13 @@ func render(rsp interface{}, table *tablewriter.Table) bool { table.Append(row) isTable = true case error: - if os.Getenv("SQLFLOW_log_dir") != "" { // To avoid printing duplicated error message to console - log.New(os.Stderr, "", 0).Printf("ERROR: %v\n", s) - } + return false, s case sql.EndOfExecution: - return isTable + return isTable, nil default: fmt.Println(s) } - return isTable + return isTable, nil } func flagPassed(name ...string) bool { @@ -104,16 +102,21 @@ func flagPassed(name ...string) bool { return found } -func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds string) { +func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds string) error { if !isTerminal { fmt.Println("sqlflow>", stmt) } isTable, tableRendered := false, false + var err error table := tablewriter.NewWriter(os.Stdout) + sess := makeSessionFromEnv() - stream := sql.RunSQLProgram(stmt, db, modelDir, &pb.Session{}) + stream := sql.RunSQLProgram(stmt, db, modelDir, sess) for rsp := range stream.ReadAll() { - isTable = render(rsp, table) + isTable, err = render(rsp, table) + if err != nil { + return err + } // pagination. avoid exceed memory if isTable && table.NumLines() == tablePageSize { @@ -125,6 +128,7 @@ func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds strin if table.NumLines() > 0 || !tableRendered { table.Render() } + return nil } func repl(scanner *bufio.Scanner, modelDir string, db *sql.DB, ds string) { @@ -134,9 +138,23 @@ func repl(scanner *bufio.Scanner, modelDir string, db *sql.DB, ds string) { if err == io.EOF && stmt == "" { return } - runStmt(stmt, false, modelDir, db, ds) + if err := runStmt(stmt, false, modelDir, db, ds); err != nil { + log.Fatalf("run SQL statment failed: %v", err) + } } +} +func makeSessionFromEnv() *pb.Session { + return &pb.Session{ + Token: os.Getenv("SQLFLOW_USER_TOKEN"), + DbConnStr: os.Getenv("SQLFLOW_DATASOURCE"), + ExitOnSubmit: strings.ToLower(os.Getenv("SQLFLOW_EXIT_ON_SUBMIT")) == "true", + UserId: os.Getenv("SQLFLOW_USER_ID"), + HiveLocation: os.Getenv("SQLFLOW_HIVE_LOCATION"), + HdfsNamenodeAddr: os.Getenv("SQLFLOW_HDFS_NAMENODE_ADDR"), + HdfsUser: os.Getenv("JUPYTER_HADOOP_USER"), + HdfsPass: os.Getenv("JUPYTER_HADOOP_PASS"), + } } func parseSQLFromStdin(stdin io.Reader) (string, error) { @@ -152,17 +170,7 @@ func parseSQLFromStdin(stdin io.Reader) (string, error) { if sqlflowDatasrouce == "" { return "", fmt.Errorf("no SQLFLOW_DATASOURCE env provided") } - - sess := &pb.Session{ - Token: os.Getenv("SQLFLOW_USER_TOKEN"), - DbConnStr: os.Getenv("SQLFLOW_DATASOURCE"), - ExitOnSubmit: strings.ToLower(os.Getenv("SQLFLOW_EXIT_ON_SUBMIT")) == "true", - UserId: os.Getenv("SQLFLOW_USER_ID"), - HiveLocation: os.Getenv("SQLFLOW_HIVE_LOCATION"), - HdfsNamenodeAddr: os.Getenv("SQLFLOW_HDFS_NAMENODE_ADDR"), - HdfsUser: os.Getenv("JUPYTER_HADOOP_USER"), - HdfsPass: os.Getenv("JUPYTER_HADOOP_PASS"), - } + sess := makeSessionFromEnv() pbIRStr, err := sql.ParseSQLStatement(strings.Join(scanedInput, ""), sess) if err != nil { return "", err diff --git a/pkg/sql/model.go b/pkg/sql/model.go index cd764b9f97..9e480b2430 100644 --- a/pkg/sql/model.go +++ b/pkg/sql/model.go @@ -39,7 +39,6 @@ func (m *model) save(db *DB, table string, session *pb.Session) (e error) { if e != nil { return fmt.Errorf("cannot create sqlfs file %s: %v", table, e) } - defer sqlf.Close() // Use a bytes.Buffer as the gob message container to separate // the message from the following tarball. @@ -58,6 +57,10 @@ func (m *model) save(db *DB, table string, session *pb.Session) (e error) { if e := cmd.Run(); e != nil { return fmt.Errorf("tar stderr: %v\ntar cmd %v", errBuf.String(), e) } + + if e := sqlf.Close(); e != nil { + return fmt.Errorf("close sqlfs error: %v", e) + } return nil } From 1ed77eab838bba2af6792cbf7f681d902fa645c5 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Dec 2019 17:54:34 +0800 Subject: [PATCH 2/5] update --- cmd/repl/repl.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/cmd/repl/repl.go b/cmd/repl/repl.go index cfadd6c44a..a67b84ba2d 100644 --- a/cmd/repl/repl.go +++ b/cmd/repl/repl.go @@ -64,7 +64,7 @@ func header(head map[string]interface{}) ([]string, error) { return cols, nil } -func render(rsp interface{}, table *tablewriter.Table) (bool, error) { +func render(rsp interface{}, table *tablewriter.Table) bool { isTable := false switch s := rsp.(type) { case map[string]interface{}: // table header @@ -81,17 +81,16 @@ func render(rsp interface{}, table *tablewriter.Table) (bool, error) { table.Append(row) isTable = true case error: - log.Fatalf("run sql statement failed, error: %v", v) + log.Fatalf("run sql statement failed, error: %v", s) case sql.EndOfExecution: - return isTable, nil + return isTable case string: fmt.Println(s) - return false, nil + return false default: - - fmt.Println(s) + log.Fatal("unrecognized response type: %v", s) } - return isTable, nil + return isTable } func flagPassed(name ...string) bool { @@ -118,7 +117,7 @@ func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds strin for rsp := range stream.ReadAll() { isTable, err := render(rsp, table) if err != nil { - return fmt.Errorf("") + return err } // pagination. avoid exceed memory if isTable && table.NumLines() == tablePageSize { From 42ca1963b55d4b2f6134378c831582d1ab6a2680 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Dec 2019 19:05:12 +0800 Subject: [PATCH 3/5] update --- cmd/repl/repl.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cmd/repl/repl.go b/cmd/repl/repl.go index a67b84ba2d..b9590f3bcd 100644 --- a/cmd/repl/repl.go +++ b/cmd/repl/repl.go @@ -115,12 +115,8 @@ func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds strin stream := sql.RunSQLProgram(stmt, db, modelDir, sess) for rsp := range stream.ReadAll() { - isTable, err := render(rsp, table) - if err != nil { - return err - } // pagination. avoid exceed memory - if isTable && table.NumLines() == tablePageSize { + if render(rsp, table) && table.NumLines() == tablePageSize { table.Render() tableRendered = true table.ClearRows() From 48652fe2d6d85e77ac2ac72a37c03c94ac0a2044 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Dec 2019 19:22:39 +0800 Subject: [PATCH 4/5] defer sqfs.Close --- pkg/sql/model.go | 1 + pkg/sqlfs/hive_writer.go | 22 ++++++++++++++++++++++ pkg/sqlfs/writer.go | 18 ++++-------------- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/pkg/sql/model.go b/pkg/sql/model.go index b92c2a2337..42d8db5ad4 100644 --- a/pkg/sql/model.go +++ b/pkg/sql/model.go @@ -86,6 +86,7 @@ func (m *model) saveDB(db *DB, table string, session *pb.Session) (e error) { if e != nil { return fmt.Errorf("cannot create sqlfs file %s: %v", table, e) } + defer sqlf.Close() // Use a bytes.Buffer as the gob message container to separate // the message from the following tarball. diff --git a/pkg/sqlfs/hive_writer.go b/pkg/sqlfs/hive_writer.go index 527a53595f..7210de2368 100644 --- a/pkg/sqlfs/hive_writer.go +++ b/pkg/sqlfs/hive_writer.go @@ -14,8 +14,10 @@ package sqlfs import ( + "database/sql" "encoding/base64" "fmt" + "io/ioutil" "os" "os/exec" @@ -29,6 +31,23 @@ type HiveWriter struct { session *pb.Session } +// NewHiveWriter returns a Hive Writer object +func NewHiveWriter(db *sql.DB, table string, session *pb.Session) (*HiveWriter, error) { + csvFile, e := ioutil.TempFile("/tmp", "sqlflow-sqlfs") + if e != nil { + return nil, fmt.Errorf("create temporary csv file failed: %v", e) + } + return &HiveWriter{ + Writer: Writer{ + db: db, + table: table, + buf: make([]byte, 0, bufSize), + flushID: 0, + }, + csvFile: csvFile, + session: session}, nil +} + // Write write bytes to sqlfs and returns (num_bytes, error) func (w *HiveWriter) Write(p []byte) (n int, e error) { n = 0 @@ -51,6 +70,9 @@ func (w *HiveWriter) Write(p []byte) (n int, e error) { // Close the connection of the sqlfs func (w *HiveWriter) Close() error { + if w.db == nil { + return nil + } defer func() { w.csvFile.Close() os.Remove(w.csvFile.Name()) diff --git a/pkg/sqlfs/writer.go b/pkg/sqlfs/writer.go index 2ca77f32ee..e00f5c62b9 100644 --- a/pkg/sqlfs/writer.go +++ b/pkg/sqlfs/writer.go @@ -18,7 +18,6 @@ import ( "encoding/base64" "fmt" "io" - "io/ioutil" pb "sqlflow.org/sqlflow/pkg/proto" ) @@ -45,20 +44,11 @@ func Create(db *sql.DB, driver, table string, session *pb.Session) (io.WriteClos } if driver == "hive" { - // HiveWriter implement can archive better performance - csvFile, e := ioutil.TempFile("/tmp", "sqlflow-sqlfs") - if e != nil { - return nil, fmt.Errorf("create temporary csv file failed: %v", e) + w, err := NewHiveWriter(db, table, session) + if err != nil { + return nil, fmt.Errorf("create: %v", err) } - return &HiveWriter{ - Writer: Writer{ - db: db, - table: table, - buf: make([]byte, 0, bufSize), - flushID: 0, - }, - csvFile: csvFile, - session: session}, nil + return w, nil } // default writer implement return &Writer{db, table, make([]byte, 0, bufSize), 0}, nil From 520af56258d14612d5301d7a8251f01c923bb3a2 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 3 Dec 2019 19:57:34 +0800 Subject: [PATCH 5/5] update --- cmd/repl/repl.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/repl/repl.go b/cmd/repl/repl.go index b9590f3bcd..2824027f07 100644 --- a/cmd/repl/repl.go +++ b/cmd/repl/repl.go @@ -88,7 +88,7 @@ func render(rsp interface{}, table *tablewriter.Table) bool { fmt.Println(s) return false default: - log.Fatal("unrecognized response type: %v", s) + log.Fatalf("unrecognized response type: %v", s) } return isTable }