Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Submit Argo workflow #1252

Merged
merged 4 commits into from
Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cmd/sqlflowserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func newServer(caCrt, caKey string) (*grpc.Server, error) {
return s, nil
}

func start(modelDir, caCrt, caKey string, port int) {
func start(modelDir, caCrt, caKey string, port int, isArgoMode bool) {
s, err := newServer(caCrt, caKey)
if err != nil {
log.Fatalf("failed to create new gRPC Server: %v", err)
Expand All @@ -59,8 +59,12 @@ func start(modelDir, caCrt, caKey string, port int) {
os.Mkdir(modelDir, os.ModePerm)
}
}
if isArgoMode {
proto.RegisterSQLFlowServer(s, server.NewServer(sf.SubmitWorkflow, modelDir))
} else {
proto.RegisterSQLFlowServer(s, server.NewServer(sf.RunSQLProgram, modelDir))
}

proto.RegisterSQLFlowServer(s, server.NewServer(sf.RunSQLProgram, modelDir))
listenString := fmt.Sprintf(":%d", port)

lis, err := net.Listen("tcp", listenString)
Expand All @@ -81,6 +85,7 @@ func main() {
caCrt := flag.String("ca-crt", "", "CA certificate file.")
caKey := flag.String("ca-key", "", "CA private key file.")
port := flag.Int("port", 50051, "TCP port to listen on.")
isArgoMode := flag.Bool("argo-mode", false, "Enable Argo workflow model.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also need to configure the k8s API endpoint address in the future?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point and I will add it when implementing the Fetch interface.

flag.Parse()
start(*modelDir, *caCrt, *caKey, *port)
start(*modelDir, *caCrt, *caKey, *port, *isArgoMode)
}
10 changes: 5 additions & 5 deletions cmd/sqlflowserver/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func TestEnd2EndMySQL(t *testing.T) {
t.Fatalf("failed to generate CA pair %v", err)
}

go start(modelDir, caCrt, caKey, unitestPort)
go start(modelDir, caCrt, caKey, unitestPort, false)
waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0)
err = prepareTestData(dbConnStr)
if err != nil {
Expand Down Expand Up @@ -359,7 +359,7 @@ func TestEnd2EndHive(t *testing.T) {
t.Skip("Skipping hive tests")
}
dbConnStr = "hive://root:[email protected]:10000/iris?auth=NOSASL"
go start(modelDir, caCrt, caKey, unitestPort)
go start(modelDir, caCrt, caKey, unitestPort, false)
waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0)
err = prepareTestData(dbConnStr)
if err != nil {
Expand Down Expand Up @@ -396,7 +396,7 @@ func TestEnd2EndMaxCompute(t *testing.T) {
SK := os.Getenv("MAXCOMPUTE_SK")
endpoint := os.Getenv("MAXCOMPUTE_ENDPOINT")
dbConnStr = fmt.Sprintf("maxcompute://%s:%s@%s", AK, SK, endpoint)
go start(modelDir, caCrt, caKey, unitestPort)
go start(modelDir, caCrt, caKey, unitestPort, false)
waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0)
err = prepareTestData(dbConnStr)
if err != nil {
Expand Down Expand Up @@ -440,7 +440,7 @@ func TestEnd2EndMaxComputeALPS(t *testing.T) {
t.Fatalf("prepare test dataset failed: %v", err)
}

go start(modelDir, caCrt, caKey, unitestPort)
go start(modelDir, caCrt, caKey, unitestPort, false)
waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0)

t.Run("CaseTrainALPS", CaseTrainALPS)
Expand Down Expand Up @@ -479,7 +479,7 @@ func TestEnd2EndMaxComputeElasticDL(t *testing.T) {
t.Fatalf("prepare test dataset failed: %v", err)
}

go start(modelDir, caCrt, caKey, unitestPort)
go start(modelDir, caCrt, caKey, unitestPort, false)
waitPortReady(fmt.Sprintf("localhost:%d", unitestPort), 0)

t.Run("CaseTrainElasticDL", CaseTrainElasticDL)
Expand Down
27 changes: 15 additions & 12 deletions pkg/server/proto/sqlflow.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,29 @@ service SQLFlow {
// SQL statements like `SELECT ...`, `DESCRIBE ...` returns a rowset.
// The rowset might be big. In such cases, Query returns a stream
// of RunResponse
//
// SQLFlow implements the Run interface with two mode:
//
// 1. Local model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will the "local mode" move to couler "docker mode"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The local mode here is the current implementation. Will move to docker mode if the docker mode is ready in Couler.

// The SQLFlow server execute the SQL statments on the local host.
//
// SQL statements like `USE database`, `DELETE` returns only a success
// message.
//
// SQL statement like `SELECT ... TO TRAIN/PREDICT ...` returns a stream of
// messages which indicates the training/predicting progress
//
// 2. Argo Workflow mode
// The SQLFlow server submites an Argo workflow into a Kubernetes cluster,
// and returns a stream of messages indicates the WorkFlow ID and the
// submiting progress.
//
// The SQLFlow gRPC client should fetch the logs of the workflow by
// calling the Fetch interface in a polling manner.
rpc Run (Request) returns (stream Response);

// Submit a SQLFlow Job which contains a SQL program to SQLFlow server.
//
// A SQL program contains one or more SQL statments.
// Each of these SQL statments can be a standard SQL like:
// `SELECT ... FROM ...;`, `DESCRIBE ...`,
// or an extended SQLFlow SQL like:
// `SELECT ... TO TRAIN/PREDICT/EXPLAIN ...`.
//
// Submit returns a Job message which contains the SQLFlow Job ID.
rpc Submit (Request) returns (Job);

// Fetch fetchs the SQLFlow job status and logs in a polling manner.
rpc Fetch (Job) returns(JobStatus);
rpc Fetch (Job) returns (JobStatus);
}

message Job {
Expand Down Expand Up @@ -78,6 +80,7 @@ message Response {
Row row = 2;
Message message = 3;
EndOfExecution eoe = 4;
Job job = 5;
}
}

Expand Down
13 changes: 6 additions & 7 deletions pkg/server/sqlflowserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ import (
)

// NewServer returns a server instance
func NewServer(run func(string, *sf.DB, string, *pb.Session) *sf.PipeReader, modelDir string) *Server {
func NewServer(run func(string, *sf.DB, string, *pb.Session) *sf.PipeReader,
modelDir string) *Server {
return &Server{run: run, modelDir: modelDir}
}

Expand All @@ -44,12 +45,6 @@ type Server struct {
modelDir string
}

// Submit implements `rpc Submit (Request) returns (Job)`
func (s *Server) Submit(ctx context.Context, in *pb.Request) (*pb.Job, error) {
job := &pb.Job{}
return job, nil
}

// Fetch implements `rpc Fetch (Job) returns(JobStatus)`
func (s *Server) Fetch(ctx context.Context, job *pb.Job) (*pb.JobStatus, error) {
js := &pb.JobStatus{}
Expand All @@ -59,6 +54,7 @@ func (s *Server) Fetch(ctx context.Context, job *pb.Job) (*pb.JobStatus, error)
// Run implements `rpc Run (Request) returns (stream Response)`
func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error {
var db *sf.DB

var err error
if db, err = sf.NewDB(req.Session.DbConnStr); err != nil {
return fmt.Errorf("create DB failed: %v", err)
Expand All @@ -82,6 +78,9 @@ func (s *Server) Run(req *pb.Request, stream pb.SQLFlow_RunServer) error {
res, err = encodeRow(s)
case string:
res, err = encodeMessage(s)
case sf.WorkflowJob:
job := r.(sf.WorkflowJob)
res = &pb.Response{Response: &pb.Response_Job{Job: &pb.Job{Id: job.JobID}}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you also update the pysqlflow, please refer to this PR in the pysqlflow PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, there is a task in the TODO list #1066

case sf.EndOfExecution:
// if sqlStatements have only one field, do **NOT** return EndOfExecution message.
if len(sqlStatements) > 1 {
Expand Down
42 changes: 41 additions & 1 deletion pkg/sql/executor_ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ type EndOfExecution struct {
Statement string
}

// WorkflowJob indicates the Argo Workflow ID
type WorkflowJob struct {
JobID string
}

var envSubmitter = os.Getenv("SQLFLOW_submitter")

// SubmitterType is the type of SQLFlow submitter
Expand Down Expand Up @@ -88,14 +93,49 @@ func RunSQLProgram(sqlProgram string, db *DB, modelDir string, session *pb.Sessi
return rd
}

// SubmitWorkflow submits an Argo workflow
func SubmitWorkflow(sqlProgram string, db *DB, modelDir string, session *pb.Session) *PipeReader {
rd, wr := Pipe()
go func() {
defer wr.Close()
err := submitWorkflow(wr, sqlProgram, db, modelDir, session)
if err != nil {
log.Errorf("submit Workflow error: %v", err)
}
}()
return rd
}

func submitWorkflow(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, session *pb.Session) error {
sqls, err := parse(db.driverName, sqlProgram)
if err != nil {
return err
}

connStr := fmt.Sprintf("%s://%s", db.driverName, db.dataSourceName)
_, err = programToIR(sqls, connStr, modelDir, true, false)
if err != nil {
return err
}

// TODO(yancey1989):
// 1. call codegen_couler.go to genearte Couler program.
// 2. compile Couler program into Argo YAML.
// 3. submit Argo YAML and fetch the workflow ID.

return wr.Write(WorkflowJob{
JobID: "sqlflow-workflow",
})
}

func runSQLProgram(wr *PipeWriter, sqlProgram string, db *DB, modelDir string, session *pb.Session) error {
sqls, err := parse(db.driverName, sqlProgram)
if err != nil {
return err
}

connStr := fmt.Sprintf("%s://%s", db.driverName, db.dataSourceName)
programIR, err := programToIR(sqls, connStr, modelDir, submitter() != SubmitterPAI)
programIR, err := programToIR(sqls, connStr, modelDir, submitter() != SubmitterPAI, true /*enableFeatureDerivation = true*/)
if err != nil {
return err
}
Expand Down
17 changes: 17 additions & 0 deletions pkg/sql/executor_ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,20 @@ func TestLogChanWriter_Write(t *testing.T) {
_, more := <-c
a.False(more)
}

func TestSubmitWorkflow(t *testing.T) {
a := assert.New(t)
modelDir := ""
a.NotPanics(func() {
rd := SubmitWorkflow(testXGBoostTrainSelectIris, testDB, modelDir, getDefaultSession())
for r := range rd.ReadAll() {
switch r.(type) {
case WorkflowJob:
job := r.(WorkflowJob)
a.Equal(job.JobID, "sqlflow-workflow")
default:
a.Fail("SubmitWorkflow should return JobID")
}
}
})
}
10 changes: 8 additions & 2 deletions pkg/sql/ir_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,13 +716,19 @@ func parseResultTable(intoStatement string) (string, string, error) {
}

// programToIR generate a list of IRs from a SQL program
func programToIR(sqls []statementParseResult, connStr, modelDir string, getTrainIRFromModel bool) (codegen.SQLProgramIR, error) {
func programToIR(sqls []statementParseResult, connStr, modelDir string, getTrainIRFromModel bool, enableFeatureDerivation bool) (codegen.SQLProgramIR, error) {
IRs := codegen.SQLProgramIR{}
for _, sql := range sqls {
if sql.extended != nil {
parsed := sql.extended
if parsed.train {
ir, err := generateTrainIRWithInferredColumns(parsed, connStr)
var ir *codegen.TrainIR
var err error
if enableFeatureDerivation {
ir, err = generateTrainIRWithInferredColumns(parsed, connStr)
} else {
ir, err = generateTrainIR(parsed, connStr)
}
if err != nil {
return nil, err
}
Expand Down