Skip to content

Commit

Permalink
fix repl always return 0 (#1286)
Browse files Browse the repository at this point in the history
* fix repl hive

* update

* update

* defer sqfs.Close

* update
  • Loading branch information
Yancey1989 authored and wangkuiyi committed Dec 4, 2019
1 parent 59bc02a commit fd63870
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 32 deletions.
43 changes: 25 additions & 18 deletions cmd/repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,14 @@ func render(rsp interface{}, table *tablewriter.Table) bool {
table.Append(row)
isTable = true
case error:
if os.Getenv("SQLFLOW_log_dir") != "" { // To avoid printing duplicated error message to console
log.New(os.Stderr, "", 0).Printf("ERROR: %v\n", s)
}
log.Fatalf("run sql statement failed, error: %v", s)
case sql.EndOfExecution:
return isTable
default:
case string:
fmt.Println(s)
return false
default:
log.Fatalf("unrecognized response type: %v", s)
}
return isTable
}
Expand All @@ -104,14 +105,15 @@ func flagPassed(name ...string) bool {
return found
}

func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds string) {
func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds string) error {
if !isTerminal {
fmt.Println("sqlflow>", stmt)
}
tableRendered := false
table := tablewriter.NewWriter(os.Stdout)
sess := makeSessionFromEnv()

stream := sql.RunSQLProgram(stmt, db, modelDir, &pb.Session{})
stream := sql.RunSQLProgram(stmt, db, modelDir, sess)
for rsp := range stream.ReadAll() {
// pagination. avoid exceed memory
if render(rsp, table) && table.NumLines() == tablePageSize {
Expand All @@ -123,6 +125,7 @@ func runStmt(stmt string, isTerminal bool, modelDir string, db *sql.DB, ds strin
if table.NumLines() > 0 || !tableRendered {
table.Render()
}
return nil
}

func repl(scanner *bufio.Scanner, modelDir string, db *sql.DB, ds string) {
Expand All @@ -132,9 +135,23 @@ func repl(scanner *bufio.Scanner, modelDir string, db *sql.DB, ds string) {
if err == io.EOF && stmt == "" {
return
}
runStmt(stmt, false, modelDir, db, ds)
if err := runStmt(stmt, false, modelDir, db, ds); err != nil {
log.Fatalf("run SQL statment failed: %v", err)
}
}
}

func makeSessionFromEnv() *pb.Session {
return &pb.Session{
Token: os.Getenv("SQLFLOW_USER_TOKEN"),
DbConnStr: os.Getenv("SQLFLOW_DATASOURCE"),
ExitOnSubmit: strings.ToLower(os.Getenv("SQLFLOW_EXIT_ON_SUBMIT")) == "true",
UserId: os.Getenv("SQLFLOW_USER_ID"),
HiveLocation: os.Getenv("SQLFLOW_HIVE_LOCATION"),
HdfsNamenodeAddr: os.Getenv("SQLFLOW_HDFS_NAMENODE_ADDR"),
HdfsUser: os.Getenv("JUPYTER_HADOOP_USER"),
HdfsPass: os.Getenv("JUPYTER_HADOOP_PASS"),
}
}

func parseSQLFromStdin(stdin io.Reader) (string, error) {
Expand All @@ -150,17 +167,7 @@ func parseSQLFromStdin(stdin io.Reader) (string, error) {
if sqlflowDatasource == "" {
return "", fmt.Errorf("no SQLFLOW_DATASOURCE env provided")
}

sess := &pb.Session{
Token: os.Getenv("SQLFLOW_USER_TOKEN"),
DbConnStr: os.Getenv("SQLFLOW_DATASOURCE"),
ExitOnSubmit: strings.ToLower(os.Getenv("SQLFLOW_EXIT_ON_SUBMIT")) == "true",
UserId: os.Getenv("SQLFLOW_USER_ID"),
HiveLocation: os.Getenv("SQLFLOW_HIVE_LOCATION"),
HdfsNamenodeAddr: os.Getenv("SQLFLOW_HDFS_NAMENODE_ADDR"),
HdfsUser: os.Getenv("JUPYTER_HADOOP_USER"),
HdfsPass: os.Getenv("JUPYTER_HADOOP_PASS"),
}
sess := makeSessionFromEnv()
pbIRStr, err := sql.ParseSQLStatement(strings.Join(scanedInput, "\n"), sess)
if err != nil {
return "", err
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ func (m *model) saveDB(db *DB, table string, session *pb.Session) (e error) {
if e := cmd.Run(); e != nil {
return fmt.Errorf("tar stderr: %v\ntar cmd %v", errBuf.String(), e)
}

if e := sqlf.Close(); e != nil {
return fmt.Errorf("close sqlfs error: %v", e)
}
return nil
}

Expand Down
22 changes: 22 additions & 0 deletions pkg/sqlfs/hive_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
package sqlfs

import (
"database/sql"
"encoding/base64"
"fmt"
"io/ioutil"
"os"
"os/exec"

Expand All @@ -29,6 +31,23 @@ type HiveWriter struct {
session *pb.Session
}

// NewHiveWriter returns a Hive Writer object
func NewHiveWriter(db *sql.DB, table string, session *pb.Session) (*HiveWriter, error) {
csvFile, e := ioutil.TempFile("/tmp", "sqlflow-sqlfs")
if e != nil {
return nil, fmt.Errorf("create temporary csv file failed: %v", e)
}
return &HiveWriter{
Writer: Writer{
db: db,
table: table,
buf: make([]byte, 0, bufSize),
flushID: 0,
},
csvFile: csvFile,
session: session}, nil
}

// Write write bytes to sqlfs and returns (num_bytes, error)
func (w *HiveWriter) Write(p []byte) (n int, e error) {
n = 0
Expand All @@ -51,6 +70,9 @@ func (w *HiveWriter) Write(p []byte) (n int, e error) {

// Close the connection of the sqlfs
func (w *HiveWriter) Close() error {
if w.db == nil {
return nil
}
defer func() {
w.csvFile.Close()
os.Remove(w.csvFile.Name())
Expand Down
18 changes: 4 additions & 14 deletions pkg/sqlfs/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"encoding/base64"
"fmt"
"io"
"io/ioutil"

pb "sqlflow.org/sqlflow/pkg/proto"
)
Expand All @@ -45,20 +44,11 @@ func Create(db *sql.DB, driver, table string, session *pb.Session) (io.WriteClos
}

if driver == "hive" {
// HiveWriter implement can archive better performance
csvFile, e := ioutil.TempFile("/tmp", "sqlflow-sqlfs")
if e != nil {
return nil, fmt.Errorf("create temporary csv file failed: %v", e)
w, err := NewHiveWriter(db, table, session)
if err != nil {
return nil, fmt.Errorf("create: %v", err)
}
return &HiveWriter{
Writer: Writer{
db: db,
table: table,
buf: make([]byte, 0, bufSize),
flushID: 0,
},
csvFile: csvFile,
session: session}, nil
return w, nil
}
// default writer implement
return &Writer{db, table, make([]byte, 0, bufSize), 0}, nil
Expand Down

0 comments on commit fd63870

Please sign in to comment.