diff --git a/spec/unit/rust-crypto/rust-crypto.spec.ts b/spec/unit/rust-crypto/rust-crypto.spec.ts index 52916b43642..fe8366cc133 100644 --- a/spec/unit/rust-crypto/rust-crypto.spec.ts +++ b/spec/unit/rust-crypto/rust-crypto.spec.ts @@ -46,7 +46,7 @@ import { } from "../../../src"; import { mkEvent } from "../../test-utils/test-utils"; import { CryptoBackend } from "../../../src/common-crypto/CryptoBackend"; -import { IEventDecryptionResult } from "../../../src/@types/crypto"; +import { IEventDecryptionResult, IMegolmSessionData } from "../../../src/@types/crypto"; import { OutgoingRequestProcessor } from "../../../src/rust-crypto/OutgoingRequestProcessor"; import { AccountDataClient, @@ -1260,6 +1260,34 @@ describe("RustCrypto", () => { }, }); }); + + it("ignores invalid keys when restoring from backup", async () => { + const rustCrypto = await makeTestRustCrypto(); + const olmMachine: OlmMachine = rustCrypto["olmMachine"]; + + await olmMachine.enableBackupV1( + (testData.SIGNED_BACKUP_DATA.auth_data as Curve25519AuthData).public_key, + testData.SIGNED_BACKUP_DATA.version!, + ); + + const backup = Array.from(testData.MEGOLM_SESSION_DATA_ARRAY); + // in addition to correct keys, we restore an invalid key + backup.push({ room_id: "!roomid", session_id: "sessionid" } as IMegolmSessionData); + const progressCallback = jest.fn(); + await rustCrypto.importBackedUpRoomKeys(backup, { progressCallback }); + expect(progressCallback).toHaveBeenCalledWith({ + total: 3, + successes: 0, + stage: "load_keys", + failures: 1, + }); + expect(progressCallback).toHaveBeenCalledWith({ + total: 3, + successes: 1, + stage: "load_keys", + failures: 1, + }); + }); }); }); diff --git a/src/rust-crypto/backup.ts b/src/rust-crypto/backup.ts index 21964eb3afb..d7ca7d85a42 100644 --- a/src/rust-crypto/backup.ts +++ b/src/rust-crypto/backup.ts @@ -212,15 +212,18 @@ export class RustBackupManager extends TypedEventEmitter { - const importOpt: ImportRoomKeyProgressData = { - total: Number(total), - successes: Number(progress), - stage: "load_keys", - failures: 0, - }; - opts?.progressCallback?.(importOpt); - }); + await this.olmMachine.importBackedUpRoomKeys( + keysByRoom, + (progress: BigInt, total: BigInt, failures: BigInt): void => { + const importOpt: ImportRoomKeyProgressData = { + total: Number(total), + successes: Number(progress), + stage: "load_keys", + failures: Number(failures), + }; + opts?.progressCallback?.(importOpt); + }, + ); } private keyBackupCheckInProgress: Promise | null = null;