From 4fae8b575a84d0de7d1585dabd8b6c72f4f54f21 Mon Sep 17 00:00:00 2001 From: Neville Li Date: Fri, 13 Dec 2024 17:46:21 -0500 Subject: [PATCH] Improve legacy Cog compliance * Run legacy Cog with PYTHONUNBUFFERED * Send "processing" instead of "start" webhook * Simplify webhook sequence assertions --- internal/server/runner.go | 13 +- internal/tests/async_prediction_test.go | 179 +++--------------------- internal/tests/async_predictor_test.go | 55 ++------ internal/tests/cog_test.go | 27 ++++ internal/tests/filter_test.go | 100 ++----------- internal/tests/iterator_test.go | 61 +------- internal/tests/path_test.go | 27 +--- internal/tests/prediction_test.go | 8 +- 8 files changed, 83 insertions(+), 387 deletions(-) diff --git a/internal/server/runner.go b/internal/server/runner.go index e4621b3..fac12e1 100644 --- a/internal/server/runner.go +++ b/internal/server/runner.go @@ -44,21 +44,14 @@ func (pr *PendingPrediction) appendLogLine(line string) { func (pr *PendingPrediction) sendWebhook(event WebhookEvent) { pr.mu.Lock() - defer func() { - // Only async and iterator predict writes new response per output item with status = "processing" - // For blocking or non-iterator cases, set it here immediately after sending "starting" webhook - if pr.response.Status == PredictionStarting { - pr.response.Status = PredictionProcessing - } - pr.mu.Unlock() - }() + defer pr.mu.Unlock() if pr.request.Webhook == "" { return } if len(pr.request.WebhookEventsFilter) > 0 && !slices.Contains(pr.request.WebhookEventsFilter, event) { return } - if pr.response.Status == PredictionProcessing { + if event == WebhookLogs || event == WebhookOutput { if time.Since(pr.lastUpdated) < 500*time.Millisecond { return } @@ -392,6 +385,8 @@ func (r *Runner) handleResponses() { if pr.response.Status == PredictionStarting { log.Infow("prediction started", "id", pr.request.Id, "status", pr.response.Status) + // Compat: legacy Cog never sends "start" event + pr.response.Status = PredictionProcessing pr.sendWebhook(WebhookStart) } else if pr.response.Status == PredictionProcessing { log.Infow("prediction processing", "id", pr.request.Id, "status", pr.response.Status) diff --git a/internal/tests/async_prediction_test.go b/internal/tests/async_prediction_test.go index 79e47e8..0328c71 100644 --- a/internal/tests/async_prediction_test.go +++ b/internal/tests/async_prediction_test.go @@ -25,29 +25,8 @@ func TestAsyncPredictionSucceeded(t *testing.T) { ct.AsyncPrediction(map[string]any{"i": 1, "s": "bar"}) wr := ct.WaitForWebhookCompletion() - if *legacyCog { - assert.Len(t, wr, 3) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - // Compat: legacy Cog buffers logging? - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, "*bar*", logs) - logs += "prediction in progress 1/1\n" - logs += "completed prediction\n" - ct.AssertResponse(wr[2], server.PredictionSucceeded, "*bar*", logs) - } else { - assert.Len(t, wr, 5) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/1\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - logs += "completed prediction\n" - ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs) - ct.AssertResponse(wr[4], server.PredictionSucceeded, "*bar*", logs) - } + logs := "starting prediction\nprediction in progress 1/1\ncompleted prediction\n" + ct.AssertResponses(wr, server.PredictionSucceeded, "*bar*", logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) @@ -64,29 +43,8 @@ func TestAsyncPredictionWithIdSucceeded(t *testing.T) { ct.AsyncPredictionWithId("p01", map[string]any{"i": 1, "s": "bar"}) wr := ct.WaitForWebhookCompletion() - if *legacyCog { - assert.Len(t, wr, 3) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - // Compat: legacy Cog buffers logging? - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, "*bar*", logs) - logs += "prediction in progress 1/1\n" - logs += "completed prediction\n" - ct.AssertResponse(wr[2], server.PredictionSucceeded, "*bar*", logs) - } else { - assert.Len(t, wr, 5) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/1\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - logs += "completed prediction\n" - ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs) - ct.AssertResponse(wr[4], server.PredictionSucceeded, "*bar*", logs) - } + logs := "starting prediction\nprediction in progress 1/1\ncompleted prediction\n" + ct.AssertResponses(wr, server.PredictionSucceeded, "*bar*", logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) @@ -104,38 +62,8 @@ func TestAsyncPredictionFailure(t *testing.T) { ct.AsyncPrediction(map[string]any{"i": 1, "s": "bar"}) wr := ct.WaitForWebhookCompletion() - if *legacyCog { - assert.Len(t, wr, 3) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - assert.Equal(t, server.PredictionProcessing, wr[1].Status) - assert.Equal(t, nil, wr[1].Output) - // Compat: legacy Cog includes worker stacktrace - assert.Contains(t, wr[1].Logs, "Traceback") - // Compat: legacy Cog buffers logging? - logs += "starting prediction\n" - logs += "prediction in progress 1/1\n" - logs += "prediction failed\n" - assert.Equal(t, server.PredictionFailed, wr[2].Status) - assert.Equal(t, nil, wr[2].Output) - // Compat: legacy Cog includes worker stacktrace - assert.Contains(t, wr[2].Logs, "Traceback") - assert.Contains(t, wr[2].Logs, logs) - assert.Equal(t, "prediction failed", wr[2].Error) - } else { - assert.Len(t, wr, 5) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/1\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - logs += "prediction failed\n" - ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs) - ct.AssertResponse(wr[4], server.PredictionFailed, nil, logs) - assert.Equal(t, "prediction failed", wr[4].Error) - } + logs := "starting prediction\nprediction in progress 1/1\nprediction failed\n" + ct.AssertResponses(wr, server.PredictionFailed, nil, logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) @@ -154,45 +82,14 @@ func TestAsyncPredictionCrash(t *testing.T) { ct.AsyncPrediction(map[string]any{"i": 1, "s": "bar"}) wr := ct.WaitForWebhookCompletion() + logs := "starting prediction\nprediction in progress 1/1\nprediction crashed\n" + ct.AssertResponses(wr, server.PredictionFailed, nil, logs) if *legacyCog { - assert.Len(t, wr, 3) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - assert.Equal(t, server.PredictionProcessing, wr[1].Status) - assert.Equal(t, nil, wr[1].Output) - // Compat: legacy Cog includes worker stacktrace - assert.Contains(t, wr[1].Logs, "Traceback") - // Compat: legacy Cog buffers logging? - logs += "starting prediction\n" - logs += "prediction in progress 1/1\n" - logs += "prediction crashed\n" - assert.Equal(t, server.PredictionFailed, wr[2].Status) - assert.Equal(t, nil, wr[2].Output) - // Compat: legacy Cog includes worker stacktrace - assert.Contains(t, wr[2].Logs, "Traceback") - assert.Contains(t, wr[2].Logs, logs) - // Compat: legacy Cog cannot handle worker crash - errMsg := "Prediction failed for an unknown reason. It might have run out of memory? (exitcode 1)" - assert.Equal(t, errMsg, wr[2].Error) - assert.Equal(t, "DEFUNCT", ct.HealthCheck().Status) + assert.Equal(t, "Prediction failed for an unknown reason. It might have run out of memory? (exitcode 1)", wr[len(wr)-1].Error) } else { - assert.Len(t, wr, 5) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/1\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - logs += "prediction crashed\n" - ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs) - assert.Equal(t, server.PredictionFailed, wr[4].Status) - assert.Equal(t, nil, wr[4].Output) - assert.Contains(t, wr[4].Logs, logs) - assert.Contains(t, wr[4].Logs, "SystemExit: 1\n") - assert.Equal(t, "prediction failed", wr[4].Error) - assert.Equal(t, "DEFUNCT", ct.HealthCheck().Status) + assert.Equal(t, "prediction failed", wr[len(wr)-1].Error) } + assert.Equal(t, "DEFUNCT", ct.HealthCheck().Status) ct.Shutdown() assert.NoError(t, ct.Cleanup()) @@ -212,36 +109,17 @@ func TestAsyncPredictionCanceled(t *testing.T) { pid := "p01" ct.AsyncPredictionWithId(pid, map[string]any{"i": 60, "s": "bar"}) if *legacyCog { - // Compat: legacy Cog buffers logging? + // Compat: legacy Cog does not send output webhook time.Sleep(time.Second) - ct.Cancel(pid) - wr := ct.WaitForWebhookCompletion() - assert.Len(t, wr, 3) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - // Compat: legacy Cog buffers logging? - logs += "prediction in progress 1/60\n" - logs += "prediction canceled\n" - ct.AssertResponse(wr[2], server.PredictionCanceled, nil, logs) } else { ct.WaitForWebhook(func(response server.PredictionResponse) bool { return strings.Contains(response.Logs, "prediction in progress 1/60\n") }) - ct.Cancel(pid) - wr := ct.WaitForWebhookCompletion() - assert.Len(t, wr, 4) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/60\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - logs += "prediction canceled\n" - ct.AssertResponse(wr[3], server.PredictionCanceled, nil, logs) } + ct.Cancel(pid) + wr := ct.WaitForWebhookCompletion() + logs := "starting prediction\nprediction in progress 1/60\nprediction canceled\n" + ct.AssertResponses(wr, server.PredictionCanceled, nil, logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) @@ -268,29 +146,8 @@ func TestAsyncPredictionConcurrency(t *testing.T) { assert.Equal(t, http.StatusConflict, resp.StatusCode) wr := ct.WaitForWebhookCompletion() - if *legacyCog { - assert.Len(t, wr, 3) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - // Compat: legacy Cog buffers logging? - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, "*bar*", logs) - logs += "prediction in progress 1/1\n" - logs += "completed prediction\n" - ct.AssertResponse(wr[2], server.PredictionSucceeded, "*bar*", logs) - } else { - assert.Len(t, wr, 5) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/1\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - logs += "completed prediction\n" - ct.AssertResponse(wr[3], server.PredictionProcessing, nil, logs) - ct.AssertResponse(wr[4], server.PredictionSucceeded, "*bar*", logs) - } + logs := "starting prediction\nprediction in progress 1/1\ncompleted prediction\n" + ct.AssertResponses(wr, server.PredictionSucceeded, "*bar*", logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) diff --git a/internal/tests/async_predictor_test.go b/internal/tests/async_predictor_test.go index e444109..fd47815 100644 --- a/internal/tests/async_predictor_test.go +++ b/internal/tests/async_predictor_test.go @@ -35,30 +35,10 @@ func TestAsyncPredictorConcurrency(t *testing.T) { bazR = append(bazR, r) } } - assert.Len(t, barR, 5) - assert.Len(t, bazR, 6) - - barLogs := "" - ct.AssertResponse(barR[0], server.PredictionStarting, nil, barLogs) - barLogs += "starting async prediction\n" - ct.AssertResponse(barR[1], server.PredictionProcessing, nil, barLogs) - barLogs += "prediction in progress 1/1\n" - ct.AssertResponse(barR[2], server.PredictionProcessing, nil, barLogs) - barLogs += "completed async prediction\n" - ct.AssertResponse(barR[3], server.PredictionProcessing, nil, barLogs) - ct.AssertResponse(barR[4], server.PredictionSucceeded, "*bar*", barLogs) - - bazLogs := "" - ct.AssertResponse(bazR[0], server.PredictionStarting, nil, bazLogs) - bazLogs += "starting async prediction\n" - ct.AssertResponse(bazR[1], server.PredictionProcessing, nil, bazLogs) - bazLogs += "prediction in progress 1/2\n" - ct.AssertResponse(bazR[2], server.PredictionProcessing, nil, bazLogs) - bazLogs += "prediction in progress 2/2\n" - ct.AssertResponse(bazR[3], server.PredictionProcessing, nil, bazLogs) - bazLogs += "completed async prediction\n" - ct.AssertResponse(bazR[4], server.PredictionProcessing, nil, bazLogs) - ct.AssertResponse(bazR[5], server.PredictionSucceeded, "*baz*", bazLogs) + barLogs := "starting async prediction\nprediction in progress 1/1\ncompleted async prediction\n" + ct.AssertResponses(barR, server.PredictionSucceeded, "*bar*", barLogs) + bazLogs := "starting async prediction\nprediction in progress 1/2\nprediction in progress 2/2\ncompleted async prediction\n" + ct.AssertResponses(bazR, server.PredictionSucceeded, "*baz*", bazLogs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) @@ -76,34 +56,17 @@ func TestAsyncPredictorCanceled(t *testing.T) { pid := "p01" ct.AsyncPredictionWithId(pid, map[string]any{"i": 60, "s": "bar"}) if *legacyCog { - // Compat: legacy Cog buffers logging? + // Compat: legacy Cog does not send output webhook time.Sleep(time.Second) - ct.Cancel(pid) - wr := ct.WaitForWebhookCompletion() - assert.Len(t, wr, 3) - logs := "" - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - logs += "starting async prediction\n" - logs += "prediction in progress 1/60\n" - logs += "prediction canceled\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - ct.AssertResponse(wr[2], server.PredictionCanceled, nil, logs) } else { ct.WaitForWebhook(func(response server.PredictionResponse) bool { return strings.Contains(response.Logs, "prediction in progress 1/60\n") }) - ct.Cancel(pid) - wr := ct.WaitForWebhookCompletion() - assert.Len(t, wr, 4) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "starting async prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/60\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - logs += "prediction canceled\n" - ct.AssertResponse(wr[3], server.PredictionCanceled, nil, logs) } + ct.Cancel(pid) + wr := ct.WaitForWebhookCompletion() + logs := "starting async prediction\nprediction in progress 1/60\nprediction canceled\n" + ct.AssertResponses(wr, server.PredictionCanceled, nil, logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) diff --git a/internal/tests/cog_test.go b/internal/tests/cog_test.go index 2c7cc76..556a38a 100644 --- a/internal/tests/cog_test.go +++ b/internal/tests/cog_test.go @@ -183,6 +183,7 @@ func (ct *CogTest) legacyCmd() *exec.Cmd { cmd.Dir = tmpDir cmd.Env = os.Environ() cmd.Env = append(cmd.Env, fmt.Sprintf("PORT=%d", ct.serverPort)) + cmd.Env = append(cmd.Env, "PYTHONUNBUFFERED=1") cmd.Env = append(cmd.Env, ct.extraEnvs...) return cmd } @@ -372,3 +373,29 @@ func (ct *CogTest) AssertResponse( assert.Equal(ct.t, output, response.Output) assert.Equal(ct.t, logs, response.Logs) } + +func (ct *CogTest) AssertResponses( + responses []server.PredictionResponse, + finalStatus server.PredictionStatus, + finalOutput any, + finalLogs string) { + l := len(responses) + logs := "" + for i, r := range responses { + if i == l-1 { + assert.Equal(ct.t, r.Status, finalStatus) + assert.Equal(ct.t, r.Output, finalOutput) + if r.Status == server.PredictionFailed { + // Compat: legacy Cog includes Traceback in failed logs + assert.Contains(ct.t, r.Logs, finalLogs) + } else { + assert.Equal(ct.t, r.Logs, finalLogs) + } + } else { + assert.Equal(ct.t, r.Status, server.PredictionProcessing) + // Logs are incremental + assert.Contains(ct.t, r.Logs, logs) + logs = r.Logs + } + } +} diff --git a/internal/tests/filter_test.go b/internal/tests/filter_test.go index d0d314e..4049d65 100644 --- a/internal/tests/filter_test.go +++ b/internal/tests/filter_test.go @@ -25,35 +25,13 @@ func TestPredictionFilterAll(t *testing.T) { }) wr := ct.WaitForWebhookCompletion() if *legacyCog { - assert.Len(t, wr, 5) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - ct.AssertResponse(wr[1], server.PredictionProcessing, []any{"*bar-0*"}, logs) - ct.AssertResponse(wr[2], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) // Compat: legacy Cog buffers logging? - logs += "starting prediction\n" - ct.AssertResponse(wr[3], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - logs += "prediction in progress 1/2\n" - logs += "prediction in progress 2/2\n" - logs += "completed prediction\n" - ct.AssertResponse(wr[4], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) + assert.Len(t, wr, 7) } else { assert.Len(t, wr, 8) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/2\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - ct.AssertResponse(wr[3], server.PredictionProcessing, []any{"*bar-0*"}, logs) - logs += "prediction in progress 2/2\n" - ct.AssertResponse(wr[4], server.PredictionProcessing, []any{"*bar-0*"}, logs) - ct.AssertResponse(wr[5], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - logs += "completed prediction\n" - ct.AssertResponse(wr[6], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - ct.AssertResponse(wr[7], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) } + logs := "starting prediction\nprediction in progress 1/2\nprediction in progress 2/2\ncompleted prediction\n" + ct.AssertResponses(wr, server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) @@ -73,11 +51,7 @@ func TestPredictionFilterCompleted(t *testing.T) { }) wr := ct.WaitForWebhookCompletion() assert.Len(t, wr, 1) - logs := "" - logs += "starting prediction\n" - logs += "prediction in progress 1/2\n" - logs += "prediction in progress 2/2\n" - logs += "completed prediction\n" + logs := "starting prediction\nprediction in progress 1/2\nprediction in progress 2/2\ncompleted prediction\n" ct.AssertResponse(wr[0], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) ct.Shutdown() @@ -99,17 +73,8 @@ func TestPredictionFilterStartedCompleted(t *testing.T) { }) wr := ct.WaitForWebhookCompletion() assert.Len(t, wr, 2) - logs := "" - if *legacyCog { - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - } else { - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - } - logs += "starting prediction\n" - logs += "prediction in progress 1/2\n" - logs += "prediction in progress 2/2\n" - logs += "completed prediction\n" + ct.AssertResponse(wr[0], server.PredictionProcessing, nil, "") + logs := "starting prediction\nprediction in progress 1/2\nprediction in progress 2/2\ncompleted prediction\n" ct.AssertResponse(wr[1], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) ct.Shutdown() @@ -130,29 +95,10 @@ func TestPredictionFilterOutput(t *testing.T) { server.WebhookCompleted, }) wr := ct.WaitForWebhookCompletion() - if *legacyCog { - assert.Len(t, wr, 3) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, []any{"*bar-0*"}, logs) - ct.AssertResponse(wr[1], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - // Compat: legacy Cog buffers logging? - logs += "starting prediction\n" - logs += "prediction in progress 1/2\n" - logs += "prediction in progress 2/2\n" - logs += "completed prediction\n" - ct.AssertResponse(wr[2], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) - } else { - assert.Len(t, wr, 3) - logs := "" - logs += "starting prediction\n" - logs += "prediction in progress 1/2\n" - ct.AssertResponse(wr[0], server.PredictionProcessing, []any{"*bar-0*"}, logs) - logs += "prediction in progress 2/2\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - logs += "completed prediction\n" - ct.AssertResponse(wr[2], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) - } + + assert.Len(t, wr, 3) + logs := "starting prediction\nprediction in progress 1/2\nprediction in progress 2/2\ncompleted prediction\n" + ct.AssertResponses(wr, server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) @@ -172,29 +118,9 @@ func TestPredictionFilterLogs(t *testing.T) { server.WebhookCompleted, }) wr := ct.WaitForWebhookCompletion() - if *legacyCog { - assert.Len(t, wr, 2) - logs := "" - logs += "starting prediction\n" - ct.AssertResponse(wr[0], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - logs += "prediction in progress 1/2\n" - logs += "prediction in progress 2/2\n" - logs += "completed prediction\n" - ct.AssertResponse(wr[1], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) - - } else { - assert.Len(t, wr, 5) - logs := "" - logs += "starting prediction\n" - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/2\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 2/2\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, []any{"*bar-0*"}, logs) - logs += "completed prediction\n" - ct.AssertResponse(wr[3], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - ct.AssertResponse(wr[4], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) - } + assert.Len(t, wr, 5) + logs := "starting prediction\nprediction in progress 1/2\nprediction in progress 2/2\ncompleted prediction\n" + ct.AssertResponses(wr, server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) diff --git a/internal/tests/iterator_test.go b/internal/tests/iterator_test.go index c11032c..0928242 100644 --- a/internal/tests/iterator_test.go +++ b/internal/tests/iterator_test.go @@ -35,36 +35,8 @@ func testPredictionIteratorSucceeded(t *testing.T, module string) { ct.AsyncPrediction(map[string]any{"i": 2, "s": "bar"}) wr := ct.WaitForWebhookCompletion() - if *legacyCog { - assert.Len(t, wr, 5) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - ct.AssertResponse(wr[1], server.PredictionProcessing, []any{"*bar-0*"}, logs) - ct.AssertResponse(wr[2], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - // Compat: legacy Cog buffers logging? - logs += "starting prediction\n" - ct.AssertResponse(wr[3], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - logs += "prediction in progress 1/2\n" - logs += "prediction in progress 2/2\n" - logs += "completed prediction\n" - ct.AssertResponse(wr[4], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) - } else { - assert.Len(t, wr, 8) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "starting prediction\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "prediction in progress 1/2\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - ct.AssertResponse(wr[3], server.PredictionProcessing, []any{"*bar-0*"}, logs) - logs += "prediction in progress 2/2\n" - ct.AssertResponse(wr[4], server.PredictionProcessing, []any{"*bar-0*"}, logs) - ct.AssertResponse(wr[5], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - logs += "completed prediction\n" - ct.AssertResponse(wr[6], server.PredictionProcessing, []any{"*bar-0*", "*bar-1*"}, logs) - ct.AssertResponse(wr[7], server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) - } + logs := "starting prediction\nprediction in progress 1/2\nprediction in progress 2/2\ncompleted prediction\n" + ct.AssertResponses(wr, server.PredictionSucceeded, []any{"*bar-0*", "*bar-1*"}, logs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) @@ -95,31 +67,10 @@ func TestPredictionAsyncIteratorConcurrency(t *testing.T) { bazR = append(bazR, r) } } - assert.Len(t, barR, 6) - barLogs := "" - ct.AssertResponse(barR[0], server.PredictionStarting, nil, barLogs) - barLogs += "starting prediction\n" - ct.AssertResponse(barR[1], server.PredictionProcessing, nil, barLogs) - barLogs += "prediction in progress 1/1\n" - ct.AssertResponse(barR[2], server.PredictionProcessing, nil, barLogs) - ct.AssertResponse(barR[3], server.PredictionProcessing, []any{"*bar-0*"}, barLogs) - barLogs += "completed prediction\n" - ct.AssertResponse(barR[4], server.PredictionProcessing, []any{"*bar-0*"}, barLogs) - ct.AssertResponse(barR[5], server.PredictionSucceeded, []any{"*bar-0*"}, barLogs) - assert.Len(t, bazR, 8) - bazLogs := "" - ct.AssertResponse(bazR[0], server.PredictionStarting, nil, bazLogs) - bazLogs += "starting prediction\n" - ct.AssertResponse(bazR[1], server.PredictionProcessing, nil, bazLogs) - bazLogs += "prediction in progress 1/2\n" - ct.AssertResponse(bazR[2], server.PredictionProcessing, nil, bazLogs) - ct.AssertResponse(bazR[3], server.PredictionProcessing, []any{"*baz-0*"}, bazLogs) - bazLogs += "prediction in progress 2/2\n" - ct.AssertResponse(bazR[4], server.PredictionProcessing, []any{"*baz-0*"}, bazLogs) - ct.AssertResponse(bazR[5], server.PredictionProcessing, []any{"*baz-0*", "*baz-1*"}, bazLogs) - bazLogs += "completed prediction\n" - ct.AssertResponse(bazR[6], server.PredictionProcessing, []any{"*baz-0*", "*baz-1*"}, bazLogs) - ct.AssertResponse(bazR[7], server.PredictionSucceeded, []any{"*baz-0*", "*baz-1*"}, bazLogs) + barLogs := "starting prediction\nprediction in progress 1/1\ncompleted prediction\n" + ct.AssertResponses(barR, server.PredictionSucceeded, []any{"*bar-0*"}, barLogs) + bazLogs := "starting prediction\nprediction in progress 1/2\nprediction in progress 2/2\ncompleted prediction\n" + ct.AssertResponses(bazR, server.PredictionSucceeded, []any{"*baz-0*", "*baz-1*"}, bazLogs) ct.Shutdown() assert.NoError(t, ct.Cleanup()) diff --git a/internal/tests/path_test.go b/internal/tests/path_test.go index 44b3704..dd10bcb 100644 --- a/internal/tests/path_test.go +++ b/internal/tests/path_test.go @@ -85,29 +85,10 @@ func TestPredictionPathUploadUrlSucceeded(t *testing.T) { wr := ct.WaitForWebhookCompletion() ul := ct.GetUploads() - if *legacyCog { - assert.Len(t, wr, 3) - assert.Len(t, ul, 1) - logs := "" - // Compat: legacy Cog sends no "starting" event - ct.AssertResponse(wr[0], server.PredictionProcessing, nil, logs) - logs += "reading input file\n" - url := fmt.Sprintf("http://localhost:%d%s", ct.webhookPort, ul[0].Path) - ct.AssertResponse(wr[1], server.PredictionProcessing, url, logs) - logs += "writing output file\n" - ct.AssertResponse(wr[2], server.PredictionSucceeded, url, logs) - } else { - assert.Len(t, wr, 4) - assert.Len(t, ul, 1) - logs := "" - ct.AssertResponse(wr[0], server.PredictionStarting, nil, logs) - logs += "reading input file\n" - ct.AssertResponse(wr[1], server.PredictionProcessing, nil, logs) - logs += "writing output file\n" - ct.AssertResponse(wr[2], server.PredictionProcessing, nil, logs) - url := fmt.Sprintf("http://localhost:%d%s", ct.webhookPort, ul[0].Path) - ct.AssertResponse(wr[3], server.PredictionSucceeded, url, logs) - } + assert.Len(t, ul, 1) + logs := "reading input file\nwriting output file\n" + url := fmt.Sprintf("http://localhost:%d%s", ct.webhookPort, ul[0].Path) + ct.AssertResponses(wr, server.PredictionSucceeded, url, logs) body := string(ul[0].Body) assert.Contains(t, body, "*bar*") diff --git a/internal/tests/prediction_test.go b/internal/tests/prediction_test.go index a51f6e5..768709a 100644 --- a/internal/tests/prediction_test.go +++ b/internal/tests/prediction_test.go @@ -1,7 +1,6 @@ package tests import ( - "fmt" "io" "net/http" "testing" @@ -62,11 +61,8 @@ func TestPredictionFailure(t *testing.T) { assert.Equal(t, server.PredictionFailed, resp.Status) assert.Equal(t, nil, resp.Output) logs := "starting prediction\nprediction in progress 1/1\nprediction failed\n" - if *legacyCog { - assert.Contains(t, resp.Logs, fmt.Sprintf("Exception: prediction failed\n%s", logs)) - } else { - assert.Equal(t, logs, resp.Logs) - } + // Compat: legacy Cog also includes Traceback + assert.Contains(t, resp.Logs, logs) assert.Equal(t, "prediction failed", resp.Error) ct.Shutdown()