Skip to content

Commit

Permalink
Improve legacy Cog compliance
Browse files Browse the repository at this point in the history
* Run legacy Cog with PYTHONUNBUFFERED
* Send "processing" instead of "start" webhook
* Simplify webhook sequence assertions
  • Loading branch information
nevillelyh committed Dec 13, 2024
1 parent bd8d0f0 commit 4fae8b5
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 387 deletions.
13 changes: 4 additions & 9 deletions internal/server/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,14 @@ func (pr *PendingPrediction) appendLogLine(line string) {

func (pr *PendingPrediction) sendWebhook(event WebhookEvent) {
pr.mu.Lock()
defer func() {
// Only async and iterator predict writes new response per output item with status = "processing"
// For blocking or non-iterator cases, set it here immediately after sending "starting" webhook
if pr.response.Status == PredictionStarting {
pr.response.Status = PredictionProcessing
}
pr.mu.Unlock()
}()
defer pr.mu.Unlock()
if pr.request.Webhook == "" {
return
}
if len(pr.request.WebhookEventsFilter) > 0 && !slices.Contains(pr.request.WebhookEventsFilter, event) {
return
}
if pr.response.Status == PredictionProcessing {
if event == WebhookLogs || event == WebhookOutput {
if time.Since(pr.lastUpdated) < 500*time.Millisecond {
return
}
Expand Down Expand Up @@ -392,6 +385,8 @@ func (r *Runner) handleResponses() {

if pr.response.Status == PredictionStarting {
log.Infow("prediction started", "id", pr.request.Id, "status", pr.response.Status)
// Compat: legacy Cog never sends "start" event
pr.response.Status = PredictionProcessing
pr.sendWebhook(WebhookStart)
} else if pr.response.Status == PredictionProcessing {
log.Infow("prediction processing", "id", pr.request.Id, "status", pr.response.Status)
Expand Down
179 changes: 18 additions & 161 deletions internal/tests/async_prediction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,8 @@ func TestAsyncPredictionSucceeded(t *testing.T) {

ct.AsyncPrediction(map[string]any{"i": 1, "s": "bar"})
wr := ct.WaitForWebhookCompletion()
if *legacyCog {
assert.Len(t, wr, 3)
logs := ""
// Compat: legacy Cog sends no "starting" event
ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs)
// Compat: legacy Cog buffers logging?
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, "*bar*", logs)
logs += "prediction in progress 1/1\n"
logs += "completed prediction\n"
ct.AssertResponse(wr[2], server.PredictionSucceeded, "*bar*", logs)
} else {
assert.Len(t, wr, 5)
logs := ""
ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs)
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs)
logs += "prediction in progress 1/1\n"
ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs)
logs += "completed prediction\n"
ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs)
ct.AssertResponse(wr[4], server.PredictionSucceeded, "*bar*", logs)
}
logs := "starting prediction\nprediction in progress 1/1\ncompleted prediction\n"
ct.AssertResponses(wr, server.PredictionSucceeded, "*bar*", logs)

ct.Shutdown()
assert.NoError(t, ct.Cleanup())
Expand All @@ -64,29 +43,8 @@ func TestAsyncPredictionWithIdSucceeded(t *testing.T) {

ct.AsyncPredictionWithId("p01", map[string]any{"i": 1, "s": "bar"})
wr := ct.WaitForWebhookCompletion()
if *legacyCog {
assert.Len(t, wr, 3)
logs := ""
// Compat: legacy Cog sends no "starting" event
ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs)
// Compat: legacy Cog buffers logging?
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, "*bar*", logs)
logs += "prediction in progress 1/1\n"
logs += "completed prediction\n"
ct.AssertResponse(wr[2], server.PredictionSucceeded, "*bar*", logs)
} else {
assert.Len(t, wr, 5)
logs := ""
ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs)
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs)
logs += "prediction in progress 1/1\n"
ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs)
logs += "completed prediction\n"
ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs)
ct.AssertResponse(wr[4], server.PredictionSucceeded, "*bar*", logs)
}
logs := "starting prediction\nprediction in progress 1/1\ncompleted prediction\n"
ct.AssertResponses(wr, server.PredictionSucceeded, "*bar*", logs)

ct.Shutdown()
assert.NoError(t, ct.Cleanup())
Expand All @@ -104,38 +62,8 @@ func TestAsyncPredictionFailure(t *testing.T) {

ct.AsyncPrediction(map[string]any{"i": 1, "s": "bar"})
wr := ct.WaitForWebhookCompletion()
if *legacyCog {
assert.Len(t, wr, 3)
logs := ""
// Compat: legacy Cog sends no "starting" event
ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs)
assert.Equal(t, server.PredictionProcessing, wr[1].Status)
assert.Equal(t, nil, wr[1].Output)
// Compat: legacy Cog includes worker stacktrace
assert.Contains(t, wr[1].Logs, "Traceback")
// Compat: legacy Cog buffers logging?
logs += "starting prediction\n"
logs += "prediction in progress 1/1\n"
logs += "prediction failed\n"
assert.Equal(t, server.PredictionFailed, wr[2].Status)
assert.Equal(t, nil, wr[2].Output)
// Compat: legacy Cog includes worker stacktrace
assert.Contains(t, wr[2].Logs, "Traceback")
assert.Contains(t, wr[2].Logs, logs)
assert.Equal(t, "prediction failed", wr[2].Error)
} else {
assert.Len(t, wr, 5)
logs := ""
ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs)
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs)
logs += "prediction in progress 1/1\n"
ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs)
logs += "prediction failed\n"
ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs)
ct.AssertResponse(wr[4], server.PredictionFailed, nil, logs)
assert.Equal(t, "prediction failed", wr[4].Error)
}
logs := "starting prediction\nprediction in progress 1/1\nprediction failed\n"
ct.AssertResponses(wr, server.PredictionFailed, nil, logs)

ct.Shutdown()
assert.NoError(t, ct.Cleanup())
Expand All @@ -154,45 +82,14 @@ func TestAsyncPredictionCrash(t *testing.T) {

ct.AsyncPrediction(map[string]any{"i": 1, "s": "bar"})
wr := ct.WaitForWebhookCompletion()
logs := "starting prediction\nprediction in progress 1/1\nprediction crashed\n"
ct.AssertResponses(wr, server.PredictionFailed, nil, logs)
if *legacyCog {
assert.Len(t, wr, 3)
logs := ""
// Compat: legacy Cog sends no "starting" event
ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs)
assert.Equal(t, server.PredictionProcessing, wr[1].Status)
assert.Equal(t, nil, wr[1].Output)
// Compat: legacy Cog includes worker stacktrace
assert.Contains(t, wr[1].Logs, "Traceback")
// Compat: legacy Cog buffers logging?
logs += "starting prediction\n"
logs += "prediction in progress 1/1\n"
logs += "prediction crashed\n"
assert.Equal(t, server.PredictionFailed, wr[2].Status)
assert.Equal(t, nil, wr[2].Output)
// Compat: legacy Cog includes worker stacktrace
assert.Contains(t, wr[2].Logs, "Traceback")
assert.Contains(t, wr[2].Logs, logs)
// Compat: legacy Cog cannot handle worker crash
errMsg := "Prediction failed for an unknown reason. It might have run out of memory? (exitcode 1)"
assert.Equal(t, errMsg, wr[2].Error)
assert.Equal(t, "DEFUNCT", ct.HealthCheck().Status)
assert.Equal(t, "Prediction failed for an unknown reason. It might have run out of memory? (exitcode 1)", wr[len(wr)-1].Error)
} else {
assert.Len(t, wr, 5)
logs := ""
ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs)
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs)
logs += "prediction in progress 1/1\n"
ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs)
logs += "prediction crashed\n"
ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs)
assert.Equal(t, server.PredictionFailed, wr[4].Status)
assert.Equal(t, nil, wr[4].Output)
assert.Contains(t, wr[4].Logs, logs)
assert.Contains(t, wr[4].Logs, "SystemExit: 1\n")
assert.Equal(t, "prediction failed", wr[4].Error)
assert.Equal(t, "DEFUNCT", ct.HealthCheck().Status)
assert.Equal(t, "prediction failed", wr[len(wr)-1].Error)
}
assert.Equal(t, "DEFUNCT", ct.HealthCheck().Status)

ct.Shutdown()
assert.NoError(t, ct.Cleanup())
Expand All @@ -212,36 +109,17 @@ func TestAsyncPredictionCanceled(t *testing.T) {
pid := "p01"
ct.AsyncPredictionWithId(pid, map[string]any{"i": 60, "s": "bar"})
if *legacyCog {
// Compat: legacy Cog buffers logging?
// Compat: legacy Cog does not send output webhook
time.Sleep(time.Second)
ct.Cancel(pid)
wr := ct.WaitForWebhookCompletion()
assert.Len(t, wr, 3)
logs := ""
// Compat: legacy Cog sends no "starting" event
ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs)
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs)
// Compat: legacy Cog buffers logging?
logs += "prediction in progress 1/60\n"
logs += "prediction canceled\n"
ct.AssertResponse(wr[2], server.PredictionCanceled, nil, logs)
} else {
ct.WaitForWebhook(func(response server.PredictionResponse) bool {
return strings.Contains(response.Logs, "prediction in progress 1/60\n")
})
ct.Cancel(pid)
wr := ct.WaitForWebhookCompletion()
assert.Len(t, wr, 4)
logs := ""
ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs)
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs)
logs += "prediction in progress 1/60\n"
ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs)
logs += "prediction canceled\n"
ct.AssertResponse(wr[3], server.PredictionCanceled, nil, logs)
}
ct.Cancel(pid)
wr := ct.WaitForWebhookCompletion()
logs := "starting prediction\nprediction in progress 1/60\nprediction canceled\n"
ct.AssertResponses(wr, server.PredictionCanceled, nil, logs)

ct.Shutdown()
assert.NoError(t, ct.Cleanup())
Expand All @@ -268,29 +146,8 @@ func TestAsyncPredictionConcurrency(t *testing.T) {
assert.Equal(t, http.StatusConflict, resp.StatusCode)

wr := ct.WaitForWebhookCompletion()
if *legacyCog {
assert.Len(t, wr, 3)
logs := ""
// Compat: legacy Cog sends no "starting" event
ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs)
// Compat: legacy Cog buffers logging?
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, "*bar*", logs)
logs += "prediction in progress 1/1\n"
logs += "completed prediction\n"
ct.AssertResponse(wr[2], server.PredictionSucceeded, "*bar*", logs)
} else {
assert.Len(t, wr, 5)
logs := ""
ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs)
logs += "starting prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs)
logs += "prediction in progress 1/1\n"
ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs)
logs += "completed prediction\n"
ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs)
ct.AssertResponse(wr[4], server.PredictionSucceeded, "*bar*", logs)
}
logs := "starting prediction\nprediction in progress 1/1\ncompleted prediction\n"
ct.AssertResponses(wr, server.PredictionSucceeded, "*bar*", logs)

ct.Shutdown()
assert.NoError(t, ct.Cleanup())
Expand Down
55 changes: 9 additions & 46 deletions internal/tests/async_predictor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,10 @@ func TestAsyncPredictorConcurrency(t *testing.T) {
bazR = append(bazR, r)
}
}
assert.Len(t, barR, 5)
assert.Len(t, bazR, 6)

barLogs := ""
ct.AssertResponse(barR[0], server.PredictionStarting, nil, barLogs)
barLogs += "starting async prediction\n"
ct.AssertResponse(barR[1], server.PredictionProcessing, nil, barLogs)
barLogs += "prediction in progress 1/1\n"
ct.AssertResponse(barR[2], server.PredictionProcessing, nil, barLogs)
barLogs += "completed async prediction\n"
ct.AssertResponse(barR[3], server.PredictionProcessing, nil, barLogs)
ct.AssertResponse(barR[4], server.PredictionSucceeded, "*bar*", barLogs)

bazLogs := ""
ct.AssertResponse(bazR[0], server.PredictionStarting, nil, bazLogs)
bazLogs += "starting async prediction\n"
ct.AssertResponse(bazR[1], server.PredictionProcessing, nil, bazLogs)
bazLogs += "prediction in progress 1/2\n"
ct.AssertResponse(bazR[2], server.PredictionProcessing, nil, bazLogs)
bazLogs += "prediction in progress 2/2\n"
ct.AssertResponse(bazR[3], server.PredictionProcessing, nil, bazLogs)
bazLogs += "completed async prediction\n"
ct.AssertResponse(bazR[4], server.PredictionProcessing, nil, bazLogs)
ct.AssertResponse(bazR[5], server.PredictionSucceeded, "*baz*", bazLogs)
barLogs := "starting async prediction\nprediction in progress 1/1\ncompleted async prediction\n"
ct.AssertResponses(barR, server.PredictionSucceeded, "*bar*", barLogs)
bazLogs := "starting async prediction\nprediction in progress 1/2\nprediction in progress 2/2\ncompleted async prediction\n"
ct.AssertResponses(bazR, server.PredictionSucceeded, "*baz*", bazLogs)

ct.Shutdown()
assert.NoError(t, ct.Cleanup())
Expand All @@ -76,34 +56,17 @@ func TestAsyncPredictorCanceled(t *testing.T) {
pid := "p01"
ct.AsyncPredictionWithId(pid, map[string]any{"i": 60, "s": "bar"})
if *legacyCog {
// Compat: legacy Cog buffers logging?
// Compat: legacy Cog does not send output webhook
time.Sleep(time.Second)
ct.Cancel(pid)
wr := ct.WaitForWebhookCompletion()
assert.Len(t, wr, 3)
logs := ""
ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs)
logs += "starting async prediction\n"
logs += "prediction in progress 1/60\n"
logs += "prediction canceled\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs)
ct.AssertResponse(wr[2], server.PredictionCanceled, nil, logs)
} else {
ct.WaitForWebhook(func(response server.PredictionResponse) bool {
return strings.Contains(response.Logs, "prediction in progress 1/60\n")
})
ct.Cancel(pid)
wr := ct.WaitForWebhookCompletion()
assert.Len(t, wr, 4)
logs := ""
ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs)
logs += "starting async prediction\n"
ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs)
logs += "prediction in progress 1/60\n"
ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs)
logs += "prediction canceled\n"
ct.AssertResponse(wr[3], server.PredictionCanceled, nil, logs)
}
ct.Cancel(pid)
wr := ct.WaitForWebhookCompletion()
logs := "starting async prediction\nprediction in progress 1/60\nprediction canceled\n"
ct.AssertResponses(wr, server.PredictionCanceled, nil, logs)

ct.Shutdown()
assert.NoError(t, ct.Cleanup())
Expand Down
27 changes: 27 additions & 0 deletions internal/tests/cog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ func (ct *CogTest) legacyCmd() *exec.Cmd {
cmd.Dir = tmpDir
cmd.Env = os.Environ()
cmd.Env = append(cmd.Env, fmt.Sprintf("PORT=%d", ct.serverPort))
cmd.Env = append(cmd.Env, "PYTHONUNBUFFERED=1")
cmd.Env = append(cmd.Env, ct.extraEnvs...)
return cmd
}
Expand Down Expand Up @@ -372,3 +373,29 @@ func (ct *CogTest) AssertResponse(
assert.Equal(ct.t, output, response.Output)
assert.Equal(ct.t, logs, response.Logs)
}

func (ct *CogTest) AssertResponses(
responses []server.PredictionResponse,
finalStatus server.PredictionStatus,
finalOutput any,
finalLogs string) {
l := len(responses)
logs := ""
for i, r := range responses {
if i == l-1 {
assert.Equal(ct.t, r.Status, finalStatus)
assert.Equal(ct.t, r.Output, finalOutput)
if r.Status == server.PredictionFailed {
// Compat: legacy Cog includes Traceback in failed logs
assert.Contains(ct.t, r.Logs, finalLogs)
} else {
assert.Equal(ct.t, r.Logs, finalLogs)
}
} else {
assert.Equal(ct.t, r.Status, server.PredictionProcessing)
// Logs are incremental
assert.Contains(ct.t, r.Logs, logs)
logs = r.Logs
}
}
}
Loading

0 comments on commit 4fae8b5

Please sign in to comment.