diff --git a/spec/integ/crypto/megolm-backup.spec.ts b/spec/integ/crypto/megolm-backup.spec.ts index eee0eafebc2..2e0352197cb 100644 --- a/spec/integ/crypto/megolm-backup.spec.ts +++ b/spec/integ/crypto/megolm-backup.spec.ts @@ -46,6 +46,7 @@ import { IKeyBackup } from "../../../src/crypto/backup"; import { flushPromises } from "../../test-utils/flushPromises"; import { defer, IDeferred } from "../../../src/utils"; import { DecryptionFailureCode } from "../../../src/crypto-api"; +import { ImportRoomKeysOpts } from "../../../src/crypto-api"; const ROOM_ID = testData.TEST_ROOM_ID; @@ -311,6 +312,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe describe("recover from backup", () => { let aliceCrypto: Crypto.CryptoApi; + let importMockImpl: jest.Mock; beforeEach(async () => { fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); @@ -322,6 +324,22 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe // tell Alice to trust the dummy device that signed the backup await waitForDeviceList(); await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID); + + importMockImpl = jest + .fn() + .mockImplementation((keys: IMegolmSessionData[], version: String, opts?: ImportRoomKeysOpts) => { + // need to report progress + if (opts?.progressCallback) { + opts.progressCallback({ + stage: "load_keys", + successes: keys.length, + failures: 0, + total: keys.length, + }); + } + }); + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.importBackedUpRoomKeys = importMockImpl; }); it("can restore from backup (Curve25519 version)", async function () { @@ -397,10 +415,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe } it("Should import full backup in chunks", async function () { - const importMockImpl = jest.fn(); - // @ts-ignore - mock a private method for testing purpose - aliceCrypto.importBackedUpRoomKeys = importMockImpl; - // We need several rooms with several sessions to test chunking const { response, expectedTotal } = createBackupDownloadResponse([45, 300, 345, 12, 130]); @@ -459,7 +473,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe throw new Error("test error"); }) // Ok for other chunks - .mockResolvedValue(undefined); + .mockImplementation(importMockImpl); const { response, expectedTotal } = createBackupDownloadResponse([100, 300]); @@ -498,9 +512,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe }); it("Should continue if some keys fails to decrypt", async function () { - // @ts-ignore - mock a private method for testing purpose - aliceCrypto.importBackedUpRoomKeys = jest.fn(); - const decryptionFailureCount = 2; const mockDecryptor = { @@ -540,6 +551,85 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe expect(result.imported).toStrictEqual(expectedTotal - decryptionFailureCount); }); + it("Should report failures when decryption works but import fails", async function () { + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.importBackedUpRoomKeys = jest + .fn() + .mockImplementationOnce((keys: IMegolmSessionData[], version: String, opts?: ImportRoomKeysOpts) => { + // report 10 failures to import + opts!.progressCallback!({ + stage: "load_keys", + successes: 20, + failures: 10, + total: 30, + }); + return Promise.resolve(); + }) + // Ok for other chunks + .mockResolvedValue(importMockImpl); + + const { response, expectedTotal } = createBackupDownloadResponse([30]); + + fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response); + + const check = await aliceCrypto.checkKeyBackupAndEnable(); + + const progressCallback = jest.fn(); + const result = await aliceClient.restoreKeyBackupWithRecoveryKey( + testData.BACKUP_DECRYPTION_KEY_BASE58, + undefined, + undefined, + check!.backupInfo!, + { + progressCallback, + }, + ); + + expect(result.total).toStrictEqual(expectedTotal); + // A chunk failed to import + expect(result.imported).toStrictEqual(20); + }); + + it("Should report failures when decryption works but import fails - per room variant", async function () { + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.importBackedUpRoomKeys = jest + .fn() + .mockImplementationOnce((keys: IMegolmSessionData[], version: String, opts?: ImportRoomKeysOpts) => { + // report 10 failures to import + opts!.progressCallback!({ + stage: "load_keys", + successes: 20, + failures: 10, + total: 30, + }); + return Promise.resolve(); + }) + // Ok for other chunks + .mockResolvedValue(importMockImpl); + + const { response, expectedTotal } = createBackupDownloadResponse([30]); + const roomId = Object.keys(response.rooms)[0]; + + fetchMock.get(`express:/_matrix/client/v3/room_keys/keys/${roomId}`, response.rooms[roomId]); + + const check = await aliceCrypto.checkKeyBackupAndEnable(); + + const progressCallback = jest.fn(); + const result = await aliceClient.restoreKeyBackupWithRecoveryKey( + testData.BACKUP_DECRYPTION_KEY_BASE58, + roomId, + undefined, + check!.backupInfo!, + { + progressCallback, + }, + ); + + expect(result.total).toStrictEqual(expectedTotal); + // A chunk failed to import + expect(result.imported).toStrictEqual(20); + }); + it("recover specific session from backup", async function () { fetchMock.get( "express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id", diff --git a/src/client.ts b/src/client.ts index 72e28f52916..b7401ec4a57 100644 --- a/src/client.ts +++ b/src/client.ts @@ -212,7 +212,13 @@ import { LocalNotificationSettings } from "./@types/local_notifications"; import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature"; import { BackupDecryptor, CryptoBackend } from "./common-crypto/CryptoBackend"; import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants"; -import { BootstrapCrossSigningOpts, CrossSigningKeyInfo, CryptoApi, ImportRoomKeysOpts } from "./crypto-api"; +import { + BootstrapCrossSigningOpts, + CrossSigningKeyInfo, + CryptoApi, + ImportRoomKeyProgressData, + ImportRoomKeysOpts, +} from "./crypto-api"; import { DeviceInfoMap } from "./crypto/DeviceList"; import { AddSecretStorageKeyOpts, @@ -3916,11 +3922,18 @@ export class MatrixClient extends TypedEventEmitter { // We have a chunk of decrypted keys: import them try { - const backupVersion = backupInfo.version!; + let success = 0; + let failures = 0; + const partialProgress = (stage: ImportRoomKeyProgressData): void => { + success = stage.successes ?? 0; + failures = stage.failures ?? 0; + }; await this.cryptoBackend!.importBackedUpRoomKeys(chunk, backupVersion, { untrusted, + progressCallback: partialProgress, }); - totalImported += chunk.length; + totalImported += success; + totalFailures += failures; } catch (e) { totalFailures += chunk.length; // We failed to import some keys, but we should still try to import the rest? @@ -3947,11 +3960,25 @@ export class MatrixClient extends TypedEventEmitter { + success = stage.successes ?? 0; + failures = stage.failures ?? 0; + }; + await this.cryptoBackend!.importBackedUpRoomKeys(keys, backupVersion, { + untrusted, + progressCallback: partialProgress, + }); + totalImported += success; + totalFailures += failures; + } catch (e) { + totalFailures += keys.length; + // We failed to import some keys, but we should still try to import the rest? + // Log the error and continue + logger.error("Error importing keys from backup", e); + } } else { totalKeyCount = 1; try { @@ -3967,6 +3994,7 @@ export class MatrixClient extends TypedEventEmitter { const importOpt: ImportRoomKeyProgressData = { @@ -265,6 +265,17 @@ export class RustBackupManager extends TypedEventEmitter | null = null;