diff --git a/cmd/sqlflowserver/main.go b/cmd/sqlflowserver/main.go index ebeee1d78f..4a1a2d3c6e 100644 --- a/cmd/sqlflowserver/main.go +++ b/cmd/sqlflowserver/main.go @@ -48,7 +48,7 @@ func newServer(caCrt, caKey string) (*grpc.Server, error) { return s, nil } -func start(modelDir, caCrt, caKey string, port int) { +func start(modelDir, caCrt, caKey string, port int, isArgoMode bool) { s, err := newServer(caCrt, caKey) if err != nil { log.Fatalf("failed to create new gRPC Server: %v", err) @@ -59,8 +59,12 @@ func start(modelDir, caCrt, caKey string, port int) { os.Mkdir(modelDir, os.ModePerm) } } + if isArgoMode { + proto.RegisterSQLFlowServer(s, server.NewServer(sf.SubmitWorkflow, modelDir)) + } else { + proto.RegisterSQLFlowServer(s, server.NewServer(sf.RunSQLProgram, modelDir)) + } - proto.RegisterSQLFlowServer(s, server.NewServer(sf.RunSQLProgram, modelDir)) listenString := fmt.Sprintf(":%d", port) lis, err := net.Listen("tcp", listenString) @@ -81,6 +85,7 @@ func main() { caCrt := flag.String("ca-crt", "", "CA certificate file.") caKey := flag.String("ca-key", "", "CA private key file.") port := flag.Int("port", 50051, "TCP port to listen on.") + isArgoMode := flag.Bool("argo-mode", false, "Enable Argo workflow model.") flag.Parse() - start(*modelDir, *caCrt, *caKey, *port) + start(*modelDir, *caCrt, *caKey, *port, *isArgoMode) } diff --git a/cmd/sqlflowserver/main_test.go b/cmd/sqlflowserver/main_test.go index 6ef875b22d..28603f8bff 100644 --- a/cmd/sqlflowserver/main_test.go +++ b/cmd/sqlflowserver/main_test.go @@ -265,7 +265,7 @@ func TestEnd2EndMySQL(t *testing.T) { t.Fatalf("failed to generate CA pair %v", err) } - go start(modelDir, caCrt, caKey, unitestPort) + go start(modelDir, caCrt, caKey, unitestPort, false) waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0) err = prepareTestData(dbConnStr) if err != nil { @@ -359,7 +359,7 @@ func TestEnd2EndHive(t *testing.T) { t.Skip("Skipping hive tests") } dbConnStr = "hive://root:root@127.0.0.1:10000/iris?auth=NOSASL" - go start(modelDir, caCrt, caKey, unitestPort) + go start(modelDir, caCrt, caKey, unitestPort, false) waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0) err = prepareTestData(dbConnStr) if err != nil { @@ -396,7 +396,7 @@ func TestEnd2EndMaxCompute(t *testing.T) { SK := os.Getenv("MAXCOMPUTE_SK") endpoint := os.Getenv("MAXCOMPUTE_ENDPOINT") dbConnStr = fmt.Sprintf("maxcompute://%s:%s@%s", AK, SK, endpoint) - go start(modelDir, caCrt, caKey, unitestPort) + go start(modelDir, caCrt, caKey, unitestPort, false) waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0) err = prepareTestData(dbConnStr) if err != nil { @@ -440,7 +440,7 @@ func TestEnd2EndMaxComputeALPS(t *testing.T) { t.Fatalf("prepare test dataset failed: %v", err) } - go start(modelDir, caCrt, caKey, unitestPort) + go start(modelDir, caCrt, caKey, unitestPort, false) waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0) t.Run("CaseTrainALPS", CaseTrainALPS) @@ -479,7 +479,7 @@ func TestEnd2EndMaxComputeElasticDL(t *testing.T) { t.Fatalf("prepare test dataset failed: %v", err) } - go start(modelDir, caCrt, caKey, unitestPort) + go start(modelDir, caCrt, caKey, unitestPort, false) waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0) t.Run("CaseTrainElasticDL", CaseTrainElasticDL) diff --git a/pkg/server/proto/sqlflow.proto b/pkg/server/proto/sqlflow.proto index 71d8a4fa10..b301662231 100644 --- a/pkg/server/proto/sqlflow.proto +++ b/pkg/server/proto/sqlflow.proto @@ -10,27 +10,29 @@ service SQLFlow { // SQL statements like `SELECT ...`, `DESCRIBE ...` returns a rowset. // The rowset might be big. In such cases, Query returns a stream // of RunResponse + // + // SQLFlow implements the Run interface with two mode: // + // 1. Local model + // The SQLFlow server execute the SQL statments on the local host. + // // SQL statements like `USE database`, `DELETE` returns only a success // message. // // SQL statement like `SELECT ... TO TRAIN/PREDICT ...` returns a stream of // messages which indicates the training/predicting progress + // + // 2. Argo Workflow mode + // The SQLFlow server submites an Argo workflow into a Kubernetes cluster, + // and returns a stream of messages indicates the WorkFlow ID and the + // submiting progress. + // + // The SQLFlow gRPC client should fetch the logs of the workflow by + // calling the Fetch interface in a polling manner. rpc Run (Request) returns (stream Response); - // Submit a SQLFlow Job which contains a SQL program to SQLFlow server. - // - // A SQL program contains one or more SQL statments. - // Each of these SQL statments can be a standard SQL like: - // `SELECT ... FROM ...;`, `DESCRIBE ...`, - // or an extended SQLFlow SQL like: - // `SELECT ... TO TRAIN/PREDICT/EXPLAIN ...`. - // - // Submit returns a Job message which contains the SQLFlow Job ID. - rpc Submit (Request) returns (Job); - // Fetch fetchs the SQLFlow job status and logs in a polling manner. - rpc Fetch (Job) returns(JobStatus); + rpc Fetch (Job) returns (JobStatus); } message Job { @@ -78,6 +80,7 @@ message Response { Row row = 2; Message message = 3; EndOfExecution eoe = 4; + Job job = 5; } } diff --git a/pkg/server/sqlflowserver.go b/pkg/server/sqlflowserver.go index 864a78d484..7436b6b2ff 100644 --- a/pkg/server/sqlflowserver.go +++ b/pkg/server/sqlflowserver.go @@ -34,7 +34,8 @@ import ( ) // NewServer returns a server instance -func NewServer(run func(string, *sf.DB, string, *pb.Session) *sf.PipeReader, modelDir string) *Server { +func NewServer(run func(string, *sf.DB, string, *pb.Session) *sf.PipeReader, + modelDir string) *Server { return &Server{run: run, modelDir: modelDir} } @@ -44,12 +45,6 @@ type Server struct { modelDir string } -// Submit implements `rpc Submit (Request) returns (Job)` -func (s *Server) Submit(ctx context.Context, in *pb.Request) (*pb.Job, error) { - job := &pb.Job{} - return job, nil -} - // Fetch implements `rpc Fetch (Job) returns(JobStatus)` func (s *Server) Fetch(ctx context.Context, job *pb.Job) (*pb.JobStatus, error) { js := &pb.JobStatus{} @@ -59,6 +54,7 @@ func (s *Server) Fetch(ctx context.Context, job *pb.Job) (*pb.JobStatus, error) // Run implements `rpc Run (Request) returns (stream Response)` func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error { var db *sf.DB + var err error if db, err = sf.NewDB(req.Session.DbConnStr); err != nil { return fmt.Errorf("create DB failed: %v", err) @@ -82,6 +78,9 @@ func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error { res, err = encodeRow(s) case string: res, err = encodeMessage(s) + case sf.WorkflowJob: + job := r.(sf.WorkflowJob) + res = &pb.Response{Response: &pb.Response_Job{Job: &pb.Job{Id: job.JobID}}} case sf.EndOfExecution: // if sqlStatements have only one field, do **NOT** return EndOfExecution message. if len(sqlStatements) > 1 { diff --git a/pkg/sql/executor_ir.go b/pkg/sql/executor_ir.go index 0bb6ca9c47..9a04ce8a47 100644 --- a/pkg/sql/executor_ir.go +++ b/pkg/sql/executor_ir.go @@ -40,6 +40,11 @@ type EndOfExecution struct { Statement string } +// WorkflowJob indicates the Argo Workflow ID +type WorkflowJob struct { + JobID string +} + var envSubmitter = os.Getenv("SQLFLOW_submitter") // SubmitterType is the type of SQLFlow submitter @@ -88,6 +93,41 @@ func RunSQLProgram(sqlProgram string, db *DB, modelDir string, session *pb.Sessi return rd } +// SubmitWorkflow submits an Argo workflow +func SubmitWorkflow(sqlProgram string, db *DB, modelDir string, session *pb.Session) *PipeReader { + rd, wr := Pipe() + go func() { + defer wr.Close() + err := submitWorkflow(wr, sqlProgram, db, modelDir, session) + if err != nil { + log.Errorf("submit Workflow error: %v", err) + } + }() + return rd +} + +func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, session *pb.Session) error { + sqls, err := parse(db.driverName, sqlProgram) + if err != nil { + return err + } + + connStr := fmt.Sprintf("%s://%s", db.driverName, db.dataSourceName) + _, err = programToIR(sqls, connStr, modelDir, true, false) + if err != nil { + return err + } + + // TODO(yancey1989): + // 1. call codegen_couler.go to genearte Couler program. + // 2. compile Couler program into Argo YAML. + // 3. submit Argo YAML and fetch the workflow ID. + + return wr.Write(WorkflowJob{ + JobID: "sqlflow-workflow", + }) +} + func runSQLProgram(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, session *pb.Session) error { sqls, err := parse(db.driverName, sqlProgram) if err != nil { @@ -95,7 +135,7 @@ func runSQLProgram(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, s } connStr := fmt.Sprintf("%s://%s", db.driverName, db.dataSourceName) - programIR, err := programToIR(sqls, connStr, modelDir, submitter() != SubmitterPAI) + programIR, err := programToIR(sqls, connStr, modelDir, submitter() != SubmitterPAI, true /*enableFeatureDerivation = true*/) if err != nil { return err } diff --git a/pkg/sql/executor_ir_test.go b/pkg/sql/executor_ir_test.go index 5cef0c7820..1dcd1f7567 100644 --- a/pkg/sql/executor_ir_test.go +++ b/pkg/sql/executor_ir_test.go @@ -224,3 +224,20 @@ func TestLogChanWriter_Write(t *testing.T) { _, more := <-c a.False(more) } + +func TestSubmitWorkflow(t *testing.T) { + a := assert.New(t) + modelDir := "" + a.NotPanics(func() { + rd := SubmitWorkflow(testXGBoostTrainSelectIris, testDB, modelDir, getDefaultSession()) + for r := range rd.ReadAll() { + switch r.(type) { + case WorkflowJob: + job := r.(WorkflowJob) + a.Equal(job.JobID, "sqlflow-workflow") + default: + a.Fail("SubmitWorkflow should return JobID") + } + } + }) +} diff --git a/pkg/sql/ir_generator.go b/pkg/sql/ir_generator.go index 7fbeb4287a..1c1f128c04 100644 --- a/pkg/sql/ir_generator.go +++ b/pkg/sql/ir_generator.go @@ -716,13 +716,19 @@ func parseResultTable(intoStatement string) (string, string, error) { } // programToIR generate a list of IRs from a SQL program -func programToIR(sqls []statementParseResult, connStr, modelDir string, getTrainIRFromModel bool) (ir.SQLProgram, error) { +func programToIR(sqls []statementParseResult, connStr, modelDir string, getTrainIRFromModel bool, enableFeatureDerivation bool) (ir.SQLProgram, error) { IRs := ir.SQLProgram{} for _, sql := range sqls { if sql.extended != nil { parsed := sql.extended if parsed.train { - ir, err := generateTrainIRWithInferredColumns(parsed, connStr) + var ir *ir.TrainClause + var err error + if enableFeatureDerivation { + ir, err = generateTrainIRWithInferredColumns(parsed, connStr) + } else { + ir, err = generateTrainIR(parsed, connStr) + } if err != nil { return nil, err }