diff --git a/cmd/repl/repl.go b/cmd/repl/repl.go index a47a6ec277..2824027f07 100644 --- a/cmd/repl/repl.go +++ b/cmd/repl/repl.go @@ -81,13 +81,14 @@ func render(rsp interface{}, table *tablewriter.Table) bool { table.Append(row) isTable = true case error: - if os.Getenv("SQLFLOW_log_dir") != "" { // To avoid printing duplicated error message to console - log.New(os.Stderr, "", 0).Printf("ERROR: %v\n", s) - } + log.Fatalf("run sql statement failed, error: %v", s) case sql.EndOfExecution: return isTable - default: + case string: fmt.Println(s) + return false + default: + log.Fatalf("unrecognized response type: %v", s) } return isTable } @@ -104,14 +105,15 @@ func flagPassed(name ...string) bool { return found } -func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds string) { +func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds string) error { if !isTerminal { fmt.Println("sqlflow>", stmt) } tableRendered := false table := tablewriter.NewWriter(os.Stdout) + sess := makeSessionFromEnv() - stream := sql.RunSQLProgram(stmt, db, modelDir, &pb.Session{}) + stream := sql.RunSQLProgram(stmt, db, modelDir, sess) for rsp := range stream.ReadAll() { // pagination. avoid exceed memory if render(rsp, table) && table.NumLines() == tablePageSize { @@ -123,6 +125,7 @@ func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds strin if table.NumLines() > 0 || !tableRendered { table.Render() } + return nil } func repl(scanner *bufio.Scanner, modelDir string, db *sql.DB, ds string) { @@ -132,9 +135,23 @@ func repl(scanner *bufio.Scanner, modelDir string, db *sql.DB, ds string) { if err == io.EOF && stmt == "" { return } - runStmt(stmt, false, modelDir, db, ds) + if err := runStmt(stmt, false, modelDir, db, ds); err != nil { + log.Fatalf("run SQL statment failed: %v", err) + } } +} +func makeSessionFromEnv() *pb.Session { + return &pb.Session{ + Token: os.Getenv("SQLFLOW_USER_TOKEN"), + DbConnStr: os.Getenv("SQLFLOW_DATASOURCE"), + ExitOnSubmit: strings.ToLower(os.Getenv("SQLFLOW_EXIT_ON_SUBMIT")) == "true", + UserId: os.Getenv("SQLFLOW_USER_ID"), + HiveLocation: os.Getenv("SQLFLOW_HIVE_LOCATION"), + HdfsNamenodeAddr: os.Getenv("SQLFLOW_HDFS_NAMENODE_ADDR"), + HdfsUser: os.Getenv("JUPYTER_HADOOP_USER"), + HdfsPass: os.Getenv("JUPYTER_HADOOP_PASS"), + } } func parseSQLFromStdin(stdin io.Reader) (string, error) { @@ -150,17 +167,7 @@ func parseSQLFromStdin(stdin io.Reader) (string, error) { if sqlflowDatasource == "" { return "", fmt.Errorf("no SQLFLOW_DATASOURCE env provided") } - - sess := &pb.Session{ - Token: os.Getenv("SQLFLOW_USER_TOKEN"), - DbConnStr: os.Getenv("SQLFLOW_DATASOURCE"), - ExitOnSubmit: strings.ToLower(os.Getenv("SQLFLOW_EXIT_ON_SUBMIT")) == "true", - UserId: os.Getenv("SQLFLOW_USER_ID"), - HiveLocation: os.Getenv("SQLFLOW_HIVE_LOCATION"), - HdfsNamenodeAddr: os.Getenv("SQLFLOW_HDFS_NAMENODE_ADDR"), - HdfsUser: os.Getenv("JUPYTER_HADOOP_USER"), - HdfsPass: os.Getenv("JUPYTER_HADOOP_PASS"), - } + sess := makeSessionFromEnv() pbIRStr, err := sql.ParseSQLStatement(strings.Join(scanedInput, "\n"), sess) if err != nil { return "", err diff --git a/pkg/sql/model.go b/pkg/sql/model.go index 430ccdc4f9..42d8db5ad4 100644 --- a/pkg/sql/model.go +++ b/pkg/sql/model.go @@ -105,6 +105,10 @@ func (m *model) saveDB(db *DB, table string, session *pb.Session) (e error) { if e := cmd.Run(); e != nil { return fmt.Errorf("tar stderr: %v\ntar cmd %v", errBuf.String(), e) } + + if e := sqlf.Close(); e != nil { + return fmt.Errorf("close sqlfs error: %v", e) + } return nil } diff --git a/pkg/sqlfs/hive_writer.go b/pkg/sqlfs/hive_writer.go index 527a53595f..7210de2368 100644 --- a/pkg/sqlfs/hive_writer.go +++ b/pkg/sqlfs/hive_writer.go @@ -14,8 +14,10 @@ package sqlfs import ( + "database/sql" "encoding/base64" "fmt" + "io/ioutil" "os" "os/exec" @@ -29,6 +31,23 @@ type HiveWriter struct { session *pb.Session } +// NewHiveWriter returns a Hive Writer object +func NewHiveWriter(db *sql.DB, table string, session *pb.Session) (*HiveWriter, error) { + csvFile, e := ioutil.TempFile("/tmp", "sqlflow-sqlfs") + if e != nil { + return nil, fmt.Errorf("create temporary csv file failed: %v", e) + } + return &HiveWriter{ + Writer: Writer{ + db: db, + table: table, + buf: make([]byte, 0, bufSize), + flushID: 0, + }, + csvFile: csvFile, + session: session}, nil +} + // Write write bytes to sqlfs and returns (num_bytes, error) func (w *HiveWriter) Write(p []byte) (n int, e error) { n = 0 @@ -51,6 +70,9 @@ func (w *HiveWriter) Write(p []byte) (n int, e error) { // Close the connection of the sqlfs func (w *HiveWriter) Close() error { + if w.db == nil { + return nil + } defer func() { w.csvFile.Close() os.Remove(w.csvFile.Name()) diff --git a/pkg/sqlfs/writer.go b/pkg/sqlfs/writer.go index 2ca77f32ee..e00f5c62b9 100644 --- a/pkg/sqlfs/writer.go +++ b/pkg/sqlfs/writer.go @@ -18,7 +18,6 @@ import ( "encoding/base64" "fmt" "io" - "io/ioutil" pb "sqlflow.org/sqlflow/pkg/proto" ) @@ -45,20 +44,11 @@ func Create(db *sql.DB, driver, table string, session *pb.Session) (io.WriteClos } if driver == "hive" { - // HiveWriter implement can archive better performance - csvFile, e := ioutil.TempFile("/tmp", "sqlflow-sqlfs") - if e != nil { - return nil, fmt.Errorf("create temporary csv file failed: %v", e) + w, err := NewHiveWriter(db, table, session) + if err != nil { + return nil, fmt.Errorf("create: %v", err) } - return &HiveWriter{ - Writer: Writer{ - db: db, - table: table, - buf: make([]byte, 0, bufSize), - flushID: 0, - }, - csvFile: csvFile, - session: session}, nil + return w, nil } // default writer implement return &Writer{db, table, make([]byte, 0, bufSize), 0}, nil