diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 44e82aed08..af638951d2 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -158,7 +158,35 @@ func SendEvent( } } - evTime, err := httputil.ParseTSParam(req) + + // HandleEventTimestamp processes the ts parameter based on whether the request is from an appservice. + func HandleEventTimestamp(req *http.Request) (time.Time, error) { + if isAppService(req) { + evTime, err := httputil.ParseTSParam(req) + if err != nil { + return time.Time{}, err // Return error for further handling + } + return evTime, nil + } + + // If not from an appservice, use the current time or other default handling + return time.Now(), nil + } + + // Check if the request is from an appservice + func isAppService(req *http.Request) error { + evTime, err := HandleEventTimestamp(req) + if err != nil { + // Handle error, e.g., return a 400 Bad Request + return httputil.LogThenError(req, err) + } + + // Use evTime for the event timestamp + // Proceed with your original logic... + + return nil + } + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -291,6 +319,26 @@ func updatePowerLevels(req *http.Request, r map[string]interface{}, roomID strin return nil } +//If appservices use a specific access token, you can check the request’s authorization header for this token. +func isAppService(req *http.Request) bool { + // Check if the request contains an access token for appservices + accessToken := req.Header.Get("Authorization") + + // This is an example; you need to replace it with your appservice token checking logic + if strings.HasPrefix(accessToken, "Bearer") { + token := strings.TrimPrefix(accessToken, "Bearer ") + return isValidAppServiceToken(token) + } + + return false +} + +func isValidAppServiceToken(token string) bool { + // Placeholder function: implement logic to validate if the token belongs to an appservice + // For example, you could compare against a list of known appservice tokens + return token == "your_appservice_token" +} + // stateEqual compares the new and the existing state event content. If they are equal, returns a *util.JSONResponse // with the existing event_id, making this an idempotent request. func stateEqual(ctx context.Context, rsAPI api.ClientRoomserverAPI, eventType, stateKey, roomID string, newContent map[string]interface{}) *util.JSONResponse { diff --git a/test/ts_param_test.go b/test/ts_param_test.go new file mode 100644 index 0000000000..95aa60fc60 --- /dev/null +++ b/test/ts_param_test.go @@ -0,0 +1,59 @@ +// ts_param_test.go +package testing + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + "your_project_path/routing" // Adjust this import path +) + +func createRequestWithTS(ts string, isAppService bool) *http.Request { + req := httptest.NewRequest("POST", "/_matrix/client/r0/rooms/!roomid:domain/send/m.room.message", nil) + q := req.URL.Query() + if ts != "" { + q.Add("ts", ts) + } + req.URL.RawQuery = q.Encode() + + if isAppService { + req.Header.Set("Authorization", "Bearer your_appservice_token") + } else { + req.Header.Set("Authorization", "Bearer regular_user_token") + } + return req +} + +func TestHandleEventTimestamp_ValidAppService(t *testing.T) { + req := createRequestWithTS("1657890000000", true) + evTime, err := routing.HandleEventTimestamp(req) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + expectedTime := time.Unix(1657890000, 0) + if !evTime.Equal(expectedTime) { + t.Errorf("Expected time %v, got %v", expectedTime, evTime) + } +} + +func TestHandleEventTimestamp_InvalidTS(t *testing.T) { + req := createRequestWithTS("invalid_ts", true) + _, err := routing.HandleEventTimestamp(req) + if err == nil { + t.Fatal("Expected an error, got none") + } +} + +func TestHandleEventTimestamp_NonAppService(t *testing.T) { + req := createRequestWithTS("1657890000000", false) + evTime, err := routing.HandleEventTimestamp(req) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + if time.Now().Sub(evTime) > time.Second { + t.Errorf("Expected current time, got %v", evTime) + } +}