From c20943d3f5d6f64981e9eaed8a3a9c3a62124fc5 Mon Sep 17 00:00:00 2001 From: Zach Nagengast Date: Wed, 1 May 2024 01:25:15 -0700 Subject: [PATCH] Cleanup (#132) * Lint tests * Lint library * Lint examples * Fix log prob alignment and timing * Allow tokenizer to be loaded from disk if it exists already --- .../WhisperAX/Views/ContentView.swift | 35 ++++----- .../WhisperAXTests/WhisperAXTests.swift | 2 - .../WhisperAXUITests/WhisperAXUITests.swift | 1 - .../WhisperAXUITestsLaunchTests.swift | 1 - .../WhisperAXExampleView.swift | 3 +- .../WhisperAX_Watch_AppTests.swift | 4 +- .../WhisperAX_Watch_AppUITests.swift | 1 - ...hisperAX_Watch_AppUITestsLaunchTests.swift | 1 - Sources/WhisperKit/Core/LogitsFilter.swift | 20 +++--- Sources/WhisperKit/Core/Models.swift | 50 ++++++------- Sources/WhisperKit/Core/SegmentSeeker.swift | 9 +-- Sources/WhisperKit/Core/TextDecoder.swift | 45 ++++++------ Sources/WhisperKit/Core/TranscribeTask.swift | 2 +- Sources/WhisperKit/Core/Utils.swift | 40 +++++++++-- Sources/WhisperKit/Core/WhisperKit.swift | 14 ++-- Sources/WhisperKitCLI/CLIArguments.swift | 2 +- Sources/WhisperKitCLI/Transcribe.swift | 14 ++-- Tests/WhisperKitTests/FunctionalTests.swift | 32 ++++----- Tests/WhisperKitTests/MemoryTestUtils.swift | 71 ++++++++++--------- Tests/WhisperKitTests/RegressionTests.swift | 67 ++++++++--------- Tests/WhisperKitTests/TestUtils.swift | 25 +++---- Tests/WhisperKitTests/UnitTests.swift | 30 ++++---- 22 files changed, 245 insertions(+), 224 deletions(-) diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 4ded0f4..4fc76a1 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -29,7 +29,7 @@ struct ContentView: View { @State private var availableModels: [String] = [] @State private var availableLanguages: [String] = [] @State private var disabledModels: [String] = WhisperKit.recommendedModels().disabled - + @AppStorage("selectedAudioInput") private var selectedAudioInput: String = "No Audio Input" @AppStorage("selectedModel") private var selectedModel: String = WhisperKit.recommendedModels().default @AppStorage("selectedTab") private var selectedTab: String = "Transcribe" @@ -73,7 +73,6 @@ struct ContentView: View { @State private var unconfirmedSegments: [TranscriptionSegment] = [] @State private var unconfirmedText: [String] = [] - // MARK: Eager mode properties @State private var eagerResults: [TranscriptionResult?] = [] @@ -274,7 +273,8 @@ struct ContentView: View { !isRecording, !isTranscribing, whisperKit.progress.fractionCompleted > 0, - whisperKit.progress.fractionCompleted < 1 { + whisperKit.progress.fractionCompleted < 1 + { ProgressView(whisperKit.progress) .progressViewStyle(.linear) .labelsHidden() @@ -314,7 +314,7 @@ struct ContentView: View { .progressViewStyle(CircularProgressViewStyle()) .scaleEffect(0.5) } - + Button(action: { deleteModel() }, label: { @@ -405,14 +405,15 @@ struct ContentView: View { if let audioDevices = audioDevices, !audioDevices.isEmpty, selectedAudioInput == "No Audio Input", - let device = audioDevices.first { + let device = audioDevices.first + { selectedAudioInput = device.name } } #endif } } - + var controlsView: some View { VStack { basicSettingsView @@ -887,13 +888,12 @@ struct ContentView: View { } }) } - + await MainActor.run { loadingProgressValue = specializationProgressRatio modelState = .downloaded } - if let modelFolder = folder { whisperKit.modelFolder = modelFolder @@ -936,7 +936,7 @@ struct ContentView: View { if !localModels.contains(model) { localModels.append(model) } - + availableLanguages = Constants.languages.map { $0.key }.sorted() loadingProgressValue = 1.0 modelState = whisperKit.modelState @@ -944,18 +944,18 @@ struct ContentView: View { } } } - + func deleteModel() { if localModels.contains(selectedModel) { let modelFolder = URL(fileURLWithPath: localModelPath).appendingPathComponent(selectedModel) - + do { try FileManager.default.removeItem(at: modelFolder) - + if let index = localModels.firstIndex(of: selectedModel) { localModels.remove(at: index) } - + modelState = .unloaded } catch { print("Error deleting model: \(error)") @@ -1058,18 +1058,19 @@ struct ContentView: View { print("Microphone access was not granted.") return } - + var deviceId: DeviceID? #if os(macOS) if self.selectedAudioInput != "No Audio Input", let devices = self.audioDevices, - let device = devices.first(where: {$0.name == selectedAudioInput}) { + let device = devices.first(where: { $0.name == selectedAudioInput }) + { deviceId = device.id } // There is no built-in microphone if deviceId == nil { - throw WhisperError.microphoneUnavailable() + throw WhisperError.microphoneUnavailable() } #endif @@ -1403,7 +1404,7 @@ struct ContentView: View { return nil } - Logging.info("[EagerMode] \(lastAgreedSeconds)-\(Double(samples.count)/16000.0) seconds") + Logging.info("[EagerMode] \(lastAgreedSeconds)-\(Double(samples.count) / 16000.0) seconds") let streamingAudio = samples var streamOptions = options diff --git a/Examples/WhisperAX/WhisperAXTests/WhisperAXTests.swift b/Examples/WhisperAX/WhisperAXTests/WhisperAXTests.swift index 7585ddb..616cd7a 100644 --- a/Examples/WhisperAX/WhisperAXTests/WhisperAXTests.swift +++ b/Examples/WhisperAX/WhisperAXTests/WhisperAXTests.swift @@ -4,7 +4,6 @@ import XCTest final class WhisperAXTests: XCTestCase { - override func setUpWithError() throws { // Put setup code here. This method is called before the invocation of each test method in the class. } @@ -27,5 +26,4 @@ final class WhisperAXTests: XCTestCase { // Put the code you want to measure the time of here. } } - } diff --git a/Examples/WhisperAX/WhisperAXUITests/WhisperAXUITests.swift b/Examples/WhisperAX/WhisperAXUITests/WhisperAXUITests.swift index cdeaeed..c48bada 100644 --- a/Examples/WhisperAX/WhisperAXUITests/WhisperAXUITests.swift +++ b/Examples/WhisperAX/WhisperAXUITests/WhisperAXUITests.swift @@ -4,7 +4,6 @@ import XCTest final class WhisperAXUITests: XCTestCase { - override func setUpWithError() throws { // Put setup code here. This method is called before the invocation of each test method in the class. diff --git a/Examples/WhisperAX/WhisperAXUITests/WhisperAXUITestsLaunchTests.swift b/Examples/WhisperAX/WhisperAXUITests/WhisperAXUITestsLaunchTests.swift index 3f88b33..00b1373 100644 --- a/Examples/WhisperAX/WhisperAXUITests/WhisperAXUITestsLaunchTests.swift +++ b/Examples/WhisperAX/WhisperAXUITests/WhisperAXUITestsLaunchTests.swift @@ -4,7 +4,6 @@ import XCTest final class WhisperAXUITestsLaunchTests: XCTestCase { - override class var runsForEachTargetApplicationUIConfiguration: Bool { true } diff --git a/Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift b/Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift index b6196b4..a4efe46 100644 --- a/Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift +++ b/Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift @@ -249,7 +249,8 @@ struct WhisperAXWatchView: View { let currentTranscription = (confirmedSegments.map { $0.text } + unconfirmedSegments.map { $0.text }).joined(separator: " ") ShareLink(item: currentTranscription, label: { Image(systemName: "square.and.arrow.up") - }) } + }) + } ToolbarItem(placement: .bottomBar) { Button { withAnimation { diff --git a/Examples/WhisperAX/WhisperAXWatchAppTests/WhisperAX_Watch_AppTests.swift b/Examples/WhisperAX/WhisperAXWatchAppTests/WhisperAX_Watch_AppTests.swift index e814f80..183d972 100644 --- a/Examples/WhisperAX/WhisperAXWatchAppTests/WhisperAX_Watch_AppTests.swift +++ b/Examples/WhisperAX/WhisperAXWatchAppTests/WhisperAX_Watch_AppTests.swift @@ -1,11 +1,10 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. -import XCTest @testable import Basic_Watch_App +import XCTest final class WhisperAX_Watch_AppTests: XCTestCase { - override func setUpWithError() throws { // Put setup code here. This method is called before the invocation of each test method in the class. } @@ -28,5 +27,4 @@ final class WhisperAX_Watch_AppTests: XCTestCase { // Put the code you want to measure the time of here. } } - } diff --git a/Examples/WhisperAX/WhisperAXWatchAppUITests/WhisperAX_Watch_AppUITests.swift b/Examples/WhisperAX/WhisperAXWatchAppUITests/WhisperAX_Watch_AppUITests.swift index 888d7f3..bd2dafd 100644 --- a/Examples/WhisperAX/WhisperAXWatchAppUITests/WhisperAX_Watch_AppUITests.swift +++ b/Examples/WhisperAX/WhisperAXWatchAppUITests/WhisperAX_Watch_AppUITests.swift @@ -4,7 +4,6 @@ import XCTest final class WhisperAX_Watch_AppUITests: XCTestCase { - override func setUpWithError() throws { // Put setup code here. This method is called before the invocation of each test method in the class. diff --git a/Examples/WhisperAX/WhisperAXWatchAppUITests/WhisperAX_Watch_AppUITestsLaunchTests.swift b/Examples/WhisperAX/WhisperAXWatchAppUITests/WhisperAX_Watch_AppUITestsLaunchTests.swift index 8dac648..d1f6eeb 100644 --- a/Examples/WhisperAX/WhisperAXWatchAppUITests/WhisperAX_Watch_AppUITestsLaunchTests.swift +++ b/Examples/WhisperAX/WhisperAXWatchAppUITests/WhisperAX_Watch_AppUITestsLaunchTests.swift @@ -4,7 +4,6 @@ import XCTest final class WhisperAX_Watch_AppUITestsLaunchTests: XCTestCase { - override class var runsForEachTargetApplicationUIConfiguration: Bool { true } diff --git a/Sources/WhisperKit/Core/LogitsFilter.swift b/Sources/WhisperKit/Core/LogitsFilter.swift index 4b31150..724785c 100644 --- a/Sources/WhisperKit/Core/LogitsFilter.swift +++ b/Sources/WhisperKit/Core/LogitsFilter.swift @@ -40,7 +40,7 @@ open class SuppressBlankFilter: LogitsFiltering { self.sampleBegin = sampleBegin self.suppressTokenIndexes = [ [0, 0, specialTokens.whitespaceToken as NSNumber], - [0, 0, specialTokens.endToken as NSNumber] + [0, 0, specialTokens.endToken as NSNumber], ] } @@ -75,10 +75,11 @@ open class TimestampRulesFilter: LogitsFiltering { public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray { guard let sampleBegin = sampleBegin(for: tokens), - sampleBegin > tokens.count else { + sampleBegin > tokens.count + else { return logits } - + // suppress <|notimestamps|> which is handled by `withoutTimestamps` logits.fill(indexes: [[0, 0, specialTokens.noTimestampsToken as NSNumber]], with: -FloatType.infinity) @@ -244,7 +245,6 @@ open class TimestampRulesFilter: LogitsFiltering { } } - @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) open class LanguageLogitsFilter: LogitsFiltering { let allLanguageTokens: Set @@ -259,19 +259,19 @@ open class LanguageLogitsFilter: LogitsFiltering { self.nonLanguageTokenIndexes = LanguageLogitsFilter.getNonLanguageTokenIndexes(logitsDim: self.logitsDim, allLanguageTokens: self.allLanguageTokens) } - // Retain the logits that correspond to language tokens and suppress non-language tokens + /// Retain the logits that correspond to language tokens and suppress non-language tokens public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray { - guard tokens.count == sampleBegin else{ + guard tokens.count == sampleBegin else { return logits } logits.fill(indexes: nonLanguageTokenIndexes, with: -FloatType.infinity) return logits } - - private static func getNonLanguageTokenIndexes(logitsDim: Int, allLanguageTokens: Set) -> [[NSNumber]]{ + + private static func getNonLanguageTokenIndexes(logitsDim: Int, allLanguageTokens: Set) -> [[NSNumber]] { var indexes: [[NSNumber]] = [] - for i in 0.. specialTokens.specialTokenBegin } ) } - + private func splitTokensOnUnicode(tokens: [Int]) -> (words: [String], wordTokens: [[Int]]) { let decodedFull = tokenizer.decode(tokens: tokens) let replacementString = "\u{fffd}" - + var words: [String] = [] var wordTokens: [[Int]] = [] var currentTokens: [Int] = [] var unicodeOffset = 0 - + for token in tokens { currentTokens.append(token) let decoded = tokenizer.decode(tokens: currentTokens) - + var hasUnicodeInFullString = false if let range = decoded.range(of: replacementString) { hasUnicodeInFullString = decodedFull[range] == replacementString } - + if !decoded.contains(replacementString) || hasUnicodeInFullString { words.append(decoded) wordTokens.append(currentTokens) @@ -1116,15 +1116,15 @@ struct WhisperTokenizerWrapper: WhisperTokenizer { unicodeOffset += decoded.count } } - + return (words, wordTokens) } - + private func splitTokensOnSpaces(tokens: [Int]) -> (words: [String], wordTokens: [[Int]]) { let (subwords, subwordTokensList) = splitTokensOnUnicode(tokens: tokens) var words: [String] = [] var wordTokens: [[Int]] = [] - + for (subword, subwordTokens) in zip(subwords, subwordTokensList) { let special = subwordTokens.first! >= specialTokens.specialTokenBegin let withSpace = subword.hasPrefix(" ") @@ -1140,10 +1140,10 @@ struct WhisperTokenizerWrapper: WhisperTokenizer { wordTokens[words.count - 1].append(contentsOf: subwordTokens) } } - + return (words, wordTokens) } - + private func isPunctuation(_ text: String, tokenRange: Range, tag: NLTag?) -> Bool { let punctuationCharacters = CharacterSet.punctuationCharacters let token = String(text[tokenRange]) @@ -1154,18 +1154,18 @@ struct WhisperTokenizerWrapper: WhisperTokenizer { } return false } - + /// Decodes token ids into individual words and per-word subtokens /// - Parameter tokenIds: Array of tokens to decode and then split /// - Returns: Tuple containing and array of the split words and all tokens for each word func splitToWordTokens(tokenIds: [Int]) -> (words: [String], wordTokens: [[Int]]) { let decodedWords = tokenizer.decode(tokens: tokenIds.filter { $0 < specialTokens.specialTokenBegin }) - + // Detect language of input text let recognizer = NLLanguageRecognizer() recognizer.processString(decodedWords) let languageCode = recognizer.dominantLanguage?.rawValue - + if ["zh", "ja", "th", "lo", "my", "yue"].contains(languageCode) { return splitTokensOnUnicode(tokens: tokenIds) } else { @@ -1178,43 +1178,43 @@ extension WhisperTokenizerWrapper: Tokenizer { func tokenize(text: String) -> [String] { tokenizer.tokenize(text: text) } - + func encode(text: String) -> [Int] { tokenizer.encode(text: text) } - + func decode(tokens: [Int]) -> String { tokenizer.decode(tokens: tokens) } - + func convertTokenToId(_ token: String) -> Int? { tokenizer.convertTokenToId(token) } - + func convertIdToToken(_ id: Int) -> String? { tokenizer.convertIdToToken(id) } - + var bosToken: String? { tokenizer.bosToken } - + var bosTokenId: Int? { tokenizer.bosTokenId } - + var eosToken: String? { tokenizer.eosToken } - + var eosTokenId: Int? { tokenizer.eosTokenId } - + var unknownToken: String? { tokenizer.unknownToken } - + var unknownTokenId: Int? { tokenizer.unknownTokenId } diff --git a/Sources/WhisperKit/Core/SegmentSeeker.swift b/Sources/WhisperKit/Core/SegmentSeeker.swift index eaa96eb..33a45dc 100644 --- a/Sources/WhisperKit/Core/SegmentSeeker.swift +++ b/Sources/WhisperKit/Core/SegmentSeeker.swift @@ -88,7 +88,6 @@ open class SegmentSeeker: SegmentSeeking { let singleTimestampEnding = lastThreeTokens == [false, true, false] let noTimestampEnding = lastThreeTokens == [false, false, false] - // find all end indexes of time token pairs var sliceIndexes = [Int]() @@ -297,7 +296,8 @@ open class SegmentSeeker: SegmentSeeking { // Check if the previous word starts with a whitespace character and is part of the prepended punctuations if let firstChar = previousWord.word.unicodeScalars.first, CharacterSet.whitespaces.contains(firstChar), - prepended.contains(previousWord.word.trimmingCharacters(in: .whitespaces)) { + prepended.contains(previousWord.word.trimmingCharacters(in: .whitespaces)) + { currentWord.word = previousWord.word + currentWord.word currentWord.tokens = previousWord.tokens + currentWord.tokens prependedAlignment[prependedAlignment.count - 1] = currentWord @@ -530,7 +530,8 @@ open class SegmentSeeker: SegmentSeeking { // Logic for the first word if firstWord.end - lastSpeechTimestamp > constrainedMedianDuration * 4 && (firstWord.end - firstWord.start > maxDuration || - (wordsInSegment.count > 1 && wordsInSegment[1].end - firstWord.start > maxDuration * 2)) { + (wordsInSegment.count > 1 && wordsInSegment[1].end - firstWord.start > maxDuration * 2)) + { if wordsInSegment.count > 1 && wordsInSegment[1].end - wordsInSegment[1].start > maxDuration { let boundary = max(wordsInSegment[1].end / 2, wordsInSegment[1].end - maxDuration) wordsInSegment[0].end = boundary @@ -555,7 +556,7 @@ open class SegmentSeeker: SegmentSeeking { lastSpeechTimestamp = updatedSegment.end } - + updatedSegment.words = wordsInSegment updatedSegments.append(updatedSegment) } diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index ca6eb25..c8b9205 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -16,7 +16,7 @@ public protocol TextDecoding { var kvCacheMaxSequenceLength: Int? { get } var windowSize: Int? { get } var embedSize: Int? { get } - + func predictLogits( inputIds: MLMultiArray, cacheLength: MLMultiArray, @@ -77,7 +77,6 @@ public protocol TextDecoding { @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public extension TextDecoding { - @available(*, deprecated, message: "Subject to removal in a future version. Use `decodeText(from:using:sampler:options:callback:) async throws -> DecodingResult` instead.") func decodeText( from encoderOutput: MLMultiArray, @@ -410,7 +409,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let prefilledIndex = 0 let currentTokens: [Int] = [tokenizer.specialTokens.startOfTranscriptToken] var logProbs: [Float] = Array(repeating: 0, count: prefilledIndex + 1) - + guard let logitsSize = logitsSize else { throw WhisperError.modelsUnavailable("Failed to read logits size from model") } @@ -423,7 +422,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel { LanguageLogitsFilter( allLanguageTokens: tokenizer.allLanguageTokens, logitsDim: logitsSize, - sampleBegin: prefilledIndex) + sampleBegin: prefilledIndex + ) ) let tokenIndex = 0 @@ -433,12 +433,12 @@ open class TextDecoder: TextDecoding, WhisperMLModel { // Set the current token as model input decoderInputs.inputIds[0] = NSNumber(value: nextToken) decoderInputs.cacheLength[0] = NSNumber(value: tokenIndex) - + // MARK: Decoding Inference - + // Predict next token let inferenceTime = Date() - + Logging.debug("Detecting language...") let predictedLogits = try await self.predictLogits( inputIds: decoderInputs.inputIds, @@ -454,30 +454,30 @@ open class TextDecoder: TextDecoding, WhisperMLModel { Logging.error("Unable to decode logits") throw WhisperError.decodingLogitsFailed() } - + let decodingInferenceTime = Date().timeIntervalSince(inferenceTime) timings.decodingPredictions += decodingInferenceTime - + // MARK: Non-inference - + // Update predicted token as current var logits = decoderOutput.logits! for filter in logitsFilters { logits = filter.filterLogits(logits, withTokens: currentTokens) } - + // MARK: Sampling - + let samplingStartTime = Date() - + let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) - + nextToken = sampleResult.tokens.last! logProbs = sampleResult.logProbs - + let samplingTime = Date().timeIntervalSince(samplingStartTime) timings.decodingSampling += samplingTime - + let detectedLanguage = tokenizer.decode(tokens: [nextToken]).dropFirst(2).dropLast(2) var decodingResult = DecodingResult.emptyResults decodingResult.timings = timings @@ -549,7 +549,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let isPrefill = tokenIndex < intialPromptIndex - 1 // Prefill stops at the last token of the initial prompt let isFirstToken = tokenIndex == prefilledIndex - + // Check if current index is part of the initial prompt if tokenIndex < intialPromptIndex { nextToken = currentTokens[tokenIndex] @@ -607,7 +607,6 @@ open class TextDecoder: TextDecoding, WhisperMLModel { nextToken = sampleResult.tokens.last! let nextTokenLogProb = sampleResult.logProbs.last! - logProbs = sampleResult.logProbs let samplingTime = Date().timeIntervalSince(samplingStartTime) timings.decodingSampling += samplingTime @@ -618,13 +617,16 @@ open class TextDecoder: TextDecoding, WhisperMLModel { } else { false } - let isSegmentCompleted = + let isSegmentCompleted = sampleResult.completed || currentTokens.count >= Constants.maxTokenContext - 1 || isFirstTokenLogProbTooLow if isSegmentCompleted { // Completed segment, stop the loop + timings.decodingNonPrediction += Date().timeIntervalSince(nonInferenceStartTime) + timings.decodingLoop += Date().timeIntervalSince(loopStart) + timings.totalDecodingLoops += 1 break } else { // MARK: KV Caching @@ -740,7 +742,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel { if options.language == nil { // Find the first token that is a recognized language token if let predictedLanguageIndex = filteredTokens.firstIndex(where: { tokenizer.allLanguageTokens.contains($0) }), - predictedLanguageIndex < tokenProbs.count { + predictedLanguageIndex < tokenProbs.count + { let predictedLanguageToken = filteredTokens[predictedLanguageIndex] // Decode the predicted language token to get the language language = tokenizer.decode(tokens: [predictedLanguageToken]).trimmingCharacters(in: CharacterSet(charactersIn: "<|>")) @@ -799,7 +802,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let formattedString = String(format: "%9.6f | %9.6f | %9.6f | %11.0f | %12.0f | %d", decoderInputs.keyCache[i].floatValue, decoderInputs.valueCache[i].floatValue, - decoderInputs.alignmentWeights[i*1500].floatValue, + decoderInputs.alignmentWeights[i * 1500].floatValue, decoderInputs.kvCacheUpdateMask[i].floatValue, decoderInputs.decoderKeyPaddingMask[i].floatValue, i) diff --git a/Sources/WhisperKit/Core/TranscribeTask.swift b/Sources/WhisperKit/Core/TranscribeTask.swift index 6dfd583..8393660 100644 --- a/Sources/WhisperKit/Core/TranscribeTask.swift +++ b/Sources/WhisperKit/Core/TranscribeTask.swift @@ -1,8 +1,8 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. -import Foundation import CoreML +import Foundation /// Responsible for transcribing audio chunk to text using the provided models and configurations. @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index 355dc06..37884d5 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -5,8 +5,8 @@ import AVFoundation import CoreML import Foundation import Hub -import Tokenizers import os.signpost +import Tokenizers #if canImport(UIKit) import UIKit #elseif canImport(AppKit) @@ -18,7 +18,7 @@ import AppKit extension Array { func chunked(into size: Int) -> [[Element]] { return stride(from: 0, to: count, by: size).map { - Array(self[$0 ..< Swift.min($0 + size, count)]) + Array(self[$0.. String { return inputPath } -func loadTokenizer( +public func loadTokenizer( for pretrained: ModelVariant, tokenizerFolder: URL? = nil, useBackgroundSession: Bool = false ) async throws -> WhisperTokenizer { let tokenizerName = tokenizerNameForVariant(pretrained) let hubApi = HubApi(downloadBase: tokenizerFolder, useBackgroundSession: useBackgroundSession) + + // Attempt to load tokenizer from local folder if specified + let resolvedTokenizerFolder = hubApi.localRepoLocation(HubApi.Repo(id: tokenizerName)) + let tokenizerConfigPath = resolvedTokenizerFolder.appendingPathComponent("tokenizer.json") + + // Check if 'tokenizer.json' exists in the folder + if FileManager.default.fileExists(atPath: tokenizerConfigPath.path) { + do { + let localConfig = LanguageModelConfigurationFromHub(modelFolder: resolvedTokenizerFolder, hubApi: hubApi) + if let tokenizerConfig = try await localConfig.tokenizerConfig { + let tokenizerData = try await localConfig.tokenizerData + let whisperTokenizer = try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) + Logging.debug("Loading tokenizer from local folder") + return WhisperTokenizerWrapper(tokenizer: whisperTokenizer) + } else { + // tokenizerConfig is nil, fall through to load from Hub + Logging.debug("Tokenizer configuration not found in local config") + } + } catch { + // Error during the local loading process and fall through to load from Hub + Logging.debug("Error loading local tokenizer: \(error)") + } + } + + // Fallback to loading from the Hub if local loading is not possible or fails + Logging.debug("Loading tokenizer from Hub") return try await WhisperTokenizerWrapper( tokenizer: AutoTokenizer.from( pretrained: tokenizerName, @@ -463,7 +489,7 @@ public func mergeTranscriptionResults(_ results: [TranscriptionResult?], confirm // Average first token times if let pipelineStart = validResults.first?.timings.pipelineStart { - let averageFirstTokenTime = validResults.map { (($0.timings.firstTokenTime) - ($0.timings.pipelineStart)) }.reduce(0, +) / Double(validResults.count) + let averageFirstTokenTime = validResults.map { ($0.timings.firstTokenTime) - ($0.timings.pipelineStart) }.reduce(0, +) / Double(validResults.count) mergedTimings.pipelineStart = pipelineStart mergedTimings.firstTokenTime = pipelineStart + averageFirstTokenTime } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 5bb7e2f..485d4d7 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -202,7 +202,7 @@ open class WhisperKit { callback(progress) } } - + let modelFolderName = modelFolder.appending(path: variantPath) return modelFolderName } catch { @@ -404,11 +404,11 @@ open class WhisperKit { var transcribeResultIndex = 0 for audioResult in loadedAudioResult { switch audioResult { - case .success: - result.append(transcribeResults[transcribeResultIndex]) - transcribeResultIndex += 1 - case .failure(let error): - result.append(.failure(error)) + case .success: + result.append(transcribeResults[transcribeResultIndex]) + transcribeResultIndex += 1 + case let .failure(error): + result.append(.failure(error)) } } return result @@ -526,7 +526,7 @@ open class WhisperKit { if self.modelState != .loaded { try await loadModels() } - + guard let tokenizer else { // Tokenizer required for decoding throw WhisperError.tokenizerUnavailable() diff --git a/Sources/WhisperKitCLI/CLIArguments.swift b/Sources/WhisperKitCLI/CLIArguments.swift index 113f24e..cc8116f 100644 --- a/Sources/WhisperKitCLI/CLIArguments.swift +++ b/Sources/WhisperKitCLI/CLIArguments.swift @@ -6,7 +6,7 @@ import ArgumentParser struct CLIArguments: ParsableArguments { @Option(help: "Paths to audio files") var audioPath = [String]() - + @Option(help: "Path to a folder containing audio files") var audioFolder: String? diff --git a/Sources/WhisperKitCLI/Transcribe.swift b/Sources/WhisperKitCLI/Transcribe.swift index 4790976..db9a75b 100644 --- a/Sources/WhisperKitCLI/Transcribe.swift +++ b/Sources/WhisperKitCLI/Transcribe.swift @@ -21,7 +21,7 @@ struct Transcribe: AsyncParsableCommand { throw ValidationError("Invalid language code \"\(language)\". Supported languages: \(Constants.languages.values)") } } - + if cliArguments.audioPath.isEmpty && !cliArguments.stream { guard let audioFolder = cliArguments.audioFolder else { throw ValidationError("Either audioPath or audioFolder must be provided.") @@ -33,7 +33,7 @@ struct Transcribe: AsyncParsableCommand { let fileExtension = fileName.lowercased().components(separatedBy: ".").last return audioExtensions.contains(fileExtension ?? "") } - + cliArguments.audioPath = audioFiles.map { audioFolder + "/" + $0 } } } @@ -75,7 +75,7 @@ struct Transcribe: AsyncParsableCommand { } var options = decodingOptions(task: task) - if let promptText = cliArguments.prompt, let tokenizer = whisperKit.tokenizer { + if let promptText = cliArguments.prompt, let tokenizer = whisperKit.tokenizer { options.promptTokens = tokenizer.encode(text: " " + promptText.trimmingCharacters(in: .whitespaces)).filter { $0 < tokenizer.specialTokens.specialTokenBegin } } @@ -90,10 +90,10 @@ struct Transcribe: AsyncParsableCommand { for (audioPath, result) in zip(resolvedAudioPaths, transcribeResult) { switch result { - case .success(let transcribeResult): - processTranscriptionResult(audioPath: audioPath, transcribeResult: transcribeResult.first) - case .failure(let error): - print("Error when transcribing \(audioPath): \(error)") + case let .success(transcribeResult): + processTranscriptionResult(audioPath: audioPath, transcribeResult: transcribeResult.first) + case let .failure(error): + print("Error when transcribing \(audioPath): \(error)") } } } diff --git a/Tests/WhisperKitTests/FunctionalTests.swift b/Tests/WhisperKitTests/FunctionalTests.swift index be9ede2..62824fd 100644 --- a/Tests/WhisperKitTests/FunctionalTests.swift +++ b/Tests/WhisperKitTests/FunctionalTests.swift @@ -125,21 +125,21 @@ final class FunctionalTests: XCTestCase { } func testBatchTranscribeAudioPaths() async throws { - let audioPaths = [ - try XCTUnwrap( + let audioPaths = try [ + XCTUnwrap( Bundle.module.path(forResource: "jfk", ofType: "wav"), "Audio file not found" ), - try XCTUnwrap( + XCTUnwrap( Bundle.module.path(forResource: "es_test_clip", ofType: "wav"), "Audio file not found" ), - try XCTUnwrap( + XCTUnwrap( Bundle.module.path(forResource: "ja_test_clip", ofType: "wav"), "Audio file not found" - ) + ), ] - let whisperKit = try await WhisperKit(modelFolder: try tinyModelPath()) + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath()) let transcriptionResults: [Result<[TranscriptionResult], Swift.Error>] = await whisperKit.transcribe(audioPaths: audioPaths) XCTAssertEqual(transcriptionResults.count, 3) @@ -159,15 +159,15 @@ final class FunctionalTests: XCTestCase { } func testBatchTranscribeAudioPathsWithErrors() async throws { - let audioPaths = [ + let audioPaths = try [ "/path/to/file1.wav", - try XCTUnwrap( + XCTUnwrap( Bundle.module.path(forResource: "jfk", ofType: "wav"), "Audio file not found" ), - "/path/to/file2.wav" + "/path/to/file2.wav", ] - let whisperKit = try await WhisperKit(modelFolder: try tinyModelPath()) + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath()) let transcriptionResults: [Result<[TranscriptionResult], Swift.Error>] = await whisperKit.transcribe(audioPaths: audioPaths) XCTAssertEqual(transcriptionResults.count, 3) @@ -186,25 +186,25 @@ final class FunctionalTests: XCTestCase { } func testBatchTranscribeAudioArrays() async throws { - let audioPaths = [ - try XCTUnwrap( + let audioPaths = try [ + XCTUnwrap( Bundle.module.path(forResource: "jfk", ofType: "wav"), "Audio file not found" ), - try XCTUnwrap( + XCTUnwrap( Bundle.module.path(forResource: "es_test_clip", ofType: "wav"), "Audio file not found" ), - try XCTUnwrap( + XCTUnwrap( Bundle.module.path(forResource: "ja_test_clip", ofType: "wav"), "Audio file not found" - ) + ), ] let audioArrays = try audioPaths .map { try AudioProcessor.loadAudio(fromPath: $0) } .map { AudioProcessor.convertBufferToArray(buffer: $0) } - let whisperKit = try await WhisperKit(modelFolder: try tinyModelPath()) + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath()) let transcriptionResults: [Result<[TranscriptionResult], Swift.Error>] = await whisperKit.transcribe(audioArrays: audioArrays) XCTAssertEqual(transcriptionResults.count, 3) diff --git a/Tests/WhisperKitTests/MemoryTestUtils.swift b/Tests/WhisperKitTests/MemoryTestUtils.swift index 4c64495..6a6f403 100644 --- a/Tests/WhisperKitTests/MemoryTestUtils.swift +++ b/Tests/WhisperKitTests/MemoryTestUtils.swift @@ -2,23 +2,25 @@ import Foundation import WhisperKit // MARK: RegressionStats + class RegressionStats: JSONCodable { let testInfo: TestInfo let memoryStats: MemoryStats let latencyStats: LatencyStats - + init(testInfo: TestInfo, memoryStats: MemoryStats, latencyStats: LatencyStats) { self.testInfo = testInfo self.memoryStats = memoryStats self.latencyStats = latencyStats } - + func jsonData() throws -> Data { return try JSONEncoder().encode(self) } } // MARK: TestInfo + class TestInfo: JSONCodable { let device, audioFile: String let model: String @@ -26,7 +28,7 @@ class TestInfo: JSONCodable { let timeElapsedInSeconds: TimeInterval let timings: TranscriptionTimings? let transcript: String? - + init(device: String, audioFile: String, model: String, date: String, timeElapsedInSeconds: TimeInterval, timings: TranscriptionTimings?, transcript: String?) { self.device = device self.audioFile = audioFile @@ -39,12 +41,13 @@ class TestInfo: JSONCodable { } // MARK: TestReport -struct TestReport: JSONCodable{ + +struct TestReport: JSONCodable { let device: String let modelsTested: [String] - let failureInfo: [String:String] - - init(device: String, modelsTested: [String], failureInfo: [String:String]) { + let failureInfo: [String: String] + + init(device: String, modelsTested: [String], failureInfo: [String: String]) { self.device = device self.modelsTested = modelsTested self.failureInfo = failureInfo @@ -52,20 +55,21 @@ struct TestReport: JSONCodable{ } // MARK: Stats + class Stats: JSONCodable { var measurements: [Measurement] let units: String var totalNumberOfMeasurements: Int - + init(measurements: [Measurement], units: String, totalNumberOfMeasurements: Int) { self.measurements = measurements self.units = units self.totalNumberOfMeasurements = totalNumberOfMeasurements } - - func measure(from values: [Float], timeElapsed: TimeInterval){ + + func measure(from values: [Float], timeElapsed: TimeInterval) { var measurement: Measurement - if let min = values.min(),let max = values.max(){ + if let min = values.min(), let max = values.max() { measurement = Measurement( min: min, max: max, @@ -80,65 +84,66 @@ class Stats: JSONCodable { } // MARK: LatencyStats -class LatencyStats: Stats{ + +class LatencyStats: Stats { override init(measurements: [Measurement] = [], units: String, totalNumberOfMeasurements: Int = 0) { super.init(measurements: measurements, units: units, totalNumberOfMeasurements: totalNumberOfMeasurements) } - + required init(from decoder: any Decoder) throws { fatalError("init(from:) has not been implemented") } - - func calculate(from total: Double, runs: Int) -> Double{ + + func calculate(from total: Double, runs: Int) -> Double { return runs > 0 ? total / Double(runs) : -1 } } -class MemoryStats: Stats{ +class MemoryStats: Stats { var preTranscribeMemory: Float var postTranscribeMemory: Float - + init(measurements: [Measurement] = [], units: String, totalNumberOfMeasurements: Int = 0, preTranscribeMemory: Float, postTranscribeMemory: Float) { self.preTranscribeMemory = preTranscribeMemory self.postTranscribeMemory = postTranscribeMemory super.init(measurements: measurements, units: units, totalNumberOfMeasurements: totalNumberOfMeasurements) } - + required init(from decoder: any Decoder) throws { fatalError("init(from:) has not been implemented") } - - // Implement the encode(to:) method + + /// Implement the encode(to:) method override func encode(to encoder: Encoder) throws { var container = encoder.container(keyedBy: CodingKeys.self) try super.encode(to: encoder) try container.encode(preTranscribeMemory, forKey: .preTranscribeMemory) try container.encode(postTranscribeMemory, forKey: .postTranscribeMemory) } - - // Coding keys for MemoryStats properties + + /// Coding keys for MemoryStats properties enum CodingKeys: String, CodingKey { case preTranscribeMemory case postTranscribeMemory } } -struct Measurement: JSONCodable{ +struct Measurement: JSONCodable { let min, max, average: Float let numberOfMeasurements: Int let timeElapsed: TimeInterval } -protocol JSONCodable: Codable { -} -extension JSONCodable{ +protocol JSONCodable: Codable {} + +extension JSONCodable { func jsonData() throws -> Data { return try JSONEncoder().encode(self) } } extension Data { - var prettyPrintedJSONString: NSString? { /// NSString gives us a nice sanitized debugDescription + var prettyPrintedJSONString: NSString? { // NSString gives us a nice sanitized debugDescription guard let object = try? JSONSerialization.jsonObject(with: self, options: []), let data = try? JSONSerialization.data(withJSONObject: object, options: [.prettyPrinted, .sortedKeys]), let prettyPrintedString = NSString(data: data, encoding: String.Encoding.utf8.rawValue) else { return nil } @@ -148,14 +153,14 @@ extension Data { } // MARK: - SystemMemoryChecker + @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -class SystemMemoryChecker: NSObject{ - +class SystemMemoryChecker: NSObject { static func getMemoryUsed() -> UInt64 { // The `TASK_VM_INFO_COUNT` and `TASK_VM_INFO_REV1_COUNT` macros are too // complex for the Swift C importer, so we have to define them ourselves. let TASK_VM_INFO_COUNT = mach_msg_type_number_t(MemoryLayout.size / MemoryLayout.size) - guard let offset = MemoryLayout.offset(of: \task_vm_info_data_t.min_address) else {return 0} + guard let offset = MemoryLayout.offset(of: \task_vm_info_data_t.min_address) else { return 0 } let TASK_VM_INFO_REV1_COUNT = mach_msg_type_number_t(offset / MemoryLayout.size) var info = task_vm_info_data_t() var count = TASK_VM_INFO_COUNT @@ -167,10 +172,10 @@ class SystemMemoryChecker: NSObject{ guard kr == KERN_SUCCESS, count >= TASK_VM_INFO_REV1_COUNT - else { return 0} - + else { return 0 } + let usedBytes = Float(info.phys_footprint) - let usedBytesInt: UInt64 = UInt64(usedBytes) + let usedBytesInt = UInt64(usedBytes) let usedMB = usedBytesInt / 1024 / 1024 return usedMB } diff --git a/Tests/WhisperKitTests/RegressionTests.swift b/Tests/WhisperKitTests/RegressionTests.swift index 18d18cb..6e10a02 100644 --- a/Tests/WhisperKitTests/RegressionTests.swift +++ b/Tests/WhisperKitTests/RegressionTests.swift @@ -5,13 +5,12 @@ import XCTest @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) final class RegressionTests: XCTestCase { - var audioFileURL: URL? - + override func setUp() { super.setUp() - - if self.audioFileURL == nil{ + + if self.audioFileURL == nil { let expectation = XCTestExpectation(description: "Download test audio") downloadTestAudio { success in if success { @@ -41,8 +40,8 @@ final class RegressionTests: XCTestCase { } } } - - func testAndMeasureModelPerformance(model: String, device: String) async throws{ + + func testAndMeasureModelPerformance(model: String, device: String) async throws { let audioFilePath = try XCTUnwrap( self.audioFileURL?.path(), "Audio file not found" @@ -50,10 +49,10 @@ final class RegressionTests: XCTestCase { let startTime = Date() let iso8601DateTimeString = ISO8601DateFormatter().string(from: Date()) - + var currentMemoryValues = [Float]() var currentTPSValues = [Float]() - + let memoryStats = MemoryStats( measurements: [], units: "MB", totalNumberOfMeasurements: 0, @@ -64,20 +63,20 @@ final class RegressionTests: XCTestCase { measurements: [], units: "Tokens/Sec", totalNumberOfMeasurements: 0 ) - var count: Int = 0 - + var count = 0 + let callback = { - (result:TranscriptionProgress) -> Bool in + (result: TranscriptionProgress) -> Bool in count += 1 let currentMemory = SystemMemoryChecker.getMemoryUsed() let currentTPS = result.timings.tokensPerSecond - if currentMemory != 0{ + if currentMemory != 0 { currentMemoryValues.append(Float(currentMemory)) } - if !currentTPS.isNaN{ + if !currentTPS.isNaN { currentTPSValues.append(Float(currentTPS)) } - if count % 100 == 1{ + if count % 100 == 1 { let timeElapsed = Date().timeIntervalSince(startTime) memoryStats.measure(from: currentMemoryValues, timeElapsed: timeElapsed) latencyStats.measure(from: currentTPSValues, timeElapsed: timeElapsed) @@ -86,16 +85,16 @@ final class RegressionTests: XCTestCase { } return true } - + let whisperKit = try await WhisperKit(model: model) memoryStats.preTranscribeMemory = Float(SystemMemoryChecker.getMemoryUsed()) - + let transcriptionResult = try await XCTUnwrapAsync( await whisperKit.transcribe(audioPath: audioFilePath, callback: callback), "Transcription failed" ) XCTAssert(transcriptionResult.text.isEmpty == false, "Transcription failed") - + memoryStats.postTranscribeMemory = Float(SystemMemoryChecker.getMemoryUsed()) let testInfo = TestInfo( device: device, @@ -107,51 +106,47 @@ final class RegressionTests: XCTestCase { transcript: transcriptionResult.text ) let json = RegressionStats(testInfo: testInfo, memoryStats: memoryStats, latencyStats: latencyStats) - do{ + do { let attachment = try XCTAttachment(data: json.jsonData(), uniformTypeIdentifier: "json") attachment.lifetime = .keepAlways attachment.name = "\(device)_\(model)_\(iso8601DateTimeString).json" add(attachment) - } - catch{ + } catch { XCTFail("Failed with error: \(error)") } } - - func testRegressionAndLatencyForAllModels() async throws{ + + func testRegressionAndLatencyForAllModels() async throws { var allModels: [String] = [] - var failureInfo: [String:String] = [:] + var failureInfo: [String: String] = [:] var currentDevice = WhisperKit.deviceName() let iso8601DateTimeString = ISO8601DateFormatter().string(from: Date()) - + #if os(macOS) && arch(arm64) currentDevice = Process.processor #endif - - do{ + + do { allModels = try await WhisperKit.fetchAvailableModels() - } - catch{ + } catch { XCTFail("Failed to fetch available models: \(error.localizedDescription)") } - - for model in allModels{ - do{ + + for model in allModels { + do { try await testAndMeasureModelPerformance(model: model, device: currentDevice) - } - catch{ + } catch { failureInfo[model] = error.localizedDescription } } let testReport = TestReport(device: currentDevice, modelsTested: allModels, failureInfo: failureInfo) - do{ + do { let attachment = try XCTAttachment(data: testReport.jsonData(), uniformTypeIdentifier: "json") attachment.lifetime = .keepAlways attachment.name = "\(currentDevice)_summary_\(iso8601DateTimeString).json" add(attachment) - }catch{ + } catch { XCTFail("Failed with error: \(error)") } } - } diff --git a/Tests/WhisperKitTests/TestUtils.swift b/Tests/WhisperKitTests/TestUtils.swift index 49f4507..62aeb4c 100644 --- a/Tests/WhisperKitTests/TestUtils.swift +++ b/Tests/WhisperKitTests/TestUtils.swift @@ -173,7 +173,7 @@ extension XCTestCase { if try isGitLFSPointerFile(url: proxyFileToCheck) { continue } - + // Check if the directory name contains the quantization pattern // Only test large quantized models let dirName = folderURL.lastPathComponent @@ -184,15 +184,16 @@ extension XCTestCase { } return modelPaths } - - // Function to check if the beginning of the file matches a Git LFS pointer pattern + + /// Function to check if the beginning of the file matches a Git LFS pointer pattern func isGitLFSPointerFile(url: URL) throws -> Bool { let fileHandle = try FileHandle(forReadingFrom: url) // Read the first few bytes of the file to get enough for the Git LFS pointer signature let data = fileHandle.readData(ofLength: 512) // Read first 512 bytes fileHandle.closeFile() if let string = String(data: data, encoding: .utf8), - string.starts(with: "version https://git-lfs.github.com/") { + string.starts(with: "version https://git-lfs.github.com/") + { return true } return false @@ -238,19 +239,19 @@ extension SpecialTokens { extension Result { var isSuccess: Bool { switch self { - case .success: - return true - case .failure: - return false + case .success: + return true + case .failure: + return false } } func whisperError() -> WhisperError? { switch self { - case .success: - return nil - case .failure(let error): - return error as? WhisperError + case .success: + return nil + case let .failure(error): + return error as? WhisperError } } } diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 62ac364..7627aa9 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -3,17 +3,16 @@ import AVFoundation import CoreML -import Tokenizers import Hub import NaturalLanguage +import Tokenizers @testable import WhisperKit import XCTest @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) final class UnitTests: XCTestCase { - // MARK: - Model Loading Test - + func testInit() async throws { try await XCTUnwrapAsync( await WhisperKit(prewarm: false, load: false, download: false), @@ -181,7 +180,7 @@ final class UnitTests: XCTestCase { ) ) } - + func testDecoderLogProbThresholdDecodingFallback() async throws { let decodingOptions = DecodingOptions( withoutTimestamps: true, @@ -598,7 +597,7 @@ final class UnitTests: XCTestCase { let languageCode = recognizer.dominantLanguage!.rawValue XCTAssertEqual( - languageCode, + languageCode, option.language, "Text language \"\(languageCode)\" at index \(i) did not match expected language \"\(option.language)\"" ) @@ -666,7 +665,7 @@ final class UnitTests: XCTestCase { func testTemperatureIncrement() async throws { let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) - + // Generate random audio samples let audioSamples = (0..<(30 * 16000)).map { _ in Float.random(in: -0.7...0.7) } @@ -842,13 +841,13 @@ final class UnitTests: XCTestCase { let result5 = tokensFilter5.filterLogits(logits5, withTokens: [1, 2, 3]) XCTAssertEqual(result5.data(for: 2), [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) } - - func testLanguageLogitsFilter() throws{ + + func testLanguageLogitsFilter() throws { let tokensFilter1 = LanguageLogitsFilter(allLanguageTokens: [2, 4, 6], logitsDim: 7, sampleBegin: 0) let logits1 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) let result1 = tokensFilter1.filterLogits(logits1, withTokens: []) XCTAssertEqual(result1.data(for: 2), [-.infinity, -.infinity, 0.3, -.infinity, 0.5, -.infinity, 0.7]) - + let tokensFilter2 = LanguageLogitsFilter(allLanguageTokens: [2, 4, 6], logitsDim: 7, sampleBegin: 0) let logits2 = try MLMultiArray.logits([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) let result2 = tokensFilter2.filterLogits(logits2, withTokens: [1]) @@ -907,7 +906,7 @@ final class UnitTests: XCTestCase { let logits1 = try MLMultiArray.logits([1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2]) let result1 = tokensFilter1.filterLogits(logits1, withTokens: []) XCTAssertEqual(result1.data(for: 2), [1.1, 5.2, 0.3, 0.4, 0.2, 0.1, 0.2]) - + let tokensFilter2 = TimestampRulesFilter( specialTokens: .default( endToken: 3, @@ -1238,7 +1237,7 @@ final class UnitTests: XCTestCase { WordTiming(word: " do", tokens: [360], start: 9.44, end: 9.64, probability: 0.87), WordTiming(word: " for", tokens: [337], start: 9.64, end: 9.86, probability: 0.95), WordTiming(word: " your", tokens: [428], start: 9.86, end: 10.06, probability: 0.96), - WordTiming(word: " country.", tokens: [1941, 13], start: 10.06, end: 10.5, probability: 0.91) + WordTiming(word: " country.", tokens: [1941, 13], start: 10.06, end: 10.5, probability: 0.91), ] XCTAssertEqual(wordTimings.count, expectedWordTimings.count, "Number of word timings should match") @@ -1261,7 +1260,7 @@ final class UnitTests: XCTestCase { let audioFile = "jfk.wav" let modelPath = try tinyModelPath() - let whisperKit = try await WhisperKit(modelFolder: modelPath,/* computeOptions: computeOptions,*/ verbose: true, logLevel: .debug) + let whisperKit = try await WhisperKit(modelFolder: modelPath, /* computeOptions: computeOptions,*/ verbose: true, logLevel: .debug) let startTime = Date() let audioComponents = audioFile.components(separatedBy: ".") @@ -1272,7 +1271,6 @@ final class UnitTests: XCTestCase { let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioFileURL) let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) - var results: [TranscriptionResult?] = [] var prevResult: TranscriptionResult? var lastAgreedSeconds: Float = 0.0 @@ -1284,7 +1282,7 @@ final class UnitTests: XCTestCase { for seekSample in stride(from: 0, to: audioArray.count, by: 32000) { let endSample = min(seekSample + 32000, audioArray.count) - Logging.info("[testStreamingTimestamps] \(lastAgreedSeconds)-\(Double(endSample)/16000.0) seconds") + Logging.info("[testStreamingTimestamps] \(lastAgreedSeconds)-\(Double(endSample) / 16000.0) seconds") let simulatedStreamingAudio = Array(audioArray[.. \(Double(endSample)/16000.0) \(currentWords)") + Logging.info("[testStreamingTimestamps] Current: \(lastAgreedSeconds) -> \(Double(endSample) / 16000.0) \(currentWords)") } else { Logging.info("[testStreamingTimestamps] Using same last agreed time \(lastAgreedSeconds)") skipAppend = true } - - } prevResult = result }