From 02763ca430f7190fc80d003b51a2e4510c35f9cc Mon Sep 17 00:00:00 2001 From: Zach Nagengast Date: Thu, 11 Jul 2024 11:55:58 -0700 Subject: [PATCH] Fix resampling large files (#183) * Update resampling logic to handle chunking properly * Cleanup logging * Optimize memory usage when resampling * Add filter to input prompt text * Correct timestamp filter logic for #170 * Filter out zero length segments - when calculating word timestamps - resolves #170 * Add method for async audio loading * Fix async load audio function * Fix tests * Fix tests * Fix tests * Revert timestamp filter changes * Temporarily remove xcpretty for tests * Check suspected test crash * Remove errant test case for japanese options * Add bigger range for early stopping test * Reset progress between runs * Fix progress resetting and improve example app transcription handling * Update tests * Minimize crash risk for early stop checks * Fix finalize text * Add source text to language label --- .github/workflows/unit-tests.yml | 2 +- .../WhisperAX.xcodeproj/project.pbxproj | 4 +- .../WhisperAX/Views/ContentView.swift | 103 +++++--- Sources/WhisperKit/Core/AudioProcessor.swift | 228 +++++++++++++----- Sources/WhisperKit/Core/Models.swift | 57 +++-- Sources/WhisperKit/Core/TextDecoder.swift | 16 +- Sources/WhisperKit/Core/TranscribeTask.swift | 10 +- Sources/WhisperKit/Core/Utils.swift | 108 ++++++++- Sources/WhisperKit/Core/WhisperKit.swift | 40 ++- Sources/WhisperKitCLI/TranscribeCLI.swift | 4 +- Tests/WhisperKitTests/FunctionalTests.swift | 29 --- Tests/WhisperKitTests/RegressionTests.swift | 29 +++ Tests/WhisperKitTests/UnitTests.swift | 75 +++++- 13 files changed, 536 insertions(+), 169 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 0af57b1..83ec13e 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -77,4 +77,4 @@ jobs: run: | set -o pipefail xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' | xcpretty - xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}' | xcpretty + xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}' diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj index 0ec5ed3..bfb9069 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj @@ -890,7 +890,7 @@ LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; MACOSX_DEPLOYMENT_TARGET = 14.0; - MARKETING_VERSION = 0.3.1; + MARKETING_VERSION = 0.3.2; PRODUCT_BUNDLE_IDENTIFIER = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}"; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; @@ -936,7 +936,7 @@ LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; MACOSX_DEPLOYMENT_TARGET = 14.0; - MARKETING_VERSION = 0.3.1; + MARKETING_VERSION = 0.3.2; PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 61c1496..1fa7096 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -97,7 +97,7 @@ struct ContentView: View { @State private var showAdvancedOptions: Bool = false @State private var transcriptionTask: Task? = nil @State private var selectedCategoryId: MenuItem.ID? - @State private var transcribeFileTask: Task? = nil + @State private var transcribeTask: Task? = nil struct MenuItem: Identifiable, Hashable { var id = UUID() @@ -122,7 +122,7 @@ struct ContentView: View { // MARK: Views func resetState() { - transcribeFileTask?.cancel() + transcribeTask?.cancel() isRecording = false isTranscribing = false whisperKit?.audioProcessor.stopRecording() @@ -311,15 +311,27 @@ struct ContentView: View { .textSelection(.enabled) .padding() if let whisperKit, - !isRecording, - !isTranscribing, - whisperKit.progress.fractionCompleted > 0, + !isStreamMode, + isTranscribing, + let task = transcribeTask, + !task.isCancelled, whisperKit.progress.fractionCompleted < 1 { - ProgressView(whisperKit.progress) - .progressViewStyle(.linear) - .labelsHidden() - .padding(.horizontal) + HStack { + ProgressView(whisperKit.progress) + .progressViewStyle(.linear) + .labelsHidden() + .padding(.horizontal) + + Button { + transcribeTask?.cancel() + transcribeTask = nil + } label: { + Image(systemName: "xmark.circle.fill") + .foregroundColor(.secondary) + } + .buttonStyle(BorderlessButtonStyle()) + } } } } @@ -706,7 +718,7 @@ struct ContentView: View { } .disabled(!(whisperKit?.modelVariant.isMultilingual ?? false)) } label: { - Label("Language", systemImage: "globe") + Label("Source Language", systemImage: "globe") } .padding(.horizontal) .padding(.top) @@ -1149,12 +1161,14 @@ struct ContentView: View { func transcribeFile(path: String) { resetState() whisperKit?.audioProcessor = AudioProcessor() - self.transcribeFileTask = Task { + self.transcribeTask = Task { + isTranscribing = true do { try await transcribeCurrentFile(path: path) } catch { print("File selection error: \(error.localizedDescription)") } + isTranscribing = false } } @@ -1218,12 +1232,34 @@ struct ContentView: View { // If not looping, transcribe the full buffer if !loop { - Task { + self.transcribeTask = Task { + isTranscribing = true do { try await transcribeCurrentBuffer() } catch { print("Error: \(error.localizedDescription)") } + finalizeText() + isTranscribing = false + } + } + + finalizeText() + } + + func finalizeText() { + // Finalize unconfirmed text + Task { + await MainActor.run { + if hypothesisText != "" { + confirmedText += hypothesisText + hypothesisText = "" + } + + if unconfirmedSegments.count > 0 { + confirmedSegments.append(contentsOf: unconfirmedSegments) + unconfirmedSegments = [] + } } } } @@ -1231,8 +1267,14 @@ struct ContentView: View { // MARK: - Transcribe Logic func transcribeCurrentFile(path: String) async throws { - let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path) - let audioFileSamples = AudioProcessor.convertBufferToArray(buffer: audioFileBuffer) + // Load and convert buffer in a limited scope + let audioFileSamples = try await Task { + try autoreleasepool { + let audioFileBuffer = try AudioProcessor.loadAudio(fromPath: path) + return AudioProcessor.convertBufferToArray(buffer: audioFileBuffer) + } + }.value + let transcription = try await transcribeAudioSamples(audioFileSamples) await MainActor.run { @@ -1258,7 +1300,7 @@ struct ContentView: View { let languageCode = Constants.languages[selectedLanguage, default: Constants.defaultLanguageCode] let task: DecodingTask = selectedTask == "transcribe" ? .transcribe : .translate - let seekClip: [Float] = [] + let seekClip: [Float] = [lastConfirmedSegmentEndSeconds] let options = DecodingOptions( verbose: true, @@ -1271,6 +1313,7 @@ struct ContentView: View { usePrefillCache: enableCachePrefill, skipSpecialTokens: !enableSpecialCharacters, withoutTimestamps: !enableTimestamps, + wordTimestamps: true, clipTimestamps: seekClip, chunkingStrategy: chunkingStrategy ) @@ -1279,7 +1322,7 @@ struct ContentView: View { let decodingCallback: ((TranscriptionProgress) -> Bool?) = { (progress: TranscriptionProgress) in DispatchQueue.main.async { let fallbacks = Int(progress.timings.totalDecodingFallbacks) - let chunkId = progress.windowId + let chunkId = isStreamMode ? 0 : progress.windowId // First check if this is a new window for the same chunk, append if so var updatedChunk = (chunkText: [progress.text], fallbacks: fallbacks) @@ -1292,7 +1335,7 @@ struct ContentView: View { // This is either a new window or a fallback (only in streaming mode) if fallbacks == currentChunk.fallbacks && isStreamMode { // New window (since fallbacks havent changed) - updatedChunk.chunkText = currentChunk.chunkText + [progress.text] + updatedChunk.chunkText = [updatedChunk.chunkText.first ?? "" + progress.text] } else { // Fallback, overwrite the previous bad text updatedChunk.chunkText[currentChunk.chunkText.endIndex - 1] = progress.text @@ -1419,6 +1462,7 @@ struct ContentView: View { // Run realtime transcribe using word timestamps for segmentation let transcription = try await transcribeEagerMode(Array(currentBuffer)) await MainActor.run { + currentText = "" self.tokensPerSecond = transcription?.timings.tokensPerSecond ?? 0 self.firstTokenTime = transcription?.timings.firstTokenTime ?? 0 self.pipelineStart = transcription?.timings.pipelineStart ?? 0 @@ -1464,10 +1508,13 @@ struct ContentView: View { // Update lastConfirmedSegmentEnd based on the last confirmed segment if let lastConfirmedSegment = confirmedSegmentsArray.last, lastConfirmedSegment.end > lastConfirmedSegmentEndSeconds { lastConfirmedSegmentEndSeconds = lastConfirmedSegment.end + print("Last confirmed segment end: \(lastConfirmedSegmentEndSeconds)") // Add confirmed segments to the confirmedSegments array - if !self.confirmedSegments.contains(confirmedSegmentsArray) { - self.confirmedSegments.append(contentsOf: confirmedSegmentsArray) + for segment in confirmedSegmentsArray { + if !self.confirmedSegments.contains(segment: segment) { + self.confirmedSegments.append(segment) + } } } @@ -1584,18 +1631,20 @@ struct ContentView: View { eagerResults.append(transcription) } } + + await MainActor.run { + let finalWords = confirmedWords.map { $0.word }.joined() + confirmedText = finalWords + + // Accept the final hypothesis because it is the last of the available audio + let lastHypothesis = lastAgreedWords + findLongestDifferentSuffix(prevWords, hypothesisWords) + hypothesisText = lastHypothesis.map { $0.word }.joined() + } } catch { Logging.error("[EagerMode] Error: \(error)") + finalizeText() } - await MainActor.run { - let finalWords = confirmedWords.map { $0.word }.joined() - confirmedText = finalWords - - // Accept the final hypothesis because it is the last of the available audio - let lastHypothesis = lastAgreedWords + findLongestDifferentSuffix(prevWords, hypothesisWords) - hypothesisText = lastHypothesis.map { $0.word }.joined() - } let mergedResult = mergeTranscriptionResults(eagerResults, confirmedWords: confirmedWords) diff --git a/Sources/WhisperKit/Core/AudioProcessor.swift b/Sources/WhisperKit/Core/AudioProcessor.swift index ff05b07..41fe096 100644 --- a/Sources/WhisperKit/Core/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/AudioProcessor.swift @@ -20,9 +20,12 @@ public struct AudioDevice: Identifiable, Hashable { public protocol AudioProcessing { /// Loads audio data from a specified file path. - /// - Parameter audioFilePath: The file path of the audio file. + /// - Parameters: + /// - audioFilePath: The file path of the audio file. + /// - startTime: Optional start time in seconds to read from + /// - endTime: Optional end time in seconds to read until /// - Returns: `AVAudioPCMBuffer` containing the audio data. - static func loadAudio(fromPath audioFilePath: String) throws -> AVAudioPCMBuffer + static func loadAudio(fromPath audioFilePath: String, startTime: Double?, endTime: Double?) throws -> AVAudioPCMBuffer /// Loads and converts audio data from a specified file paths. /// - Parameter audioPaths: The file paths of the audio files. @@ -71,6 +74,16 @@ public protocol AudioProcessing { /// Overrideable default methods for AudioProcessing public extension AudioProcessing { + /// Loads and converts audio data from a specified file paths. + /// - Parameter audioPaths: The file paths of the audio files. + /// - Returns: `AVAudioPCMBuffer` containing the audio data. + @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) + static func loadAudioAsync(fromPath audioFilePath: String) async throws -> AVAudioPCMBuffer { + return try await Task { + return try AudioProcessor.loadAudio(fromPath: audioFilePath) + }.value + } + func startRecordingLive(inputDeviceID: DeviceID? = nil, callback: (([Float]) -> Void)?) throws { try startRecordingLive(inputDeviceID: inputDeviceID, callback: callback) } @@ -172,7 +185,7 @@ public class AudioProcessor: NSObject, AudioProcessing { // MARK: - Loading and conversion - public static func loadAudio(fromPath audioFilePath: String) throws -> AVAudioPCMBuffer { + public static func loadAudio(fromPath audioFilePath: String, startTime: Double? = 0, endTime: Double? = nil) throws -> AVAudioPCMBuffer { guard FileManager.default.fileExists(atPath: audioFilePath) else { throw WhisperError.loadAudioFailed("Resource path does not exist \(audioFilePath)") } @@ -184,24 +197,44 @@ public class AudioProcessor: NSObject, AudioProcessing { let channelCount = audioFile.fileFormat.channelCount let frameLength = AVAudioFrameCount(audioFile.length) - let outputBuffer: AVAudioPCMBuffer + // Calculate the frame range based on the start and end seconds + let startFrame = AVAudioFramePosition((startTime ?? 0) * sampleRate) + let endFrame: AVAudioFramePosition + if let end = endTime { + endFrame = min(AVAudioFramePosition(end * sampleRate), AVAudioFramePosition(audioFile.length)) + } else { + endFrame = AVAudioFramePosition(audioFile.length) + } + + let frameCount = AVAudioFrameCount(endFrame - startFrame) + + // Seek to the start frame + audioFile.framePosition = startFrame + + var outputBuffer: AVAudioPCMBuffer? + // If the audio file already meets the desired format, read directly into the output buffer if sampleRate == 16000 && channelCount == 1 { - guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: frameLength) else { + guard let buffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: frameCount) else { throw WhisperError.loadAudioFailed("Unable to create audio buffer") } - try audioFile.read(into: buffer) + try audioFile.read(into: buffer, frameCount: frameCount) outputBuffer = buffer } else { // Audio needs resampling to 16khz - guard let buffer = resampleAudio(fromFile: audioFile, toSampleRate: 16000, channelCount: 1) else { - throw WhisperError.loadAudioFailed("Unable to resample audio") - } - outputBuffer = buffer + outputBuffer = resampleAudio(fromFile: audioFile, toSampleRate: 16000, channelCount: 1, frameCount: frameCount) + } + + if let outputBuffer = outputBuffer { + Logging.debug("Audio source details - Sample Rate: \(sampleRate) Hz, Channel Count: \(channelCount), Frame Length: \(frameLength), Duration: \(Double(frameLength) / sampleRate)s") + Logging.debug("Audio buffer details - Sample Rate: \(outputBuffer.format.sampleRate) Hz, Channel Count: \(outputBuffer.format.channelCount), Frame Length: \(outputBuffer.frameLength), Duration: \(Double(outputBuffer.frameLength) / outputBuffer.format.sampleRate)s") + + logCurrentMemoryUsage("After loadAudio function") + + return outputBuffer + } else { + throw WhisperError.loadAudioFailed("Failed to process audio buffer") } - Logging.debug("Audio source details - Sample Rate: \(sampleRate) Hz, Channel Count: \(channelCount), Frame Length: \(frameLength), Duration: \(Double(frameLength) / sampleRate)s") - Logging.debug("Audio buffer details - Sample Rate: \(outputBuffer.format.sampleRate) Hz, Channel Count: \(outputBuffer.format.channelCount), Frame Length: \(outputBuffer.frameLength), Duration: \(Double(outputBuffer.frameLength) / outputBuffer.format.sampleRate)s") - return outputBuffer } public static func loadAudio(at audioPaths: [String]) async -> [Result<[Float], Swift.Error>] { @@ -226,59 +259,102 @@ public class AudioProcessor: NSObject, AudioProcessing { } } - public static func resampleAudio(fromFile audioFile: AVAudioFile, toSampleRate sampleRate: Double, channelCount: AVAudioChannelCount) -> AVAudioPCMBuffer? { - let newFrameLength = Int64((sampleRate / audioFile.fileFormat.sampleRate) * Double(audioFile.length)) - let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: channelCount)! - guard let converter = AVAudioConverter(from: audioFile.processingFormat, to: outputFormat) else { - Logging.error("Failed to create audio converter") + /// Resamples audio from a file to a specified sample rate and channel count. + /// - Parameters: + /// - audioFile: The input audio file. + /// - sampleRate: The desired output sample rate. + /// - channelCount: The desired output channel count. + /// - frameCount: The desired frames to read from the input audio file. (default: all). + /// - maxReadFrameSize: Maximum number of frames to read at once (default: 10 million). + /// - Returns: Resampled audio as an AVAudioPCMBuffer, or nil if resampling fails. + public static func resampleAudio( + fromFile audioFile: AVAudioFile, + toSampleRate sampleRate: Double, + channelCount: AVAudioChannelCount, + frameCount: AVAudioFrameCount? = nil, + maxReadFrameSize: AVAudioFrameCount = 1_323_000 // 30s of audio at commonly found 44.1khz sample rate + ) -> AVAudioPCMBuffer? { + let inputFormat = audioFile.fileFormat + let inputFrameCount = frameCount ?? AVAudioFrameCount(audioFile.length) + let inputDuration = Double(inputFrameCount) / inputFormat.sampleRate + + guard let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: channelCount) else { + Logging.error("Failed to create output audio format") return nil } - let frameCount = AVAudioFrameCount(audioFile.length) + Logging.debug("Resampling \(String(format: "%.2f", inputDuration)) seconds of audio") - // Read audio in 100mb increments to reduce the memory spike for large audio files - let maxReadFrameSize: AVAudioFrameCount = 100_000_000 - guard let inputBuffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: min(frameCount, maxReadFrameSize)), - let outputBuffer = AVAudioPCMBuffer(pcmFormat: outputFormat, frameCapacity: AVAudioFrameCount(newFrameLength)) - else { - Logging.error("Unable to create buffers, likely due to unsupported file format") + // Create the output buffer with full capacity + guard let outputBuffer = AVAudioPCMBuffer(pcmFormat: outputFormat, frameCapacity: AVAudioFrameCount(inputDuration * outputFormat.sampleRate)) else { + Logging.error("Failed to create output buffer") return nil } - while audioFile.framePosition < frameCount { - do { - let maxReadFrameCount = min(frameCount - UInt32(audioFile.framePosition), maxReadFrameSize) - try audioFile.read(into: inputBuffer, frameCount: maxReadFrameCount) - } catch { - Logging.error("Error reading audio file: \(error)") - return nil - } + let inputBuffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: maxReadFrameSize)! + + while audioFile.framePosition < inputFrameCount { + let remainingFrames = inputFrameCount - AVAudioFrameCount(audioFile.framePosition) + let framesToRead = min(remainingFrames, maxReadFrameSize) - let inputBlock: AVAudioConverterInputBlock = { _, outStatus in - if inputBuffer.frameLength == 0 { - outStatus.pointee = .endOfStream + let currentPositionInSeconds = Double(audioFile.framePosition) / inputFormat.sampleRate + let nextPositionInSeconds = (Double(audioFile.framePosition) + Double(framesToRead)) / inputFormat.sampleRate + Logging.debug("Resampling \(String(format: "%.2f", currentPositionInSeconds))s - \(String(format: "%.2f", nextPositionInSeconds))s") + + do { + try audioFile.read(into: inputBuffer, frameCount: framesToRead) + guard let resampledChunk = resampleAudio(fromBuffer: inputBuffer, + toSampleRate: outputFormat.sampleRate, + channelCount: outputFormat.channelCount) else { + Logging.error("Failed to resample audio chunk") return nil - } else { - outStatus.pointee = .haveData - return inputBuffer } - } - var error: NSError? - let status = converter.convert(to: outputBuffer, error: &error, withInputFrom: inputBlock) - switch status { - case .error: - if let conversionError = error { - Logging.error("Error converting audio file: \(conversionError)") + // Append the resampled chunk to the output buffer + guard outputBuffer.appendContents(of: resampledChunk) else { + Logging.error("Failed to append audio chunk") + return nil } + } catch { + Logging.error("Error reading audio file: \(error)") return nil - default: break } } return outputBuffer } + /// Resamples an audio buffer to a specified sample rate and channel count. + /// - Parameters: + /// - inputBuffer: The input audio buffer. + /// - sampleRate: The desired output sample rate. + /// - channelCount: The desired output channel count. + /// - Returns: Resampled audio as an AVAudioPCMBuffer, or nil if resampling fails. + public static func resampleAudio(fromBuffer inputBuffer: AVAudioPCMBuffer, toSampleRate sampleRate: Double, channelCount: AVAudioChannelCount) -> AVAudioPCMBuffer? { + guard let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: channelCount) else { + Logging.error("Failed to create output audio format") + return nil + } + + guard let converter = AVAudioConverter(from: inputBuffer.format, to: outputFormat) else { + Logging.error("Failed to create audio converter") + return nil + } + + do { + return try Self.resampleBuffer(inputBuffer, with: converter) + } catch { + Logging.error("Failed to resample buffer: \(error)") + return nil + } + } + + /// Resamples an audio buffer using the provided converter. + /// - Parameters: + /// - buffer: The input audio buffer. + /// - converter: The audio converter to use for resampling. + /// - Returns: Resampled audio as an AVAudioPCMBuffer. + /// - Throws: WhisperError if resampling fails. public static func resampleBuffer(_ buffer: AVAudioPCMBuffer, with converter: AVAudioConverter) throws -> AVAudioPCMBuffer { guard let convertedBuffer = AVAudioPCMBuffer( pcmFormat: converter.outputFormat, @@ -288,11 +364,22 @@ public class AudioProcessor: NSObject, AudioProcessing { } let inputBlock: AVAudioConverterInputBlock = { _, outStatus in - outStatus.pointee = .haveData - return buffer + if buffer.frameLength == 0 { + outStatus.pointee = .endOfStream + return nil + } else { + outStatus.pointee = .haveData + return buffer + } + } + + var error: NSError? + let status = converter.convert(to: convertedBuffer, error: &error, withInputFrom: inputBlock) + + if status == .error, let conversionError = error { + throw WhisperError.audioProcessingFailed("Error converting audio: \(conversionError)") } - converter.convert(to: convertedBuffer, error: nil, withInputFrom: inputBlock) return convertedBuffer } @@ -412,11 +499,42 @@ public class AudioProcessor: NSObject, AudioProcessing { return max(0, min(normalizedEnergy, 1)) } - public static func convertBufferToArray(buffer: AVAudioPCMBuffer) -> [Float] { - let start = buffer.floatChannelData?[0] - let count = Int(buffer.frameLength) - let convertedArray = Array(UnsafeBufferPointer(start: start, count: count)) - return convertedArray + public static func convertBufferToArray(buffer: AVAudioPCMBuffer, chunkSize: Int = 1024) -> [Float] { + guard let channelData = buffer.floatChannelData else { + return [] + } + + let frameLength = Int(buffer.frameLength) + let startPointer = channelData[0] + + var result: [Float] = [] + result.reserveCapacity(frameLength) // Reserve the capacity to avoid multiple allocations + + var currentFrame = 0 + while currentFrame < frameLength { + let remainingFrames = frameLength - currentFrame + let currentChunkSize = min(chunkSize, remainingFrames) + + var chunk = [Float](repeating: 0, count: currentChunkSize) + + chunk.withUnsafeMutableBufferPointer { bufferPointer in + vDSP_mmov( + startPointer.advanced(by: currentFrame), + bufferPointer.baseAddress!, + vDSP_Length(currentChunkSize), + 1, + vDSP_Length(currentChunkSize), + 1 + ) + } + + result.append(contentsOf: chunk) + currentFrame += currentChunkSize + + memset(startPointer.advanced(by: currentFrame - currentChunkSize), 0, currentChunkSize * MemoryLayout.size) + } + + return result } public static func requestRecordPermission() async -> Bool { diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 50247eb..0911d1b 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -520,35 +520,34 @@ public struct TranscriptionResult: Codable { let decodingLoopInfo = formatTimeWithPercentage(timings.decodingLoop, totalLoops, fullDecodingDuration) // Logging - Logging.info("---- Transcription Timings ----") - - Logging.info("Audio Load: \(audioLoadTime)") - Logging.info("Audio Processing: \(audioProcTime)") - Logging.info("Mels: \(logmelsTime)") - Logging.info("Encoding: \(encodingTime)") - Logging.info("Matrices Init: \(decodingInitTime)") - Logging.info("Prefill: \(prefillInfo)") - Logging.info("Decoding: \(predictionsInfo)") - Logging.info("Non-inference: \(nonPredTimeInfo)") - Logging.info("- Logit Filtering: \(filteringInfo)") - Logging.info("- Sampling: \(samplingInfo)") - Logging.info("- Kv Caching: \(kvCachingInfo)") - Logging.info("- Word Timestamps: \(wordTimestampInfo)") - Logging.info("- Windowing: \(windowingInfo)") - Logging.info("Fallbacks: \(fallbackInfo)") - Logging.info("Decoding Full Loop: \(decodingLoopInfo)") - Logging.info("-------------------------------") - - // Summary statistics - Logging.info("Model Load Time: \(String(format: "%.2f", timings.modelLoading)) seconds") - Logging.info("Inference Duration (Global): \(String(format: "%.2f", timings.fullPipeline)) seconds") - Logging.info("- Decoding Loop (Avg/window): \(String(format: "%.2f", decodeTimePerWindow)) seconds") - Logging.info("- Audio Windows: \(String(format: "%.2f", timings.totalAudioProcessingRuns))") - Logging.info("Time to first token: \(String(format: "%.2f", timeToFirstToken)) seconds") - Logging.info("Total Tokens: \(totalTokens)") - Logging.info("Tokens per Second: \(String(format: "%.2f", tokensPerSecond)) tok/s") - Logging.info("Real Time Factor: \(String(format: "%.3f", rtf))") - Logging.info("Fallbacks: \(timings.totalDecodingFallbacks)") + Logging.info(""" + ---- Transcription Timings ---- + Audio Load: \(audioLoadTime) + Audio Processing: \(audioProcTime) + Mels: \(logmelsTime) + Encoding: \(encodingTime) + Matrices Init: \(decodingInitTime) + Prefill: \(prefillInfo) + Decoding: \(predictionsInfo) + Non-inference: \(nonPredTimeInfo) + - Logit Filtering: \(filteringInfo) + - Sampling: \(samplingInfo) + - Kv Caching: \(kvCachingInfo) + - Word Timestamps: \(wordTimestampInfo) + - Windowing: \(windowingInfo) + Fallbacks: \(fallbackInfo) + Decoding Full Loop: \(decodingLoopInfo) + ------------------------------- + Model Load Time: \(String(format: "%.2f", timings.modelLoading)) seconds + Inference Duration (Global): \(String(format: "%.2f", timings.fullPipeline)) seconds + - Decoding Loop (Avg/window): \(String(format: "%.2f", decodeTimePerWindow)) seconds + - Audio Windows: \(String(format: "%.2f", timings.totalAudioProcessingRuns)) + Time to first token: \(String(format: "%.2f", timeToFirstToken)) seconds + Total Tokens: \(totalTokens) + Tokens per Second: \(String(format: "%.2f", tokensPerSecond)) tok/s + Real Time Factor: \(String(format: "%.3f", rtf)) + Fallbacks: \(timings.totalDecodingFallbacks) + """) } } diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 1866179..9ea7edf 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -204,13 +204,13 @@ public extension TextDecoding { // Add prompt tokens if let promptTokens = options.promptTokens { let maxPromptLen = (Constants.maxTokenContext / 2) - 1 - let trimmedPromptTokens = Array(promptTokens.suffix(maxPromptLen)) + let trimmedPromptTokens = Array(promptTokens.suffix(maxPromptLen)).filter { $0 < tokenizer.specialTokens.specialTokenBegin } prefillTokens = [tokenizer.specialTokens.startOfPreviousToken] + trimmedPromptTokens + prefillTokens } // Add prefix tokens if let prefixTokens = options.prefixTokens { - let trimmedPrefixTokens = Array(prefixTokens.suffix(Constants.maxTokenContext / 2)) + let trimmedPrefixTokens = Array(prefixTokens.suffix(Constants.maxTokenContext / 2)).filter { $0 < tokenizer.specialTokens.specialTokenBegin } prefillTokens.append(contentsOf: trimmedPrefixTokens) } } @@ -589,7 +589,9 @@ open class TextDecoder: TextDecoding, WhisperMLModel { var hasAlignment = false var isFirstTokenLogProbTooLow = false let windowUUID = UUID() - shouldEarlyStop[windowUUID] = false + DispatchQueue.main.sync { + shouldEarlyStop[windowUUID] = false + } for tokenIndex in prefilledIndex.. $0.start } + // Update seek point with new (more accurate) segments if let lastSpeechTimestamp = currentSegments?.last?.end { seek = max(seek, Int(lastSpeechTimestamp * Float(WhisperKit.sampleRate))) @@ -242,11 +244,17 @@ final class TranscribeTask { maxTokenContext: decodeOptions?.sampleLength ?? Constants.maxTokenContext ) + // Update the progress let clipProgress = min(seek, seekClipEnd) - seekClipStart progress.completedUnitCount = previousSeekProgress + Int64(clipProgress) } } + // Transcription completed + progress.completedUnitCount = progress.totalUnitCount + + // MARK: - Decode with Fallback Logic + func decodeWithFallback( encoderSegment encoderOutput: MLMultiArray, decodingOptions options: DecodingOptions, diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index e4b1f68..6e14944 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -31,6 +31,12 @@ extension Array where Element == Result<[TranscriptionResult], Swift.Error> { } } +public extension Array where Element == TranscriptionSegment { + func contains(segment: TranscriptionSegment) -> Bool { + return self.contains { $0.start == segment.start } + } +} + extension MLMultiArray { /// Calculate the linear offset by summing the products of each dimension’s index with the dimension’s stride. /// More info [here](https://developer.apple.com/documentation/coreml/mlmultiarray/2879231-subscript) @@ -181,6 +187,74 @@ extension String { } } +extension AVAudioPCMBuffer { + // Appends the contents of another buffer to the current buffer + func appendContents(of buffer: AVAudioPCMBuffer) -> Bool { + return appendContents(of: buffer, startingFrame: 0, frameCount: buffer.frameLength) + } + + // Appends a specific range of frames from another buffer to the current buffer + func appendContents(of buffer: AVAudioPCMBuffer, startingFrame: AVAudioFramePosition, frameCount: AVAudioFrameCount) -> Bool { + guard format == buffer.format else { + Logging.debug("Format mismatch") + return false + } + + guard startingFrame + AVAudioFramePosition(frameCount) <= AVAudioFramePosition(buffer.frameLength) else { + Logging.debug("Insufficient audio in buffer") + return false + } + + guard frameLength + frameCount <= frameCapacity else { + Logging.debug("Insufficient space in buffer") + return false + } + + guard let destination = floatChannelData, let source = buffer.floatChannelData else { + Logging.debug("Failed to access float channel data") + return false + } + + let calculatedStride = stride + let destinationPointer = destination.pointee.advanced(by: calculatedStride * Int(frameLength)) + let sourcePointer = source.pointee.advanced(by: calculatedStride * Int(startingFrame)) + + memcpy(destinationPointer, sourcePointer, Int(frameCount) * calculatedStride * MemoryLayout.size) + + frameLength += frameCount + return true + } + + // Convenience initializer to concatenate multiple buffers into one + convenience init?(concatenating buffers: [AVAudioPCMBuffer]) { + guard !buffers.isEmpty else { + Logging.debug("Buffers array should not be empty") + return nil + } + + let totalFrames = buffers.reduce(0) { $0 + $1.frameLength } + + guard let firstBuffer = buffers.first else { + Logging.debug("Failed to get the first buffer") + return nil + } + + self.init(pcmFormat: firstBuffer.format, frameCapacity: totalFrames) + + for buffer in buffers { + if !appendContents(of: buffer) { + Logging.debug("Failed to append buffer") + return nil + } + } + } + + // Computed property to determine the stride for float channel data + private var stride: Int { + return Int(format.streamDescription.pointee.mBytesPerFrame) / MemoryLayout.size + } +} + // MARK: - Helpers @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) @@ -666,6 +740,28 @@ public func compressionRatio(of text: String) -> Float { } } +public func logCurrentMemoryUsage(_ message: String) { + let memoryUsage = getMemoryUsage() + Logging.debug("\(message) - Memory usage: \(memoryUsage) MB") +} + +public func getMemoryUsage() -> UInt64 { + var info = mach_task_basic_info() + var count = mach_msg_type_number_t(MemoryLayout.size) / 4 + + let kerr: kern_return_t = withUnsafeMutablePointer(to: &info) { + $0.withMemoryRebound(to: integer_t.self, capacity: 1) { + task_info(mach_task_self_, task_flavor_t(MACH_TASK_BASIC_INFO), $0, &count) + } + } + + guard kerr == KERN_SUCCESS else { + return 0 // If the call fails, return 0 + } + + return info.resident_size / 1024 / 1024 // Convert to MB +} + // MARK: - Singletons public class Logging { @@ -675,6 +771,8 @@ public class Logging { public typealias LoggingCallback = (_ message: String) -> Void var loggingCallback: LoggingCallback? + private let logger = OSLog(subsystem: Bundle.main.bundleIdentifier ?? "com.argmax.whisperkit", category: "WhisperKit") + public enum LogLevel: Int { case debug = 1 case info = 2 @@ -688,30 +786,30 @@ public class Logging { private init() {} - public func log(_ items: Any..., separator: String = " ", terminator: String = "\n") { + public func log(_ items: Any..., separator: String = " ", terminator: String = "\n", type: OSLogType) { let message = items.map { "\($0)" }.joined(separator: separator) if let logger = loggingCallback { logger(message) } else { - print("[WhisperKit] \(message)", terminator: terminator) + os_log("%{public}@", log: logger, type: type, message) } } public static func debug(_ items: Any..., separator: String = " ", terminator: String = "\n") { if shared.logLevel.shouldLog(level: .debug) { - shared.log(items, separator: separator, terminator: terminator) + shared.log(items, separator: separator, terminator: terminator, type: .debug) } } public static func info(_ items: Any..., separator: String = " ", terminator: String = "\n") { if shared.logLevel.shouldLog(level: .info) { - shared.log(items, separator: separator, terminator: terminator) + shared.log(items, separator: separator, terminator: terminator, type: .info) } } public static func error(_ items: Any..., separator: String = " ", terminator: String = "\n") { if shared.logLevel.shouldLog(level: .error) { - shared.log(items, separator: separator, terminator: terminator) + shared.log(items, separator: separator, terminator: terminator, type: .error) } } } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index d0b513f..ea74946 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -34,7 +34,7 @@ open class WhisperKit { /// Progress public private(set) var currentTimings: TranscriptionTimings - public let progress = Progress() + public private(set) var progress = Progress() /// Configuration public var modelFolder: URL? @@ -624,6 +624,7 @@ open class WhisperKit { // Append the results of each batch to the final result array result.append(contentsOf: partialResult) } + return result } @@ -767,11 +768,11 @@ open class WhisperKit { // Tokenizer required for decoding throw WhisperError.tokenizerUnavailable() } - try Task.checkCancellation() let childProgress = Progress() progress.totalUnitCount += 1 progress.addChild(childProgress, withPendingUnitCount: 1) + let transcribeTask = TranscribeTask( currentTimings: currentTimings, progress: childProgress, @@ -781,14 +782,33 @@ open class WhisperKit { textDecoder: textDecoder, tokenizer: tokenizer ) - let transcribeTaskResult = try await transcribeTask.run( - audioArray: audioArray, - decodeOptions: decodeOptions, - callback: callback - ) - if let decodeOptions, decodeOptions.verbose { - transcribeTaskResult.logTimings() + + do { + try Task.checkCancellation() + + let transcribeTaskResult = try await transcribeTask.run( + audioArray: audioArray, + decodeOptions: decodeOptions, + callback: callback + ) + + if let decodeOptions, decodeOptions.verbose { + transcribeTaskResult.logTimings() + } + + if progress.isFinished { + // Reset progress if it is completed + progress = Progress() + } + + return [transcribeTaskResult] + } catch { + // Handle cancellation + if error is CancellationError { + // Reset progress when cancelled + progress = Progress() + } + throw error } - return [transcribeTaskResult] } } diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift index 81e27bd..578645c 100644 --- a/Sources/WhisperKitCLI/TranscribeCLI.swift +++ b/Sources/WhisperKitCLI/TranscribeCLI.swift @@ -82,12 +82,12 @@ struct TranscribeCLI: AsyncParsableCommand { } var options = decodingOptions(task: task) - if let promptText = cliArguments.prompt, let tokenizer = whisperKit.tokenizer { + if let promptText = cliArguments.prompt, promptText.count > 0, let tokenizer = whisperKit.tokenizer { options.promptTokens = tokenizer.encode(text: " " + promptText.trimmingCharacters(in: .whitespaces)).filter { $0 < tokenizer.specialTokens.specialTokenBegin } options.usePrefillPrompt = true } - if let prefixText = cliArguments.prefix, let tokenizer = whisperKit.tokenizer { + if let prefixText = cliArguments.prefix, prefixText.count > 0, let tokenizer = whisperKit.tokenizer { options.prefixTokens = tokenizer.encode(text: " " + prefixText.trimmingCharacters(in: .whitespaces)).filter { $0 < tokenizer.specialTokens.specialTokenBegin } options.usePrefillPrompt = true } diff --git a/Tests/WhisperKitTests/FunctionalTests.swift b/Tests/WhisperKitTests/FunctionalTests.swift index 16cba98..3cd2246 100644 --- a/Tests/WhisperKitTests/FunctionalTests.swift +++ b/Tests/WhisperKitTests/FunctionalTests.swift @@ -13,35 +13,6 @@ final class FunctionalTests: XCTestCase { ) } - func testOutputAll() async throws { - let modelPaths = try allModelPaths() - - for modelPath in modelPaths { - let modelName = modelPath.split(separator: "/").last! - print("[Integration] Testing model \(modelName)") - let audioFilePath = try XCTUnwrap( - Bundle.module.path(forResource: "jfk", ofType: "wav"), - "Audio file not found" - ) - - let whisperKit = try await WhisperKit( - modelFolder: modelPath, - verbose: true, - logLevel: .debug - ) - - let transcriptionResult: [TranscriptionResult] = try await whisperKit.transcribe(audioPath: audioFilePath) - let transcriptionResultText = transcriptionResult.text - - print("[Integration] \(transcriptionResultText)") - XCTAssertEqual( - transcriptionResultText.normalized, - " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.".normalized, - "Transcription result does not match expected result for model \(modelName)" - ) - } - } - func testRealTimeFactorTiny() async throws { let modelPath = try tinyModelPath() diff --git a/Tests/WhisperKitTests/RegressionTests.swift b/Tests/WhisperKitTests/RegressionTests.swift index 68778e7..15f90ea 100644 --- a/Tests/WhisperKitTests/RegressionTests.swift +++ b/Tests/WhisperKitTests/RegressionTests.swift @@ -116,6 +116,35 @@ final class RegressionTests: XCTestCase { } } + func testOutputAll() async throws { + let modelPaths = try allModelPaths() + + for modelPath in modelPaths { + let modelName = modelPath.split(separator: "/").last! + print("[Integration] Testing model \(modelName)") + let audioFilePath = try XCTUnwrap( + Bundle.module.path(forResource: "jfk", ofType: "wav"), + "Audio file not found" + ) + + let whisperKit = try await WhisperKit( + modelFolder: modelPath, + verbose: true, + logLevel: .debug + ) + + let transcriptionResult: [TranscriptionResult] = try await whisperKit.transcribe(audioPath: audioFilePath) + let transcriptionResultText = transcriptionResult.text + + print("[Integration] \(transcriptionResultText)") + XCTAssertEqual( + transcriptionResultText.normalized, + " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.".normalized, + "Transcription result does not match expected result for model \(modelName)" + ) + } + } + func testRegressionAndLatencyForAllModels() async throws { var allModels: [String] = [] var failureInfo: [String: String] = [:] diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index c096642..b30cad4 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -39,6 +39,16 @@ final class UnitTests: XCTestCase { XCTAssertNotNil(audioBuffer, "Failed to load audio file at path: \(audioFilePath)") XCTAssertEqual(audioBuffer.format.sampleRate, 16000) XCTAssertEqual(audioBuffer.format.channelCount, 1) + XCTAssertEqual(audioBuffer.frameLength, 176000) + XCTAssertEqual(audioBuffer.frameLength, 11 * 16000) + + let audioBufferWithStartTime = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2) + XCTAssertEqual(audioBufferWithStartTime.frameLength, AVAudioFrameCount(156800)) + XCTAssertEqual(audioBufferWithStartTime.frameLength, AVAudioFrameCount(16000 * (11 - 1.2))) + + let audioBufferWithStartTimeAndEndTime = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2, endTime: 3.4) + XCTAssertEqual(audioBufferWithStartTimeAndEndTime.frameLength, AVAudioFrameCount(35200)) + XCTAssertEqual(audioBufferWithStartTimeAndEndTime.frameLength, AVAudioFrameCount(16000 * (3.4 - 1.2))) } func testAudioPad() { @@ -74,6 +84,62 @@ final class UnitTests: XCTestCase { XCTAssertEqual(resampledAudio?.format.channelCount, targetChannelCount, "Resampled audio channels is not as expected") } + func testAudioResampleFromFile() throws { + Logging.shared.logLevel = .debug + + let audioFileURL = try XCTUnwrap( + Bundle.module.url(forResource: "jfk", withExtension: "wav"), + "Audio file not found" + ) + let audioFile = try AVAudioFile(forReading: audioFileURL) + + let targetSampleRate = 16000.0 + let targetChannelCount: AVAudioChannelCount = 1 + let smallMaxReadFrameSize: AVAudioFrameCount = 10_000 // Small chunk size to test chunking logic + + let resampledAudio = AudioProcessor.resampleAudio( + fromFile: audioFile, + toSampleRate: targetSampleRate, + channelCount: targetChannelCount, + maxReadFrameSize: smallMaxReadFrameSize + ) + + XCTAssertNotNil(resampledAudio, "Failed to resample audio with small chunks") + XCTAssertEqual(resampledAudio?.format.sampleRate, targetSampleRate, "Resampled audio sample rate is not as expected") + XCTAssertEqual(resampledAudio?.format.channelCount, targetChannelCount, "Resampled audio channels is not as expected") + + // Check if the duration is approximately the same (allowing for small differences due to resampling) + let originalDuration = Double(audioFile.length) / audioFile.fileFormat.sampleRate + let resampledDuration = Double(resampledAudio!.frameLength) / targetSampleRate + XCTAssertEqual(originalDuration, resampledDuration, accuracy: 0.1, "Resampled audio duration should be close to original") + + // Read the entire original file into a buffer + audioFile.framePosition = 0 + guard let originalBuffer = AVAudioPCMBuffer(pcmFormat: audioFile.processingFormat, frameCapacity: AVAudioFrameCount(audioFile.length)) else { + XCTFail("Failed to create original buffer") + return + } + try audioFile.read(into: originalBuffer) + + // Compare the audio samples + let originalData = originalBuffer.floatChannelData?[0] + let resampledData = resampledAudio?.floatChannelData?[0] + + guard let originalSamples = originalData, let resampledSamples = resampledData else { + XCTFail("Failed to access audio sample data") + return + } + + var maxDifference: Float = 0 + for i in 0..