diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj index bfb9069..4c92d93 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj @@ -869,7 +869,7 @@ CURRENT_PROJECT_VERSION = 1; DEAD_CODE_STRIPPING = YES; DEVELOPMENT_ASSET_PATHS = "\"WhisperAX/Preview Content\""; - DEVELOPMENT_TEAM = PP83DTRKSA; + DEVELOPMENT_TEAM = 88ACA86N96; ENABLE_HARDENED_RUNTIME = YES; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 41e3727..bc58b75 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,6 +1,15 @@ { - "originHash" : "cd17206b47bb810af9459722192530e3838d8e6629a970988e32a432aaa05f6e", + "originHash" : "420a1723357da21f9e31b01403fd3d66df6e400a752d242d05b2c3d5667e3c33", "pins" : [ + { + "identity" : "jinja", + "kind" : "remoteSourceControl", + "location" : "https://github.com/maiqingqiang/Jinja", + "state" : { + "revision" : "b435eb62b0d3d5f34167ec70a128355486981712", + "version" : "1.0.5" + } + }, { "identity" : "networkimage", "kind" : "remoteSourceControl", @@ -15,8 +24,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-argument-parser.git", "state" : { - "revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41", - "version" : "1.3.0" + "revision" : "41982a3656a71c768319979febd796c6fd111d5c", + "version" : "1.5.0" } }, { @@ -24,8 +33,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/gonzalezreal/swift-markdown-ui.git", "state" : { - "revision" : "ae799d015a5374708f7b4c85f3294c05f2a564e2", - "version" : "2.3.0" + "revision" : "55441810c0f678c78ed7e2ebd46dde89228e02fc", + "version" : "2.4.0" } }, { @@ -33,8 +42,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe", - "version" : "0.1.7" + "revision" : "0f2306713d48a75b862026ebb291926793773f52", + "version" : "0.1.12" } } ], diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 2a182fb..b164111 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -1206,9 +1206,10 @@ struct ContentView: View { #endif try? audioProcessor.startRecordingLive(inputDeviceID: deviceId) { _ in - DispatchQueue.main.async { - bufferEnergy = whisperKit?.audioProcessor.relativeEnergy ?? [] - bufferSeconds = Double(whisperKit?.audioProcessor.audioSamples.count ?? 0) / Double(WhisperKit.sampleRate) + Task { @MainActor in + bufferEnergy = await whisperKit?.audioProcessor.getRelativeEnergy() ?? [] + let audioSamples = await whisperKit?.audioProcessor.getAudioSamples() ?? [] + bufferSeconds = Double(audioSamples.count) / Double(WhisperKit.sampleRate) } } @@ -1406,7 +1407,7 @@ struct ContentView: View { guard let whisperKit = whisperKit else { return } // Retrieve the current audio buffer from the audio processor - let currentBuffer = whisperKit.audioProcessor.audioSamples + let currentBuffer = whisperKit.audioProcessor.getAudioSamples() // Calculate the size and duration of the next buffer segment let nextBufferSize = currentBuffer.count - lastBufferSize @@ -1424,8 +1425,9 @@ struct ContentView: View { } if useVAD { + let relativeEnergy = whisperKit.audioProcessor.getRelativeEnergy() let voiceDetected = AudioProcessor.isVoiceDetected( - in: whisperKit.audioProcessor.relativeEnergy, + in: relativeEnergy, nextBufferInSeconds: nextBufferSeconds, silenceThreshold: Float(silenceThreshold) ) diff --git a/Package.resolved b/Package.resolved index 6cccf25..87fb996 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,12 +1,21 @@ { "pins" : [ + { + "identity" : "jinja", + "kind" : "remoteSourceControl", + "location" : "https://github.com/maiqingqiang/Jinja", + "state" : { + "revision" : "4ffa95ce02e013c992287e19e3bbd620b6cc233a", + "version" : "1.0.4" + } + }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-argument-parser.git", "state" : { - "revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41", - "version" : "1.3.0" + "revision" : "41982a3656a71c768319979febd796c6fd111d5c", + "version" : "1.5.0" } }, { @@ -14,8 +23,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", "state" : { - "revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe", - "version" : "0.1.7" + "revision" : "0f2306713d48a75b862026ebb291926793773f52", + "version" : "0.1.12" } } ], diff --git a/Package.swift b/Package.swift index f3f111e..3515d89 100644 --- a/Package.swift +++ b/Package.swift @@ -20,8 +20,8 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"), - .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), + .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.12"), + .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.5.0"), ], targets: [ .target( diff --git a/Sources/WhisperKit/Core/Audio/AudioChunker.swift b/Sources/WhisperKit/Core/Audio/AudioChunker.swift index 467bfd6..3b5e091 100644 --- a/Sources/WhisperKit/Core/Audio/AudioChunker.swift +++ b/Sources/WhisperKit/Core/Audio/AudioChunker.swift @@ -43,12 +43,12 @@ public extension AudioChunking { /// A audio chunker that splits audio into smaller pieces based on voice activity detection @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -open class VADAudioChunker: AudioChunking { +public struct VADAudioChunker: AudioChunking { /// prevent hallucinations at the end of the clip by stopping up to 1.0s early private let windowPadding: Int - private let vad: VoiceActivityDetector + private let vad: any VoiceActivityDetectable - public init(windowPadding: Int = 16000, vad: VoiceActivityDetector? = nil) { + public init(windowPadding: Int = 16000, vad: (any VoiceActivityDetectable)? = nil) { self.windowPadding = windowPadding self.vad = vad ?? EnergyVAD() } @@ -81,12 +81,12 @@ open class VADAudioChunker: AudioChunking { // Typically this will be the full audio file, unless seek points are explicitly provided var startIndex = seekClipStart while startIndex < seekClipEnd - windowPadding { - let currentFrameLength = startIndex - seekClipStart - if startIndex >= currentFrameLength, startIndex < 0 { + // 配列範囲内にあるかチェック + if startIndex >= audioArray.count || startIndex < 0 { throw WhisperError.audioProcessingFailed("startIndex is outside the buffer size") } - // Make sure we still need chunking for this seek clip, otherwise use the original seek clip end + // Adjust the end index based on VAD or maxChunkLength var endIndex = seekClipEnd if startIndex + maxChunkLength < endIndex { // Adjust the end index based on VAD @@ -97,6 +97,8 @@ open class VADAudioChunker: AudioChunking { ) } + // Ensure endIndex is within the array bounds + endIndex = min(endIndex, audioArray.count) guard endIndex > startIndex else { break } @@ -108,4 +110,5 @@ open class VADAudioChunker: AudioChunking { } return chunkedAudio } + } diff --git a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift index c3958cb..cd86e81 100644 --- a/Sources/WhisperKit/Core/Audio/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/Audio/AudioProcessor.swift @@ -13,12 +13,12 @@ public typealias DeviceID = AudioDeviceID public typealias DeviceID = String #endif -public struct AudioDevice: Identifiable, Hashable { +public struct AudioDevice: Identifiable, Hashable, Sendable { public let id: DeviceID public let name: String } -public protocol AudioProcessing { +public protocol AudioProcessing: Actor { /// Loads audio data from a specified file path. /// - Parameters: /// - audioFilePath: The file path of the audio file. @@ -47,13 +47,13 @@ public protocol AudioProcessing { ) -> MLMultiArray? /// Stores the audio samples to be transcribed - var audioSamples: ContiguousArray { get } + func getAudioSamples() -> ContiguousArray /// Empties the audio samples array, keeping the last `keep` samples func purgeAudioSamples(keepingLast keep: Int) /// A measure of current buffer's energy in dB normalized from 0 - 1 based on the quietest buffer's energy in a specified window - var relativeEnergy: [Float] { get } + func getRelativeEnergy() -> [Float] /// How many past buffers of audio to use to calculate relative energy /// The lowest average energy value in the buffer within this amount of previous buffers will used as the silence baseline @@ -95,7 +95,7 @@ public extension AudioProcessing { static func padOrTrimAudio(fromArray audioArray: [Float], startAt startIndex: Int = 0, toLength frameLength: Int = 480_000, saveSegment: Bool = false) -> MLMultiArray? { let currentFrameLength = audioArray.count - if startIndex >= currentFrameLength, startIndex < 0 { + if startIndex >= currentFrameLength || startIndex < 0 { Logging.error("startIndex is outside the buffer size") return nil } @@ -169,19 +169,34 @@ public extension AudioProcessing { } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -public class AudioProcessor: NSObject, AudioProcessing { +public actor AudioProcessor: @preconcurrency AudioProcessing { private var lastInputDevice: DeviceID? public var audioEngine: AVAudioEngine? - public var audioSamples: ContiguousArray = [] - public var audioEnergy: [(rel: Float, avg: Float, max: Float, min: Float)] = [] public var relativeEnergyWindow: Int = 20 - public var relativeEnergy: [Float] { - return self.audioEnergy.map { $0.rel } - } public var audioBufferCallback: (([Float]) -> Void)? public var maxBufferLength = WhisperKit.sampleRate * WhisperKit.chunkLength // 30 seconds of audio at 16,000 Hz public var minBufferLength = Int(Double(WhisperKit.sampleRate) * 0.1) // 0.1 second of audio at 16,000 Hz + + public init() { + + } + + private var audioSamples: ContiguousArray = [] + + public func getAudioSamples() -> ContiguousArray { + self.audioSamples + } + + private var audioEnergy: [(rel: Float, avg: Float, max: Float, min: Float)] = [] + + public func getAudioEnergy() -> [(rel: Float, avg: Float, max: Float, min: Float)] { + self.audioEnergy + } + + public func getRelativeEnergy() -> [Float] { + self.audioEnergy.map(\.rel) + } // MARK: - Loading and conversion @@ -437,7 +452,7 @@ public class AudioProcessor: NSObject, AudioProcessing { /// - Returns: an array of tuples indicating the start and end indices of non-silent chunks public static func calculateNonSilentChunks( in signal: [Float] - ) -> [(startIndex: Int, endIndex: Int)] { + ) -> [SampleRange] { EnergyVAD().calculateActiveChunks(in: signal) } @@ -486,73 +501,40 @@ public class AudioProcessor: NSObject, AudioProcessing { var rmsEnergy: Float = 0.0 var minEnergy: Float = 0.0 var maxEnergy: Float = 0.0 - - // Calculate the root mean square of the signal vDSP_rmsqv(signal, 1, &rmsEnergy, vDSP_Length(signal.count)) - - // Calculate the maximum sample value of the signal - vDSP_maxmgv(signal, 1, &maxEnergy, vDSP_Length(signal.count)) - - // Calculate the minimum sample value of the signal - vDSP_minmgv(signal, 1, &minEnergy, vDSP_Length(signal.count)) - + vDSP_maxv(signal, 1, &maxEnergy, vDSP_Length(signal.count)) + vDSP_minv(signal, 1, &minEnergy, vDSP_Length(signal.count)) return (rmsEnergy, maxEnergy, minEnergy) } public static func calculateRelativeEnergy(of signal: [Float], relativeTo reference: Float?) -> Float { let signalEnergy = calculateAverageEnergy(of: signal) - - // Make sure reference is greater than 0 - // Default 1e-3 measured empirically in a silent room let referenceEnergy = max(1e-8, reference ?? 1e-3) - - // Convert to dB let dbEnergy = 20 * log10(signalEnergy) let refEnergy = 20 * log10(referenceEnergy) - - // Normalize based on reference - // NOTE: since signalEnergy elements are floats from 0 to 1, max (full volume) is always 0dB let normalizedEnergy = rescale(value: dbEnergy, min: refEnergy, max: 0) - - // Clamp from 0 to 1 return max(0, min(normalizedEnergy, 1)) } - public static func convertBufferToArray(buffer: AVAudioPCMBuffer, chunkSize: Int = 1024) -> [Float] { + public static func convertBufferToArray(buffer: AVAudioPCMBuffer) -> [Float] { guard let channelData = buffer.floatChannelData else { return [] } - let frameLength = Int(buffer.frameLength) let startPointer = channelData[0] - - var result: [Float] = [] - result.reserveCapacity(frameLength) // Reserve the capacity to avoid multiple allocations - - var currentFrame = 0 - while currentFrame < frameLength { - let remainingFrames = frameLength - currentFrame - let currentChunkSize = min(chunkSize, remainingFrames) - - var chunk = [Float](repeating: 0, count: currentChunkSize) - - chunk.withUnsafeMutableBufferPointer { bufferPointer in - vDSP_mmov( - startPointer.advanced(by: currentFrame), - bufferPointer.baseAddress!, - vDSP_Length(currentChunkSize), - 1, - vDSP_Length(currentChunkSize), - 1 - ) - } - - result.append(contentsOf: chunk) - currentFrame += currentChunkSize - - memset(startPointer.advanced(by: currentFrame - currentChunkSize), 0, currentChunkSize * MemoryLayout.size) + let result = [Float]( + unsafeUninitializedCapacity: frameLength + ) { bufferPointer, initializedCount in + vDSP_mmov( + startPointer, + bufferPointer.baseAddress!, + vDSP_Length(frameLength), + 1, + vDSP_Length(frameLength), + 1 + ) + initializedCount = frameLength } - return result } @@ -672,9 +654,13 @@ public class AudioProcessor: NSObject, AudioProcessing { return devices } #endif - + deinit { - stopRecording() + audioEngine?.stop() + audioEngine = nil + + audioSamples.removeAll() + audioEnergy.removeAll() } } @@ -685,17 +671,24 @@ public extension AudioProcessor { /// We have a new buffer, process and store it. /// NOTE: Assumes audio is 16khz mono func processBuffer(_ buffer: [Float]) { + let bufferCount = buffer.count + let previousCount = audioSamples.count + audioSamples.reserveCapacity(previousCount + bufferCount) audioSamples.append(contentsOf: buffer) - // Find the lowest average energy of the last 20 buffers ~2 seconds - let minAvgEnergy = self.audioEnergy.suffix(20).reduce(Float.infinity) { min($0, $1.avg) } - let relativeEnergy = Self.calculateRelativeEnergy(of: buffer, relativeTo: minAvgEnergy) + // エネルギー計算 + let recentAudioEnergy = self.audioEnergy.suffix(relativeEnergyWindow) + let minAvgEnergy: Float + if recentAudioEnergy.isEmpty { + minAvgEnergy = 1e-8 // デフォルトの最小エネルギー値 + } else { + minAvgEnergy = recentAudioEnergy.reduce(Float.infinity) { min($0, $1.avg) } + } - // Update energy for buffers with valid data + let relativeEnergy = Self.calculateRelativeEnergy(of: buffer, relativeTo: minAvgEnergy) let signalEnergy = Self.calculateEnergy(of: buffer) - let newEnergy = (relativeEnergy, signalEnergy.avg, signalEnergy.max, signalEnergy.min) + let newEnergy = (rel: relativeEnergy, avg: signalEnergy.avg, max: signalEnergy.max, min: signalEnergy.min) self.audioEnergy.append(newEnergy) - // Call the callback with the new buffer audioBufferCallback?(buffer) @@ -779,7 +772,6 @@ public extension AudioProcessor { let bufferSize = AVAudioFrameCount(minBufferLength) // 100ms - 400ms supported inputNode.installTap(onBus: 0, bufferSize: bufferSize, format: nodeFormat) { [weak self] (buffer: AVAudioPCMBuffer, _: AVAudioTime) in - guard let self = self else { return } var buffer = buffer if !buffer.format.sampleRate.isEqual(to: Double(WhisperKit.sampleRate)) { do { @@ -789,20 +781,29 @@ public extension AudioProcessor { return } } - - let newBufferArray = Self.convertBufferToArray(buffer: buffer) - self.processBuffer(newBufferArray) + let targetBuffer = buffer + let newBufferArray = Self.convertBufferToArray(buffer: targetBuffer) + Task { [weak self] in + guard let self = self else { return } + await self.processBuffer(newBufferArray) + } } audioEngine.prepare() try audioEngine.start() - + return audioEngine } - + func purgeAudioSamples(keepingLast keep: Int) { - if audioSamples.count > keep { - audioSamples.removeFirst(audioSamples.count - keep) + let samplesToRemove = audioSamples.count - keep + if samplesToRemove > 0 { + audioSamples.removeFirst(samplesToRemove) + } + + let energiesToRemove = samplesToRemove / minBufferLength + if energiesToRemove > 0 { + audioEnergy.removeFirst(min(energiesToRemove, audioEnergy.count)) } } diff --git a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift index f91ba53..2c06957 100644 --- a/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift +++ b/Sources/WhisperKit/Core/Audio/AudioStreamTranscriber.swift @@ -5,7 +5,7 @@ import Foundation @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public extension AudioStreamTranscriber { - struct State { + struct State: Sendable { public var isRecording: Bool = false public var currentFallbacks: Int = 0 public var lastBufferSize: Int = 0 @@ -79,7 +79,7 @@ public actor AudioStreamTranscriber { return } state.isRecording = true - try audioProcessor.startRecordingLive { [weak self] _ in + try await audioProcessor.startRecordingLive { [weak self] _ in Task { [weak self] in await self?.onAudioBufferCallback() } @@ -90,7 +90,9 @@ public actor AudioStreamTranscriber { public func stopStreamTranscription() { state.isRecording = false - audioProcessor.stopRecording() + Task { + await audioProcessor.stopRecording() + } Logging.info("Realtime transcription has ended") } @@ -99,14 +101,16 @@ public actor AudioStreamTranscriber { do { try await transcribeCurrentBuffer() } catch { - Logging.error("Error: \(error.localizedDescription)") + Logging.error("Error: \(#file) \(error.localizedDescription)") break } } } private func onAudioBufferCallback() { - state.bufferEnergy = audioProcessor.relativeEnergy + Task { + state.bufferEnergy = await audioProcessor.getRelativeEnergy() + } } private func onProgressCallback(_ progress: TranscriptionProgress) { @@ -124,7 +128,7 @@ public actor AudioStreamTranscriber { private func transcribeCurrentBuffer() async throws { // Retrieve the current audio buffer from the audio processor - let currentBuffer = audioProcessor.audioSamples + let currentBuffer = await audioProcessor.getAudioSamples() // Calculate the size and duration of the next buffer segment let nextBufferSize = currentBuffer.count - state.lastBufferSize @@ -139,8 +143,9 @@ public actor AudioStreamTranscriber { } if useVAD { + let relativeEnergy = await audioProcessor.getRelativeEnergy() let voiceDetected = AudioProcessor.isVoiceDetected( - in: audioProcessor.relativeEnergy, + in: relativeEnergy, nextBufferInSeconds: nextBufferSeconds, silenceThreshold: silenceThreshold ) diff --git a/Sources/WhisperKit/Core/Audio/EnergyVAD.swift b/Sources/WhisperKit/Core/Audio/EnergyVAD.swift index 3c8f0e7..d8bc97a 100644 --- a/Sources/WhisperKit/Core/Audio/EnergyVAD.swift +++ b/Sources/WhisperKit/Core/Audio/EnergyVAD.swift @@ -5,46 +5,39 @@ import Foundation /// Voice activity detection based on energy threshold @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -final class EnergyVAD: VoiceActivityDetector { - var energyThreshold: Float - - /// Initialize a new EnergyVAD instance - /// - Parameters: - /// - sampleRate: Audio sample rate - /// - frameLength: Frame length in seconds - /// - frameOverlap: frame overlap in seconds, this will include `frameOverlap` length audio into the `frameLength` and is helpful to catch audio that starts exactly at chunk boundaries - /// - energyThreshold: minimal energy threshold - convenience init( +public struct EnergyVAD: VoiceActivityDetectable { + public let sampleRate: Int + public let frameLengthSamples: Int + public let frameOverlapSamples: Int + public var energyThreshold: Float + + public init( sampleRate: Int = WhisperKit.sampleRate, frameLength: Float = 0.1, frameOverlap: Float = 0.0, energyThreshold: Float = 0.02 ) { - self.init( - sampleRate: sampleRate, - // Compute frame length and overlap in number of samples - frameLengthSamples: Int(frameLength * Float(sampleRate)), - frameOverlapSamples: Int(frameOverlap * Float(sampleRate)), - energyThreshold: energyThreshold - ) + self.sampleRate = sampleRate + self.frameLengthSamples = Int(frameLength * Float(sampleRate)) + self.frameOverlapSamples = Int(frameOverlap * Float(sampleRate)) + self.energyThreshold = energyThreshold } - required init( + init( sampleRate: Int = 16000, frameLengthSamples: Int, frameOverlapSamples: Int = 0, energyThreshold: Float = 0.02 ) { + self.sampleRate = sampleRate + self.frameLengthSamples = frameLengthSamples + self.frameOverlapSamples = frameOverlapSamples self.energyThreshold = energyThreshold - super.init(sampleRate: sampleRate, frameLengthSamples: frameLengthSamples, frameOverlapSamples: frameOverlapSamples) } - - override func voiceActivity(in waveform: [Float]) -> [Bool] { + + public func voiceActivity(in waveform: [Float]) -> [Bool] { let chunkRatio = Double(waveform.count) / Double(frameLengthSamples) - - // Round up if uneven, the final chunk will not be a full `frameLengthSamples` long let count = Int(chunkRatio.rounded(.up)) - let chunkedVoiceActivity = AudioProcessor.calculateVoiceActivityInChunks( of: waveform, chunkCount: count, @@ -52,7 +45,6 @@ final class EnergyVAD: VoiceActivityDetector { frameOverlapSamples: frameOverlapSamples, energyThreshold: energyThreshold ) - return chunkedVoiceActivity } } diff --git a/Sources/WhisperKit/Core/Audio/SampleRange.swift b/Sources/WhisperKit/Core/Audio/SampleRange.swift new file mode 100644 index 0000000..87469a3 --- /dev/null +++ b/Sources/WhisperKit/Core/Audio/SampleRange.swift @@ -0,0 +1,10 @@ +// +// SampleRange.swift +// whisperkit +// +// Created by Norikazu Muramoto on 2024/10/06. +// + +public typealias FrameRange = (start: Int, end: Int) +public typealias SampleRange = (startIndex: Int, endIndex: Int) +public typealias TimestampRange = (startTime: Float, endTime: Float) diff --git a/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift b/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift new file mode 100644 index 0000000..d48e2c4 --- /dev/null +++ b/Sources/WhisperKit/Core/Audio/VoiceActivityDetectable.swift @@ -0,0 +1,124 @@ +// +// VoiceActivityDetectable.swift +// whisperkit +// +// Created by Norikazu Muramoto on 2024/10/03. +// + +/// Protocol defining the interface for Voice Activity Detection (VAD) +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) +public protocol VoiceActivityDetectable: Sendable { + var sampleRate: Int { get } + var frameLengthSamples: Int { get } + var frameOverlapSamples: Int { get } + + func voiceActivity(in waveform: [Float]) -> [Bool] + func calculateActiveChunks(in waveform: [Float]) -> [SampleRange] + func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int + func voiceActivityIndexToSeconds(_ index: Int) -> Float + func findLongestSilence(in vadResult: [Bool]) -> SampleRange? + func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] + func calculateNonSilentSeekClips(in waveform: [Float]) -> [FrameRange] + func calculateSeekTimestamps(in waveform: [Float]) -> [TimestampRange] +} + +extension VoiceActivityDetectable { + + public func calculateActiveChunks(in waveform: [Float]) -> [SampleRange] { + let vad = voiceActivity(in: waveform) + var result = [SampleRange]() + var currentStartIndex: Int? + + for (index, vadChunk) in vad.enumerated() { + if vadChunk { + let chunkStart = index * frameLengthSamples + let chunkEnd = min(chunkStart + frameLengthSamples, waveform.count) + + if currentStartIndex != nil { + result[result.count - 1].endIndex = chunkEnd + } else { + currentStartIndex = chunkStart + result.append((startIndex: chunkStart, endIndex: chunkEnd)) + } + } else { + currentStartIndex = nil + } + } + + return result + } + + public func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int { + return index * frameLengthSamples + } + + public func voiceActivityIndexToSeconds(_ index: Int) -> Float { + return Float(voiceActivityIndexToAudioSampleIndex(index)) / Float(sampleRate) + } + + public func findLongestSilence(in vadResult: [Bool]) -> SampleRange? { + var longestStartIndex: Int? + var longestEndIndex: Int? + var longestCount = 0 + var index = 0 + while index < vadResult.count { + if vadResult[index] { + index += 1 + } else { + var endIndex = index + while endIndex < vadResult.count, !vadResult[endIndex] { + endIndex += 1 + } + let count = endIndex - index + if count > longestCount { + longestCount = count + longestStartIndex = index + longestEndIndex = endIndex + } + index = endIndex + } + } + if let longestStartIndex, let longestEndIndex { + return (startIndex: longestStartIndex, endIndex: longestEndIndex) + } else { + return nil + } + } + + // MARK - Utility + + public func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] { + let nonSilentChunks = calculateActiveChunks(in: waveform) + var clipTimestamps = [Float]() + + for chunk in nonSilentChunks { + let startTimestamp = Float(chunk.startIndex) / Float(sampleRate) + let endTimestamp = Float(chunk.endIndex) / Float(sampleRate) + + clipTimestamps.append(contentsOf: [startTimestamp, endTimestamp]) + } + + return clipTimestamps + } + + public func calculateNonSilentSeekClips(in waveform: [Float]) -> [FrameRange] { + let clipTimestamps = voiceActivityClipTimestamps(in: waveform) + let options = DecodingOptions(clipTimestamps: clipTimestamps) + let seekClips = prepareSeekClips(contentFrames: waveform.count, decodeOptions: options) + return seekClips + } + + public func calculateSeekTimestamps(in waveform: [Float]) -> [TimestampRange] { + let nonSilentChunks = calculateActiveChunks(in: waveform) + var seekTimestamps = [TimestampRange]() + + for chunk in nonSilentChunks { + let startTimestamp = Float(chunk.startIndex) / Float(sampleRate) + let endTimestamp = Float(chunk.endIndex) / Float(sampleRate) + + seekTimestamps.append(contentsOf: [(startTime: startTimestamp, endTime: endTimestamp)]) + } + + return seekTimestamps + } +} diff --git a/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift b/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift index bb7ef62..dd4c529 100644 --- a/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift +++ b/Sources/WhisperKit/Core/Audio/VoiceActivityDetector.swift @@ -6,22 +6,11 @@ import Foundation /// A base class for Voice Activity Detection (VAD), used to identify and separate segments of audio that contain human speech from those that do not. /// Subclasses must implement the `voiceActivity(in:)` method to provide specific voice activity detection functionality. @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -open class VoiceActivityDetector { - /// The sample rate of the audio signal, in samples per second. +public struct VoiceActivityDetector: VoiceActivityDetectable { public let sampleRate: Int - - /// The length of each frame in samples. public let frameLengthSamples: Int - - /// The number of samples overlapping between consecutive frames. public let frameOverlapSamples: Int - - /// Initializes a new `VoiceActivityDetector` instance with the specified parameters. - /// - Parameters: - /// - sampleRate: The sample rate of the audio signal in samples per second. Defaults to 16000. - /// - frameLengthSamples: The length of each frame in samples. - /// - frameOverlapSamples: The number of samples overlapping between consecutive frames. Defaults to 0. - /// - Note: Subclasses should override the `voiceActivity(in:)` method to provide specific VAD functionality. + public init( sampleRate: Int = 16000, frameLengthSamples: Int, @@ -31,126 +20,8 @@ open class VoiceActivityDetector { self.frameLengthSamples = frameLengthSamples self.frameOverlapSamples = frameOverlapSamples } - - /// Analyzes the provided audio waveform to determine which segments contain voice activity. - /// - Parameter waveform: An array of `Float` values representing the audio waveform. - /// - Returns: An array of `Bool` values where `true` indicates the presence of voice activity and `false` indicates silence. - open func voiceActivity(in waveform: [Float]) -> [Bool] { - fatalError("`voiceActivity` must be implemented by subclass") - } - - /// Calculates and returns a list of active audio chunks, each represented by a start and end index. - /// - Parameter waveform: An array of `Float` values representing the audio waveform. - /// - Returns: An array of tuples where each tuple contains the start and end indices of an active audio chunk. - public func calculateActiveChunks(in waveform: [Float]) -> [(startIndex: Int, endIndex: Int)] { - let vad: [Bool] = voiceActivity(in: waveform) - var result = [(startIndex: Int, endIndex: Int)]() - - // Temporary variables to hold the start of the current non-silent segment - var currentStartIndex: Int? - - for (index, vadChunk) in vad.enumerated() { - if vadChunk { - let chunkStart = index * frameLengthSamples - let chunkEnd = min(chunkStart + frameLengthSamples, waveform.count) - - if currentStartIndex != nil { - // If we already have a starting point, just update the end point in the last added segment - result[result.count - 1].endIndex = chunkEnd - } else { - // If there is no current start, this is a new segment - currentStartIndex = chunkStart - result.append((startIndex: chunkStart, endIndex: chunkEnd)) - } - } else { - // Reset currentStartIndex when encountering a silent chunk - currentStartIndex = nil - } - } - - return result - } - - /// Converts a voice activity index to the corresponding audio sample index. - /// - Parameter index: The voice activity index to convert. - /// - Returns: The corresponding audio sample index. - public func voiceActivityIndexToAudioSampleIndex(_ index: Int) -> Int { - return index * frameLengthSamples - } - - public func voiceActivityIndexToSeconds(_ index: Int) -> Float { - return Float(voiceActivityIndexToAudioSampleIndex(index)) / Float(sampleRate) - } - - /// Identifies the longest continuous period of silence within the provided voice activity detection results. - /// - Parameter vadResult: An array of `Bool` values representing voice activity detection results. - /// - Returns: A tuple containing the start and end indices of the longest silence period, or `nil` if no silence is found. - public func findLongestSilence(in vadResult: [Bool]) -> (startIndex: Int, endIndex: Int)? { - var longestStartIndex: Int? - var longestEndIndex: Int? - var longestCount = 0 - var index = 0 - while index < vadResult.count { - let value = vadResult[index] - if value { - // found non-silence, skip - index += 1 - } else { - // found beginning of silence, find the end - var endIndex = index - while endIndex < vadResult.count, !vadResult[endIndex] { - endIndex += 1 - } - let count = endIndex - index - if count > longestCount { - longestCount = count - longestStartIndex = index - longestEndIndex = endIndex - } - index = endIndex - } - } - if let longestStartIndex, let longestEndIndex { - return (startIndex: longestStartIndex, endIndex: longestEndIndex) - } else { - return nil - } - } - - // MARK - Utility - - func voiceActivityClipTimestamps(in waveform: [Float]) -> [Float] { - let nonSilentChunks = calculateActiveChunks(in: waveform) - var clipTimestamps = [Float]() - - for chunk in nonSilentChunks { - let startTimestamp = Float(chunk.startIndex) / Float(sampleRate) - let endTimestamp = Float(chunk.endIndex) / Float(sampleRate) - - clipTimestamps.append(contentsOf: [startTimestamp, endTimestamp]) - } - - return clipTimestamps - } - - func calculateNonSilentSeekClips(in waveform: [Float]) -> [(start: Int, end: Int)] { - let clipTimestamps = voiceActivityClipTimestamps(in: waveform) - let options = DecodingOptions(clipTimestamps: clipTimestamps) - let seekClips = prepareSeekClips(contentFrames: waveform.count, decodeOptions: options) - return seekClips - } - - func calculateSeekTimestamps(in waveform: [Float]) -> [(startTime: Float, endTime: Float)] { - let nonSilentChunks = calculateActiveChunks(in: waveform) - var seekTimestamps = [(startTime: Float, endTime: Float)]() - - for chunk in nonSilentChunks { - let startTimestamp = Float(chunk.startIndex) / Float(sampleRate) - let endTimestamp = Float(chunk.endIndex) / Float(sampleRate) - - seekTimestamps.append(contentsOf: [(startTime: startTimestamp, endTime: endTimestamp)]) - } - - return seekTimestamps + + public func voiceActivity(in waveform: [Float]) -> [Bool] { + fatalError("voiceActivity(in:) must be implemented by conforming types") } } diff --git a/Sources/WhisperKit/Core/AudioEncoder.swift b/Sources/WhisperKit/Core/AudioEncoder.swift index 06337cd..c9c9358 100644 --- a/Sources/WhisperKit/Core/AudioEncoder.swift +++ b/Sources/WhisperKit/Core/AudioEncoder.swift @@ -22,16 +22,14 @@ public class AudioEncoder: AudioEncoding, WhisperMLModel { guard let inputDescription = model?.modelDescription.outputDescriptionsByName["encoder_output_embeds"] else { return nil } guard inputDescription.type == .multiArray else { return nil } guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } - let shape = shapeConstraint.shape.map { $0.intValue } - return shape[1] + return shapeConstraint.shape[0].intValue } public var sequenceLength: Int? { guard let inputDescription = model?.modelDescription.outputDescriptionsByName["encoder_output_embeds"] else { return nil } guard inputDescription.type == .multiArray else { return nil } guard let shapeConstraint = inputDescription.multiArrayConstraint else { return nil } - let shape = shapeConstraint.shape.map { $0.intValue } - return shape[3] + return shapeConstraint.shape[1].intValue } public init() {} diff --git a/Sources/WhisperKit/Core/Configurations.swift b/Sources/WhisperKit/Core/Configurations.swift index c7a38b3..77d47ce 100644 --- a/Sources/WhisperKit/Core/Configurations.swift +++ b/Sources/WhisperKit/Core/Configurations.swift @@ -143,7 +143,7 @@ public struct DecodingOptions { public var noSpeechThreshold: Float? public var concurrentWorkerCount: Int public var chunkingStrategy: ChunkingStrategy? - public var voiceActivityDetector: VoiceActivityDetector? + public var voiceActivityDetector: (any VoiceActivityDetectable)? public init( verbose: Bool = false, @@ -172,7 +172,7 @@ public struct DecodingOptions { noSpeechThreshold: Float? = 0.6, concurrentWorkerCount: Int = 16, chunkingStrategy: ChunkingStrategy? = nil, - voiceActivityDetector: VoiceActivityDetector? = nil + voiceActivityDetector: (any VoiceActivityDetectable)? = nil ) { self.verbose = verbose self.task = task diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 3e05132..7bf1a50 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -167,14 +167,14 @@ public struct ModelComputeOptions { // MARK: - Chunking -public struct AudioChunk { +public struct AudioChunk: Sendable { public var seekOffsetIndex: Int public var audioSamples: [Float] } // MARK: - Decoding -public enum DecodingTask: CustomStringConvertible, CaseIterable { +public enum DecodingTask: CustomStringConvertible, CaseIterable, Sendable { case transcribe case translate @@ -247,13 +247,13 @@ public struct DecodingCache { } } -public enum ChunkingStrategy: String, CaseIterable { +public enum ChunkingStrategy: String, CaseIterable, Sendable { case none case vad } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -public struct DecodingFallback { +public struct DecodingFallback: Sendable { public var needsFallback: Bool public var fallbackReason: String @@ -390,7 +390,7 @@ public enum WhisperError: Error, LocalizedError, Equatable { // Structs -public struct TranscriptionResult: Codable { +public struct TranscriptionResult: Codable, Sendable { public var text: String public var segments: [TranscriptionSegment] public var language: String @@ -478,7 +478,7 @@ public extension TranscriptionResult { } } -public struct TranscriptionSegment: Hashable, Codable { +public struct TranscriptionSegment: Hashable, Codable, Sendable { public var id: Int = 0 public var seek: Int = 0 public var start: Float = 0.0 @@ -493,7 +493,7 @@ public struct TranscriptionSegment: Hashable, Codable { public var words: [WordTiming]? = nil } -public struct WordTiming: Hashable, Codable { +public struct WordTiming: Hashable, Codable, Sendable { public var word: String public var tokens: [Int] public var start: Float @@ -501,7 +501,7 @@ public struct WordTiming: Hashable, Codable { public var probability: Float } -public struct TranscriptionProgress { +public struct TranscriptionProgress: Sendable { public var timings: TranscriptionTimings public var text: String public var tokens: [Int] @@ -533,7 +533,7 @@ public struct TranscriptionProgress { /// - Note: This callback should be lightweight and return as quickly as possible to avoid extra decoding loops public typealias TranscriptionCallback = ((TranscriptionProgress) -> Bool?)? -public struct TranscriptionTimings: Codable { +public struct TranscriptionTimings: Codable, Sendable { public var pipelineStart: CFAbsoluteTime public var firstTokenTime: CFAbsoluteTime public var inputAudioSeconds: TimeInterval @@ -1155,6 +1155,15 @@ struct WhisperTokenizerWrapper: WhisperTokenizer { } extension WhisperTokenizerWrapper: Tokenizer { + + func applyChatTemplate(messages: [[String : String]]) throws -> [Int] { + try tokenizer.applyChatTemplate(messages: messages) + } + + func applyChatTemplate(messages: [[String : String]], chatTemplate: String?, addGenerationPrompt: Bool, truncation: Bool, maxLength: Int?) throws -> [Int] { + try tokenizer.applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength) + } + func tokenize(text: String) -> [String] { tokenizer.tokenize(text: text) } @@ -1166,6 +1175,10 @@ extension WhisperTokenizerWrapper: Tokenizer { func decode(tokens: [Int]) -> String { tokenizer.decode(tokens: tokens) } + + func encode(text: String, addSpecialTokens: Bool) -> [Int] { + tokenizer.encode(text: text, addSpecialTokens: addSpecialTokens) + } func convertTokenToId(_ token: String) -> Int? { tokenizer.convertTokenToId(token) diff --git a/Sources/WhisperKit/Core/ResultWriter.swift b/Sources/WhisperKit/Core/ResultWriter.swift index 00d694c..cfea6c9 100644 --- a/Sources/WhisperKit/Core/ResultWriter.swift +++ b/Sources/WhisperKit/Core/ResultWriter.swift @@ -3,7 +3,7 @@ import Foundation -public protocol ResultWriting { +public protocol ResultWriting: Sendable { var outputDir: String { get } func write(result: TranscriptionResult, to file: String, options: [String: Any]?) -> Result func formatTime(seconds: Float, alwaysIncludeHours: Bool, decimalMarker: String) -> String @@ -37,7 +37,7 @@ public extension ResultWriting { } } -open class WriteJSON: ResultWriting { +public struct WriteJSON: ResultWriting { public let outputDir: String public init(outputDir: String) { @@ -66,7 +66,7 @@ open class WriteJSON: ResultWriting { } } -open class WriteSRT: ResultWriting { +public struct WriteSRT: ResultWriting { public let outputDir: String public init(outputDir: String) { @@ -101,7 +101,7 @@ open class WriteSRT: ResultWriting { } } -open class WriteVTT: ResultWriting { +public struct WriteVTT: ResultWriting { public let outputDir: String public init(outputDir: String) { diff --git a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift index 33a45dc..3aa35ff 100644 --- a/Sources/WhisperKit/Core/Text/SegmentSeeker.swift +++ b/Sources/WhisperKit/Core/Text/SegmentSeeker.swift @@ -19,7 +19,7 @@ public protocol SegmentSeeking { specialToken: Int, tokenizer: WhisperTokenizer ) -> (Int, [TranscriptionSegment]?) - + func addWordTimestamps( segments: [TranscriptionSegment], alignmentWeights: MLMultiArray, @@ -35,12 +35,11 @@ public protocol SegmentSeeking { } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -open class SegmentSeeker: SegmentSeeking { +public struct SegmentSeeker: SegmentSeeking { public init() {} - + // MARK: - Seek & Segments - - // TODO: simplify this interface + public func findSeekPointAndSegments( decodingResult: DecodingResult, options: DecodingOptions, @@ -52,77 +51,61 @@ open class SegmentSeeker: SegmentSeeking { specialToken: Int, tokenizer: WhisperTokenizer ) -> (Int, [TranscriptionSegment]?) { - // check if we need to skip this segment entirely - // if so, reset currentSegments, continue to next window, otherwise: + // このセグメントをスキップする必要があるか確認 var seek = currentSeek let timeOffset = Float(seek) / Float(sampleRate) let secondsPerTimeToken = WhisperKit.secondsPerTimeToken + if let threshold = options.noSpeechThreshold { - // check no speech threshold for segment var shouldSkip = decodingResult.noSpeechProb > threshold - - // check avg logprob threshold for segment + if let logProbThreshold = options.logProbThreshold, - decodingResult.avgLogProb > logProbThreshold - { - // Confidence in overall segment overrides no speech threshold + decodingResult.avgLogProb > logProbThreshold { shouldSkip = false } - + if shouldSkip { - // skip one full segment, this one is silent seek += segmentSize return (seek, nil) } } - + var currentSegments: [TranscriptionSegment] = [] - - // loop through all consecutive timestamps and turn them into `TranscriptionSegments` + let currentTokens = decodingResult.tokens let currentLogProbs = decodingResult.tokenLogProbs let isTimestampToken = currentTokens.map { $0 >= timeToken } - - // check if single or double timestamp ending - let lastThreeTokens = isTimestampToken.suffix(3) - let singleTimestampEnding = lastThreeTokens == [false, true, false] - let noTimestampEnding = lastThreeTokens == [false, false, false] - - // find all end indexes of time token pairs + var sliceIndexes = [Int]() - var previousTokenIsTimestamp = false - for (currentTokenIsTimestampIndex, currentTokenIsTimestamp) in isTimestampToken.enumerated() { + for (currentIndex, currentTokenIsTimestamp) in isTimestampToken.enumerated() { if previousTokenIsTimestamp && currentTokenIsTimestamp { - sliceIndexes.append(currentTokenIsTimestampIndex) + sliceIndexes.append(currentIndex) } previousTokenIsTimestamp = currentTokenIsTimestamp } - - // Window contains multiple consecutive timestamps, split into sub-segments + if !sliceIndexes.isEmpty { - // If the last timestamp is not consecutive, we need to add it as the final slice manually - if singleTimestampEnding { - let singleTimestampEndingIndex = isTimestampToken.lastIndex(where: { $0 })! - sliceIndexes.append(singleTimestampEndingIndex + 1) - } else if noTimestampEnding { - sliceIndexes.append(currentTokens.count) - } - + let lastTimestampIndex = isTimestampToken.lastIndex(of: true) ?? currentTokens.count - 1 + sliceIndexes.append(lastTimestampIndex + 1) + var lastSliceStart = 0 for currentSliceEnd in sliceIndexes { let slicedTokens = Array(currentTokens[lastSliceStart..= timeToken } - - let startTimestampSeconds = Float(timestampTokens.first! - timeToken) * secondsPerTimeToken - let endTimestampSeconds = Float(timestampTokens.last! - timeToken) * secondsPerTimeToken - - // Decode segment text + + guard let firstTimestamp = timestampTokens.first, + let lastTimestamp = timestampTokens.last else { continue } + + let startTimestampSeconds = Float(firstTimestamp - timeToken) * secondsPerTimeToken + let endTimestampSeconds = Float(lastTimestamp - timeToken) * secondsPerTimeToken + + // セグメントテキストをデコード let wordTokens = slicedTokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } let slicedTextTokens = options.skipSpecialTokens ? wordTokens : slicedTokens let sliceText = tokenizer.decode(tokens: slicedTextTokens) - + let newSegment = TranscriptionSegment( id: allSegmentsCount + currentSegments.count, seek: seek, @@ -139,34 +122,29 @@ open class SegmentSeeker: SegmentSeeking { currentSegments.append(newSegment) lastSliceStart = currentSliceEnd } - - // Seek to the last timestamp in the segment - if !noTimestampEnding { - let lastTimestampToken = currentTokens[lastSliceStart - (singleTimestampEnding ? 1 : 0)] - timeToken - let lastTimestampSeconds = Float(lastTimestampToken) * secondsPerTimeToken + + // セグメント内の最後のタイムスタンプまでシークを進める + if let lastTimestampToken = currentTokens[lastSliceStart - 1] as Int? { + let lastTimestampSeconds = Float(lastTimestampToken - timeToken) * secondsPerTimeToken let lastTimestampSamples = Int(lastTimestampSeconds * Float(sampleRate)) seek += lastTimestampSamples } else { seek += segmentSize } } else { - // Model is not giving any consecutive timestamps, so lump all the current tokens together - var durationSeconds = Float(segmentSize) / Float(sampleRate) - - // Find any timestamp that is not 0.00 - let timestampTokens = currentTokens.filter { $0 > timeToken } - - // If there are no consecutive timestamps at all, check if there is at least one timestamp at the end - // If there is at least one, use that to record a more accurate end time - if !timestampTokens.isEmpty, let lastTimestamp = timestampTokens.last { + // 連続したタイムスタンプがない場合の処理 + let durationSeconds: Float + if let lastTimestamp = currentTokens.last(where: { $0 > timeToken }) { durationSeconds = Float(lastTimestamp - timeToken) * secondsPerTimeToken + } else { + durationSeconds = Float(segmentSize) / Float(sampleRate) } - - // Decode segment text + + // セグメントテキストをデコード let wordTokens = decodingResult.tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } let segmentTextTokens = options.skipSpecialTokens ? wordTokens : decodingResult.tokens let segmentText = tokenizer.decode(tokens: segmentTextTokens) - + let newSegment = TranscriptionSegment( id: allSegmentsCount + currentSegments.count, seek: seek, @@ -181,154 +159,139 @@ open class SegmentSeeker: SegmentSeeking { noSpeechProb: decodingResult.noSpeechProb ) currentSegments.append(newSegment) - - // Model has told us there is no more speech in this segment, move on to next + seek += segmentSize - // TODO: use this logic instead once we handle no speech - // seek += Int(durationSeconds * Float(sampleRate)) } - + return (seek, currentSegments) } - + // MARK: - Word Timestamps - - /// Matrix is a 2D array of alignment weights of shape (n, m) where n is the number of rows representing text tokens - /// and m is the number of columns representing audio tokens + func dynamicTimeWarping(withMatrix matrix: MLMultiArray) throws -> (textIndices: [Int], timeIndices: [Int]) { guard matrix.shape.count == 2, let numberOfRows = matrix.shape[0] as? Int, - let numberOfColumns = matrix.shape[1] as? Int - else { + let numberOfColumns = matrix.shape[1] as? Int else { throw WhisperError.segmentingFailed("Invalid alignment matrix shape") } - - // Initialize cost matrix and trace matrix - var costMatrix = Array(repeating: Array(repeating: Double.infinity, count: numberOfColumns + 1), count: numberOfRows + 1) - var traceMatrix = Array(repeating: Array(repeating: -1, count: numberOfColumns + 1), count: numberOfRows + 1) - - costMatrix[0][0] = 0 + + // MLMultiArray を Float16 型として扱う + let elementCount = numberOfRows * numberOfColumns + let pointer = matrix.dataPointer.bindMemory(to: UInt16.self, capacity: elementCount) + + // Float16 から Double に変換しながらデータを読み込む + var matrixData = [Double](repeating: 0.0, count: elementCount) + for i in 0.. (Double, Int) { let c0 = costDiagonal + matrixValue let c1 = costUp + matrixValue let c2 = costLeft + matrixValue - - if c0 < c1 && c0 < c2 { + + if c0 <= c1 && c0 <= c2 { return (c0, 0) - } else if c1 < c0 && c1 < c2 { + } else if c1 <= c0 && c1 <= c2 { return (c1, 1) } else { return (c2, 2) } } - - func backtrace(fromTraceMatrix traceMatrix: [[Int]]) -> (textIndices: [Int], timeIndices: [Int]) { - var i = traceMatrix.count - 1 - var j = traceMatrix[0].count - 1 - + + func backtrace(fromDirectionMatrix directionMatrix: [Int], numberOfRows: Int, numberOfColumns: Int) -> (textIndices: [Int], timeIndices: [Int]) { + var i = numberOfRows + var j = numberOfColumns + var textIndices = [Int]() var timeIndices = [Int]() - + + let width = numberOfColumns + 1 + while i > 0 || j > 0 { textIndices.append(i - 1) timeIndices.append(j - 1) - - switch traceMatrix[i][j] { - case 0: - i -= 1 - j -= 1 - case 1: - i -= 1 - case 2: - j -= 1 - default: - break + + let dir = directionMatrix[i * width + j] + switch dir { + case 0: + i -= 1 + j -= 1 + case 1: + i -= 1 + case 2: + j -= 1 + default: + break } } - + return (textIndices.reversed(), timeIndices.reversed()) } - + func mergePunctuations(alignment: [WordTiming], prepended: String, appended: String) -> [WordTiming] { - var prependedAlignment = [WordTiming]() - var appendedAlignment = [WordTiming]() - - // Include the first word if it's not a prepended punctuation - if !alignment.isEmpty && !prepended.contains(alignment[0].word.trimmingCharacters(in: .whitespaces)) { - prependedAlignment.append(alignment[0]) - } - - // Merge prepended punctuations - for i in 1.. [WordTiming] { - // TODO: Use accelerate framework for these two, they take roughly the same time let (textIndices, timeIndices) = try dynamicTimeWarping(withMatrix: alignmentWeights) let (words, wordTokens) = tokenizer.splitToWordTokens(tokenIds: wordTokenIds) - + if wordTokens.count <= 1 { return [] } - - // Calculate start times and end times - var startTimes: [Float] = [0.0] - var endTimes = [Float]() - var currentTokenIndex = textIndices.first ?? 0 - for index in 0.. [TranscriptionSegment]? { - // Initialize arrays to hold the extracted and filtered data + // アライメントのためのデータを準備 var wordTokenIds = [Int]() var filteredLogProbs = [Float]() var filteredIndices = [Int]() var lastSpeechTimestamp = lastSpeechTimestamp - - // Iterate through each segment + var indexOffset = 0 for segment in segments { for (index, token) in segment.tokens.enumerated() { wordTokenIds.append(token) - filteredIndices.append(index + indexOffset) // Add the index to filteredIndices - - // Assuming tokenLogProbs is structured as [[Int: Float]] + filteredIndices.append(index + indexOffset) if let logProb = segment.tokenLogProbs[index][token] { filteredLogProbs.append(logProb) } } - - // Update the indexOffset as we start a new segment indexOffset += segment.tokens.count } - - // Filter alignmentWeights using filteredIndices + + // alignmentWeights を効率的にフィルタリング let shape = alignmentWeights.shape guard let columnCount = shape.last?.intValue else { throw WhisperError.segmentingFailed("Invalid shape in alignmentWeights") } - - let filteredAlignmentWeights = initMLMultiArray(shape: [filteredIndices.count, columnCount] as [NSNumber], dataType: alignmentWeights.dataType, initialValue: FloatType(0)) - - alignmentWeights.withUnsafeMutableBytes { weightsPointer, weightsStride in - filteredAlignmentWeights.withUnsafeMutableBytes { filteredWeightsPointer, filteredWeightsStride in - for (newIndex, originalIndex) in filteredIndices.enumerated() { - let sourcePointer = weightsPointer.baseAddress!.advanced(by: Int(originalIndex * columnCount * MemoryLayout.stride)) - let destinationPointer = filteredWeightsPointer.baseAddress!.advanced(by: Int(newIndex * columnCount * MemoryLayout.stride)) - - memcpy(destinationPointer, sourcePointer, columnCount * MemoryLayout.stride) - } - } - } - - Logging.debug("Alignment weights shape: \(filteredAlignmentWeights.shape)") - + + let filteredAlignmentWeights = try filterAlignmentWeights( + alignmentWeights: alignmentWeights, + filteredIndices: filteredIndices, + rowCount: filteredIndices.count, + columnCount: columnCount + ) + var alignment = try findAlignment( wordTokenIds: wordTokenIds, alignmentWeights: filteredAlignmentWeights, @@ -463,104 +396,70 @@ open class SegmentSeeker: SegmentSeeking { timings: timings ) - // TODO: This section is considered a "hack" in the source repo - // Reference: https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/timing.py#L305 - var wordDurations = alignment.map { $0.end - $0.start } - wordDurations = wordDurations.filter { $0 > 0 } - - let medianDuration: Float = wordDurations.isEmpty ? 0.0 : wordDurations.sorted(by: <)[wordDurations.count / 2] - let constrainedMedianDuration = min(0.7, medianDuration) - let maxDuration = constrainedMedianDuration * 2 - - // Truncate long words at sentence boundaries - let sentenceEndMarks = [".", "。", "!", "!", "?", "?"] - if !wordDurations.isEmpty { - for i in 1.. maxDuration { - if sentenceEndMarks.contains(alignment[i].word) { - alignment[i].end = alignment[i].start + maxDuration - } else if i > 0, sentenceEndMarks.contains(alignment[i - 1].word) { - alignment[i].start = alignment[i].end - maxDuration - } - } - } - } - - // Process alignment for punctuations let mergedAlignment = mergePunctuations(alignment: alignment, prepended: prependPunctuations, appended: appendPunctuations) - + var wordIndex = 0 let timeOffset = Float(seek) / Float(WhisperKit.sampleRate) var updatedSegments = [TranscriptionSegment]() - + for segment in segments { var savedTokens = 0 let textTokens = segment.tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } var wordsInSegment = [WordTiming]() - - for timing in mergedAlignment[wordIndex...] where savedTokens < textTokens.count { + + while wordIndex < mergedAlignment.count && savedTokens < textTokens.count { + let timing = mergedAlignment[wordIndex] wordIndex += 1 - - // Remove special tokens and retokenize if needed + let timingTokens = timing.tokens.filter { $0 < tokenizer.specialTokens.specialTokenBegin } if timingTokens.isEmpty { continue } - - let start = (timeOffset + timing.start).rounded(2) - let end = (timeOffset + timing.end).rounded(2) - let probability = timing.probability.rounded(2) + + let start = (timeOffset + timing.start).rounded(toPlaces: 2) + let end = (timeOffset + timing.end).rounded(toPlaces: 2) + let probability = timing.probability.rounded(toPlaces: 2) let wordTiming = WordTiming(word: timing.word, tokens: timingTokens, start: start, end: end, probability: probability) wordsInSegment.append(wordTiming) - + savedTokens += timingTokens.count } - - // Create an updated segment with the word timings + var updatedSegment = segment - - // TODO: This section is considered a "hack" in the source repo - // Reference: https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/timing.py#L342 - // Truncate long words at segment boundaries - if let firstWord = wordsInSegment.first, let lastWord = wordsInSegment.last { - // Logic for the first word - if firstWord.end - lastSpeechTimestamp > constrainedMedianDuration * 4 && - (firstWord.end - firstWord.start > maxDuration || - (wordsInSegment.count > 1 && wordsInSegment[1].end - firstWord.start > maxDuration * 2)) - { - if wordsInSegment.count > 1 && wordsInSegment[1].end - wordsInSegment[1].start > maxDuration { - let boundary = max(wordsInSegment[1].end / 2, wordsInSegment[1].end - maxDuration) - wordsInSegment[0].end = boundary - wordsInSegment[1].start = boundary - } - wordsInSegment[0].start = max(lastSpeechTimestamp, firstWord.end - maxDuration) - } - - // Prefer segment-level start timestamp if the first word is too long. - if segment.start < firstWord.end && segment.start - 0.5 > firstWord.start { - wordsInSegment[0].start = max(0, min(firstWord.end - constrainedMedianDuration, segment.start)) - } else { - updatedSegment.start = firstWord.start - } - - // Prefer segment-level end timestamp if the last word is too long. - if updatedSegment.end > lastWord.start && segment.end + 0.5 < lastWord.end { - wordsInSegment[wordsInSegment.count - 1].end = max(lastWord.start + constrainedMedianDuration, segment.end) - } else { - updatedSegment.end = lastWord.end - } - - lastSpeechTimestamp = updatedSegment.end - } - updatedSegment.words = wordsInSegment updatedSegments.append(updatedSegment) } - + return updatedSegments } + + private func filterAlignmentWeights( + alignmentWeights: MLMultiArray, + filteredIndices: [Int], + rowCount: Int, + columnCount: Int + ) throws -> MLMultiArray { + let filteredAlignmentWeights = try MLMultiArray(shape: [rowCount, columnCount] as [NSNumber], dataType: .float16) + let sourcePointer = alignmentWeights.dataPointer.bindMemory(to: UInt16.self, capacity: alignmentWeights.count) + let destinationPointer = filteredAlignmentWeights.dataPointer.bindMemory(to: UInt16.self, capacity: filteredAlignmentWeights.count) + + for (newIndex, originalIndex) in filteredIndices.enumerated() { + let sourceRow = sourcePointer.advanced(by: originalIndex * columnCount) + let destinationRow = destinationPointer.advanced(by: newIndex * columnCount) + destinationRow.update(from: sourceRow, count: columnCount) + } + + return filteredAlignmentWeights + } +} + +extension Float { + func rounded(toPlaces places: Int) -> Float { + let divisor = pow(10, Float(places)) + return (self * divisor).rounded() / divisor + } } diff --git a/Sources/WhisperKit/Core/Text/TokenSampler.swift b/Sources/WhisperKit/Core/Text/TokenSampler.swift index ce15cd5..1dd9b6e 100644 --- a/Sources/WhisperKit/Core/Text/TokenSampler.swift +++ b/Sources/WhisperKit/Core/Text/TokenSampler.swift @@ -10,7 +10,7 @@ public protocol TokenSampling { func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult } -public struct SamplingResult { +public struct SamplingResult: Sendable { public var tokens: [Int] public var logProbs: [Float] public var completed: Bool @@ -21,73 +21,73 @@ open class GreedyTokenSampler: TokenSampling { public var temperature: FloatType public var eotToken: Int public var decodingOptions: DecodingOptions - + public init(temperature: FloatType, eotToken: Int, decodingOptions: DecodingOptions) { self.temperature = temperature self.eotToken = eotToken self.decodingOptions = decodingOptions } - + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { var softmaxOutput: BNNSNDArrayDescriptor? var argmaxOutput: BNNSNDArrayDescriptor? var softmaxInput: BNNSNDArrayDescriptor? var softmaxInputNeedsDeallocate = false - + var nextToken: Int? - + do { let logitsRawPointer = UnsafeMutableRawBufferPointer( start: logits.dataPointer, count: logits.count * MemoryLayout.stride ) - + let logitsDescriptor = BNNSNDArrayDescriptor( data: logitsRawPointer, scalarType: FloatType.self, shape: .vector(logits.count, stride: 1) )! - + softmaxInput = logitsDescriptor - + // Scale logits by temperature if > 0 if temperature != 0.0 { let scaledLogits = BNNSNDArrayDescriptor.allocateUninitialized( scalarType: FloatType.self, shape: .vector(logits.count, stride: 1) ) - + try! BNNS.applyActivation( activation: BNNS.ActivationFunction.linear(alpha: Float(1 / temperature)), input: logitsDescriptor, output: scaledLogits, batchSize: 1 ) - + softmaxInput = scaledLogits softmaxInputNeedsDeallocate = true } - + // Always softmax once softmaxOutput = BNNSNDArrayDescriptor.allocateUninitialized( scalarType: Float.self, shape: .vector(logits.count, stride: 1) ) - + try BNNS.applyActivation( activation: BNNS.ActivationFunction.softmax, input: softmaxInput!, output: softmaxOutput!, batchSize: 1 ) - + if temperature != 0.0 { // top-k multinomial sampling let k = decodingOptions.topK - + let bestValues = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float.self, shape: .vector(k, stride: 1)) let bestIndices = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int32.self, shape: .vector(k, stride: 1)) - + try! BNNS.applyTopK( k: k, input: softmaxOutput!, @@ -96,13 +96,13 @@ open class GreedyTokenSampler: TokenSampling { axis: 0, batchSize: 1 ) - + let bestValuesResult = bestValues.makeArray(of: Float.self)! let bestIndicesResult = bestIndices.makeArray(of: Int32.self)! - + bestValues.deallocate() bestIndices.deallocate() - + // multinomial sample from top-k let sumOfbestIndicesResult = bestValuesResult.reduce(0, +) let rnd = Float.random(in: 0.. SamplingResult { var finalTokens = tokens var finalLogProbs = logProbs @@ -164,7 +164,7 @@ open class GreedyTokenSampler: TokenSampling { finalTokens.append(eotToken) finalLogProbs.append(0) } - + return SamplingResult(tokens: finalTokens, logProbs: finalLogProbs, completed: true) } } @@ -175,7 +175,7 @@ open class BeamSearchTokenSampler: TokenSampling { public var patience: Float var maxCandidates: Int var finishedSequences: [Float] - + public init( beamSize: Int, eotToken: Int, @@ -191,18 +191,132 @@ open class BeamSearchTokenSampler: TokenSampling { fatalError("Invalid beam size \(beamSize) or patience \(patience)") } } - + public func reset() { finishedSequences = [] } - + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { // TODO: Implement fatalError("Not implemented: \(#function)") } - + public func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult { // TODO: Implement fatalError("Not implemented: \(#function)") } } + +@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, *) +open class NTokenSampler: TokenSampling { + public var temperature: Float + public var eotToken: Int + public var decodingOptions: DecodingOptions + + public init(temperature: Float, eotToken: Int, decodingOptions: DecodingOptions) { + self.temperature = temperature + self.eotToken = eotToken + self.decodingOptions = decodingOptions + } + + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { + // MLMultiArrayがFloat32であることを確認 + guard logits.dataType == .float32 else { + fatalError("Logits MLMultiArray must be of type Float32") + } + + let logitsCount = logits.count + + // ロジットデータへのアクセス + let logitsPointer = logits.dataPointer.bindMemory(to: Float.self, capacity: logitsCount) + let logitsBuffer = UnsafeBufferPointer(start: logitsPointer, count: logitsCount) + var logitsArray = [Float](logitsBuffer) + + // 温度が0より大きい場合はロジットをスケーリング + if temperature != 0.0 { + let tempReciprocal = 1.0 / temperature + vDSP_vsmul(logitsArray, 1, [tempReciprocal], &logitsArray, 1, vDSP_Length(logitsCount)) + } + + // ソフトマックス計算 + var softmaxOutput = [Float](repeating: 0, count: logitsCount) + computeSoftmax(logitsArray, result: &softmaxOutput) + + var nextToken: Int = 0 + var nextLogprob: Float = 0.0 + + if temperature != 0.0 { + // トップKのサンプリング + let k = min(decodingOptions.topK, logitsCount) + + // 値とインデックスをペアにしてソート + let indices = Array(0.. $1.0 } + let topKPairs = sortedPairs.prefix(k) + + let topKValues = topKPairs.map { $0.0 } + let topKIndices = topKPairs.map { $0.1 } + + // トップKの確率を正規化 + let sumTopK = topKValues.reduce(0, +) + let normalizedTopKValues = topKValues.map { $0 / sumTopK } + + // トップKからサンプリング + let randomValue = Float.random(in: 0..<1) + var cumulativeProbability: Float = 0.0 + for (i, probability) in normalizedTopKValues.enumerated() { + cumulativeProbability += probability + if randomValue < cumulativeProbability { + nextToken = topKIndices[i] + nextLogprob = log(probability) + break + } + } + } else { + // アーグマックスサンプリング + var maxValue: Float = 0 + var maxIndex: vDSP_Length = 0 + vDSP_maxvi(softmaxOutput, 1, &maxValue, &maxIndex, vDSP_Length(logitsCount)) + nextToken = Int(maxIndex) + nextLogprob = log(maxValue) + } + + let nextTokens = tokens + [nextToken] + let nextLogprobs = logProbs + [nextLogprob] + let completed = nextToken == eotToken + + return SamplingResult(tokens: nextTokens, logProbs: nextLogprobs, completed: completed) + } + + public func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult { + var finalTokens = tokens + var finalLogProbs = logProbs + if tokens.last != eotToken { + finalTokens.append(eotToken) + finalLogProbs.append(0) + } + + return SamplingResult(tokens: finalTokens, logProbs: finalLogProbs, completed: true) + } + + // ソフトマックスを効率的に計算するヘルパー関数 + func computeSoftmax(_ input: [Float], result: inout [Float]) { + var input = input + + // オーバーフローを防ぐために最大値を引く + var maxValue: Float = 0 + vDSP_maxv(input, 1, &maxValue, vDSP_Length(input.count)) + var negativeMax = -maxValue + vDSP_vsadd(input, 1, &negativeMax, &input, 1, vDSP_Length(input.count)) + + // 指数関数を適用 + vvexpf(&result, input, [Int32(input.count)]) + + // 指数関数の合計を計算 + var sumOfExponents: Float = 0 + vDSP_sve(result, 1, &sumOfExponents, vDSP_Length(input.count)) + + // 合計で割って確率を得る + vDSP_vsdiv(result, 1, &sumOfExponents, &result, 1, vDSP_Length(input.count)) + } + } diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index 8713510..1b5276f 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -259,7 +259,7 @@ extension AVAudioPCMBuffer { // MARK: - Helpers @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -func prepareSeekClips(contentFrames: Int, decodeOptions: DecodingOptions?) -> [(start: Int, end: Int)] { +func prepareSeekClips(contentFrames: Int, decodeOptions: DecodingOptions?) -> [FrameRange] { let options = decodeOptions ?? DecodingOptions() var seekPoints: [Int] = options.clipTimestamps.map { Int(round($0 * Float(WhisperKit.sampleRate))) } if seekPoints.count == 0 { @@ -270,7 +270,7 @@ func prepareSeekClips(contentFrames: Int, decodeOptions: DecodingOptions?) -> [( seekPoints.append(contentFrames) } - var seekClips: [(start: Int, end: Int)] = [] + var seekClips: [FrameRange] = [] for i in stride(from: 0, to: seekPoints.count, by: 2) { let start = seekPoints[i] let end = i + 1 < seekPoints.count ? seekPoints[i + 1] : contentFrames @@ -836,7 +836,7 @@ public class Logging { } extension Logging { - enum AudioEncoding { + enum AudioEncoding: Sendable { static let logger = Logger( subsystem: Constants.Logging.subsystem, category: "AudioEncoding" @@ -846,7 +846,7 @@ extension Logging { } extension Logging { - enum FeatureExtractor { + enum FeatureExtractor: Sendable { static let logger = Logger( subsystem: Constants.Logging.subsystem, category: "FeatureExtractor" @@ -856,7 +856,7 @@ extension Logging { } extension Logging { - enum TranscribeTask { + enum TranscribeTask: Sendable { static let logger = Logger( subsystem: Constants.Logging.subsystem, category: "TranscribeTask" diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index c1b66d5..9aaae39 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -16,7 +16,7 @@ open class WhisperKit { public private(set) var modelState: ModelState = .unloaded public var modelCompute: ModelComputeOptions public var tokenizer: WhisperTokenizer? - + /// Protocols public var audioProcessor: any AudioProcessing public var featureExtractor: any FeatureExtracting @@ -24,23 +24,23 @@ open class WhisperKit { public var textDecoder: any TextDecoding public var logitsFilters: [any LogitsFiltering] public var segmentSeeker: any SegmentSeeking - + /// Shapes public static let sampleRate: Int = 16000 public static let hopLength: Int = 160 public static let chunkLength: Int = 30 // seconds public static let windowSamples: Int = 480_000 // sampleRate * chunkLength public static let secondsPerTimeToken = Float(0.02) - + /// Progress public private(set) var currentTimings: TranscriptionTimings public private(set) var progress = Progress() - + /// Configuration public var modelFolder: URL? public var tokenizerFolder: URL? public private(set) var useBackgroundDownloadSession: Bool - + public init(_ config: WhisperKitConfig = WhisperKitConfig()) async throws { modelCompute = config.computeOptions ?? ModelComputeOptions() audioProcessor = config.audioProcessor ?? AudioProcessor() @@ -53,7 +53,7 @@ open class WhisperKit { useBackgroundDownloadSession = config.useBackgroundDownloadSession currentTimings = TranscriptionTimings() Logging.shared.logLevel = config.verbose ? config.logLevel : .none - + try await setupModels( model: config.model, downloadBase: config.downloadBase, @@ -61,19 +61,19 @@ open class WhisperKit { modelFolder: config.modelFolder, download: config.download ) - + if let prewarm = config.prewarm, prewarm { Logging.info("Prewarming models...") try await prewarmModels() } - + // If load is not passed in, load based on whether a modelFolder is passed if config.load ?? (config.modelFolder != nil) { Logging.info("Loading models...") try await loadModels() } } - + public convenience init( model: String? = nil, downloadBase: URL? = nil, @@ -113,21 +113,21 @@ open class WhisperKit { load: load, download: download, useBackgroundDownloadSession: useBackgroundDownloadSession - ) + ) try await self.init(config) } - + // MARK: - Model Loading - + public static func recommendedModels() -> (default: String, disabled: [String]) { let deviceName = Self.deviceName() Logging.debug("Running on \(deviceName)") - + let defaultModel = modelSupport(for: deviceName).default let disabledModels = modelSupport(for: deviceName).disabled return (defaultModel, disabledModels) } - + public static func deviceName() -> String { var utsname = utsname() uname(&utsname) @@ -138,14 +138,14 @@ open class WhisperKit { } return deviceName } - + public static func fetchAvailableModels(from repo: String = "argmaxinc/whisperkit-coreml", matching: [String] = ["openai_*", "distil-whisper_*"]) async throws -> [String] { let hubApi = HubApi() let modelFiles = try await hubApi.getFilenames(from: repo, matching: matching) - + return formatModelFiles(modelFiles) } - + public static func formatModelFiles(_ modelFiles: [String]) -> [String] { let modelFilters = ModelVariant.allCases.map { "\($0.description)\($0.description.contains("large") ? "" : "/")" } // Include quantized models for large let modelVariants = modelFiles.map { $0.components(separatedBy: "/")[0] + "/" } @@ -156,32 +156,32 @@ open class WhisperKit { } return count > 0 }) - + let availableModels = filteredVariants.map { variant -> String in variant.trimmingFromEnd(character: "/", upto: 1) } - + // Sorting order based on enum let sizeOrder = ModelVariant.allCases.map { $0.description } - + let sortedModels = availableModels.sorted { firstModel, secondModel in // Extract the base size without any additional qualifiers let firstModelBase = sizeOrder.first(where: { firstModel.contains($0) }) ?? "" let secondModelBase = sizeOrder.first(where: { secondModel.contains($0) }) ?? "" - + if firstModelBase == secondModelBase { // If base sizes are the same, sort alphabetically return firstModel < secondModel } else { // Sort based on the size order return sizeOrder.firstIndex(of: firstModelBase) ?? sizeOrder.count - < sizeOrder.firstIndex(of: secondModelBase) ?? sizeOrder.count + < sizeOrder.firstIndex(of: secondModelBase) ?? sizeOrder.count } } - + return sortedModels } - + public static func download( variant: String, downloadBase: URL? = nil, @@ -196,9 +196,9 @@ open class WhisperKit { Logging.debug("Searching for models matching \"\(modelSearchPath)\" in \(repo)") let modelFiles = try await hubApi.getFilenames(from: repo, matching: [modelSearchPath]) var uniquePaths = Set(modelFiles.map { $0.components(separatedBy: "/").first! }) - + var variantPath: String? = nil - + if uniquePaths.count == 1 { variantPath = uniquePaths.first } else { @@ -208,17 +208,17 @@ open class WhisperKit { Logging.debug("Searching for models matching \"\(adjustedModelSearchPath)\" in \(repo)") let adjustedModelFiles = try await hubApi.getFilenames(from: repo, matching: [adjustedModelSearchPath]) uniquePaths = Set(adjustedModelFiles.map { $0.components(separatedBy: "/").first! }) - + if uniquePaths.count == 1 { variantPath = uniquePaths.first } } - + guard let variantPath else { // If there is still ambiguity, throw an error throw WhisperError.modelsUnavailable("Multiple models found matching \"\(modelSearchPath)\"") } - + Logging.debug("Downloading model \(variantPath)...") let modelFolder = try await hubApi.snapshot(from: repo, matching: [modelSearchPath]) { progress in Logging.debug(progress) @@ -226,7 +226,7 @@ open class WhisperKit { callback(progress) } } - + let modelFolderName = modelFolder.appending(path: variantPath) return modelFolderName } catch { @@ -234,7 +234,7 @@ open class WhisperKit { throw error } } - + /// Sets up the model folder either from a local path or by downloading from a repository. open func setupModels( model: String?, @@ -245,7 +245,7 @@ open class WhisperKit { ) async throws { // Determine the model variant to use let modelVariant = model ?? WhisperKit.recommendedModels().default - + // If a local model folder is provided, use it; otherwise, download the model if let folder = modelFolder { self.modelFolder = URL(fileURLWithPath: folder) @@ -267,36 +267,36 @@ open class WhisperKit { } } } - + open func prewarmModels() async throws { try await loadModels(prewarmMode: true) } - + open func loadModels( prewarmMode: Bool = false ) async throws { modelState = prewarmMode ? .prewarming : .loading - + let modelLoadStart = CFAbsoluteTimeGetCurrent() - + guard let path = modelFolder else { throw WhisperError.modelsUnavailable("Model folder is not set.") } - + Logging.debug("Loading models from \(path.path) with prewarmMode: \(prewarmMode)") - + // Find either mlmodelc or mlpackage models let logmelUrl = detectModelURL(inFolder: path, named: "MelSpectrogram") let encoderUrl = detectModelURL(inFolder: path, named: "AudioEncoder") let decoderUrl = detectModelURL(inFolder: path, named: "TextDecoder") let decoderPrefillUrl = detectModelURL(inFolder: path, named: "TextDecoderContextPrefill") - + for item in [logmelUrl, encoderUrl, decoderUrl] { if !FileManager.default.fileExists(atPath: item.path) { throw WhisperError.modelsUnavailable("Model file not found at \(item.path)") } } - + if let featureExtractor = featureExtractor as? WhisperMLModel { Logging.debug("Loading feature extractor") try await featureExtractor.loadModel( @@ -306,7 +306,7 @@ open class WhisperKit { ) Logging.debug("Loaded feature extractor") } - + if FileManager.default.fileExists(atPath: decoderPrefillUrl.path) { Logging.debug("Loading text decoder prefill data") textDecoder.prefillData = TextDecoderContextPrefill() @@ -317,7 +317,7 @@ open class WhisperKit { ) Logging.debug("Loaded text decoder prefill data") } - + if let textDecoder = textDecoder as? WhisperMLModel { Logging.debug("Loading text decoder") let decoderLoadStart = CFAbsoluteTimeGetCurrent() @@ -327,30 +327,30 @@ open class WhisperKit { prewarmMode: prewarmMode ) currentTimings.decoderLoadTime = CFAbsoluteTimeGetCurrent() - decoderLoadStart - + Logging.debug("Loaded text decoder in \(String(format: "%.2f", currentTimings.decoderLoadTime))s") } - + if let audioEncoder = audioEncoder as? WhisperMLModel { Logging.debug("Loading audio encoder") let encoderLoadStart = CFAbsoluteTimeGetCurrent() - + try await audioEncoder.loadModel( at: encoderUrl, computeUnits: modelCompute.audioEncoderCompute, prewarmMode: prewarmMode ) currentTimings.encoderLoadTime = CFAbsoluteTimeGetCurrent() - encoderLoadStart - + Logging.debug("Loaded audio encoder in \(String(format: "%.2f", currentTimings.encoderLoadTime))s") } - + if prewarmMode { modelState = .prewarmed currentTimings.prewarmLoadTime = CFAbsoluteTimeGetCurrent() - modelLoadStart return } - + // Check model dimensions to assign appropriate tokenizer guard let logitsDim = textDecoder.logitsSize, let encoderDim = audioEncoder.embedSize else { throw WhisperError.tokenizerUnavailable() @@ -359,55 +359,67 @@ open class WhisperKit { modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim) Logging.debug("Loading tokenizer for \(modelVariant)") let tokenizerLoadStart = CFAbsoluteTimeGetCurrent() - + let tokenizer = try await loadTokenizer( for: modelVariant, tokenizerFolder: tokenizerFolder, useBackgroundSession: useBackgroundDownloadSession ) currentTimings.tokenizerLoadTime = CFAbsoluteTimeGetCurrent() - tokenizerLoadStart - + self.tokenizer = tokenizer textDecoder.tokenizer = tokenizer Logging.debug("Loaded tokenizer in \(String(format: "%.2f", currentTimings.tokenizerLoadTime))s") - + modelState = .loaded - + currentTimings.modelLoading = CFAbsoluteTimeGetCurrent() - modelLoadStart + currentTimings.prewarmLoadTime - + Logging.info("Loaded models for whisper size: \(modelVariant) in \(String(format: "%.2f", currentTimings.modelLoading))s") } - + open func unloadModels() async { modelState = .unloading - + for model in [featureExtractor, audioEncoder, textDecoder] { if let model = model as? WhisperMLModel { model.unloadModel() } } - + modelState = .unloaded - + Logging.info("Unloaded all models") } - + open func clearState() { - audioProcessor.stopRecording() + Task { + await audioProcessor.stopRecording() + } currentTimings = TranscriptionTimings() } - + deinit { - audioProcessor.stopRecording() + modelState = .unloading + if let featureExtractor = featureExtractor as? WhisperMLModel { + featureExtractor.unloadModel() + } + if let audioEncoder = audioEncoder as? WhisperMLModel { + audioEncoder.unloadModel() + } + if let textDecoder = textDecoder as? WhisperMLModel { + textDecoder.unloadModel() + } + modelState = .unloaded } - + /// Pass in your own logging callback here open func loggingCallback(_ callback: Logging.LoggingCallback?) { Logging.shared.loggingCallback = callback } - + // MARK: - Detect language - + /// Detects the language of the audio file at the specified path. /// /// - Parameter audioPath: The file path of the audio file. @@ -417,35 +429,35 @@ open class WhisperKit { ) async throws -> (language: String, langProbs: [String: Float]) { let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) - return try await detectLangauge(audioArray: audioArray) + return try await detectLanguage(audioArray: audioArray) } - + /// Detects the language of the audio samples in the provided array. /// /// - Parameter audioArray: An array of audio samples. /// - Returns: A tuple containing the detected language and the language log probabilities. - open func detectLangauge( + open func detectLanguage( audioArray: [Float] ) async throws -> (language: String, langProbs: [String: Float]) { if modelState != .loaded { try await loadModels() } - + // Ensure the model is multilingual, as language detection is only supported for these models guard textDecoder.isModelMultilingual else { throw WhisperError.decodingFailed("Language detection not supported for this model") } - + // Tokenizer required for decoding guard let tokenizer else { throw WhisperError.tokenizerUnavailable() } - + let options = DecodingOptions(verbose: Logging.shared.logLevel != .none) let decoderInputs = try textDecoder.prepareDecoderInputs(withPrompt: [tokenizer.specialTokens.startOfTranscriptToken]) decoderInputs.kvCacheUpdateMask[0] = 1.0 decoderInputs.decoderKeyPaddingMask[0] = 0.0 - + // Detect language using up to the first 30 seconds guard let audioSamples = AudioProcessor.padOrTrimAudio(fromArray: audioArray, startAt: 0, toLength: WhisperKit.windowSamples) else { throw WhisperError.transcriptionFailed("Audio samples are nil") @@ -456,7 +468,7 @@ open class WhisperKit { guard let encoderOutput = try await audioEncoder.encodeFeatures(melOutput) else { throw WhisperError.transcriptionFailed("Encoder output is nil") } - + let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: tokenizer.specialTokens.endToken, decodingOptions: options) guard let languageDecodingResult: DecodingResult = try? await textDecoder.detectLanguage( from: encoderOutput, @@ -467,12 +479,12 @@ open class WhisperKit { ) else { throw WhisperError.decodingFailed("Language detection failed") } - + return (language: languageDecodingResult.language, langProbs: languageDecodingResult.languageProbs) } - + // MARK: - Transcribe multiple audio files - + /// Convenience method to transcribe multiple audio files asynchronously and return the results as an array of optional arrays of `TranscriptionResult`. /// - Returns: An array of optional arrays containing `TranscriptionResult`. open func transcribe( @@ -488,7 +500,7 @@ open class WhisperKit { let results = transcribeResults.toOptionalArrays() return results } - + /// Transcribes multiple audio files asynchronously and returns the results as an array of tuples containing the file path and the `Result` object. /// /// This method processes the provided audio file paths by loading the audio data and then transcribing the audio arrays. @@ -507,45 +519,45 @@ open class WhisperKit { ) async -> [Result<[TranscriptionResult], Swift.Error>] { // Start timing the audio loading and conversion process let loadAudioStart = Date() - + // Load and extract audio data from the provided file paths let loadedAudioResult = await AudioProcessor.loadAudio(at: audioPaths) let audioArrays = loadedAudioResult.compactMap { try? $0.get() } - + // Calculate the time taken to load and convert audio let loadAndConvertTime = Date().timeIntervalSince(loadAudioStart) currentTimings.audioLoading = loadAndConvertTime Logging.debug("Total Audio Loading and Converting Time: \(loadAndConvertTime)") - + // Transcribe the loaded audio arrays let transcribeResults = await transcribeWithResults( audioArrays: audioArrays, decodeOptions: decodeOptions, callback: callback ) - + // Initialize the result array to hold final transcription results var result = [Result<[TranscriptionResult], Swift.Error>]() var transcribeResultIndex = 0 - + // Iterate over loadedAudioResult and map each to the corresponding transcription result for audioResult in loadedAudioResult { switch audioResult { - case .success: - // Append transcription result if audio loading was successful (may still contain failure) - result.append(transcribeResults[transcribeResultIndex]) - transcribeResultIndex += 1 - case let .failure(error): - // Append failure result if audio loading failed - result.append(.failure(error)) + case .success: + // Append transcription result if audio loading was successful (may still contain failure) + result.append(transcribeResults[transcribeResultIndex]) + transcribeResultIndex += 1 + case let .failure(error): + // Append failure result if audio loading failed + result.append(.failure(error)) } } - + return result } - + // MARK: - Transcribe multiple audio arrays - + /// Convenience method to transcribe multiple audio arrays asynchronously and return the results as an array of optional arrays of `TranscriptionResult`. /// - Returns: An array of optional arrays containing `TranscriptionResult`. open func transcribe( @@ -558,10 +570,10 @@ open class WhisperKit { decodeOptions: decodeOptions, callback: callback ) - + return transcribeResults.toOptionalArrays() } - + /// Transcribes multiple audio arrays asynchronously and returns the results as an array of `Result` objects. /// /// This method processes the provided audio arrays by dividing them into batches based on the concurrent worker count @@ -587,7 +599,7 @@ open class WhisperKit { callback: callback ) } - + /// Method to transcribe multiple audio arrays asynchronously with optional associated decoding options and return the results as an array of `Result` objects. /// - Parameters: /// - audioArrays: An array of arrays, each containing audio @@ -601,18 +613,18 @@ open class WhisperKit { callback: TranscriptionCallback = nil ) async -> [Result<[TranscriptionResult], Swift.Error>] { var result = [Result<[TranscriptionResult], Swift.Error>]() - + guard audioArrays.count == decodeOptionsArray.count else { return [.failure(WhisperError.transcriptionFailed("The number of audio arrays and decoding options must be balanced."))] } - + // Determine the number of concurrent workers from decodeOptions based on the maximum value or default to 0 let concurrentWorkerCount = decodeOptionsArray.map { $0?.concurrentWorkerCount ?? 0 }.max() ?? 0 - + // Chunk the audio arrays based on the number of concurrent workers // If concurrentWorkerCount is 0, all audio arrays are processed in one batch let batchedAudioArrays = concurrentWorkerCount == 0 ? [audioArrays] : audioArrays.batched(into: concurrentWorkerCount) - + for (batchIndex, audioArrayBatch) in batchedAudioArrays.enumerated() { // Use withTaskGroup to manage concurrent transcription tasks let partialResult = await withTaskGroup(of: [(index: Int, result: Result<[TranscriptionResult], Swift.Error>)].self) { taskGroup -> [Result<[TranscriptionResult], Swift.Error>] in @@ -623,10 +635,10 @@ open class WhisperKit { batchedProgress.windowId = audioIndex + batchIndex * audioArrayBatch.count return callback?(batchedProgress) } - + // Setup decoding options for the current audio array let batchedDecodeOptions = decodeOptionsArray[audioIndex] - + // Add a new task to the task group for each audio array taskGroup.addTask { do { @@ -643,29 +655,29 @@ open class WhisperKit { } } } - + // Collect results from all completed tasks in the task group var batchResult = [(index: Int, result: Result<[TranscriptionResult], Swift.Error>)]() for await result in taskGroup { batchResult.append(contentsOf: result) } - + // Sort the results by index to maintain the original order (they may not be in order due to concurrency) batchResult.sort(by: { $0.index < $1.index }) - + // Map the sorted batch results to a simple array of results return batchResult.map { $0.result } } - + // Append the results of each batch to the final result array result.append(contentsOf: partialResult) } - + return result } - + // MARK: - Transcribe single audio file - + @available(*, deprecated, message: "Subject to removal in a future version. Use `transcribe(audioPath:decodeOptions:callback:) async throws -> [TranscriptionResult]` instead.") @_disfavoredOverload open func transcribe( @@ -676,7 +688,7 @@ open class WhisperKit { let result: [TranscriptionResult] = try await transcribe(audioPath: audioPath, decodeOptions: decodeOptions, callback: callback) return result.first } - + /// Transcribes an audio file from the given path asynchronously. /// - Parameters: /// - audioPath: The file path to the audio file to be transcribed. @@ -693,24 +705,24 @@ open class WhisperKit { let loadAudioStart = Date() let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioPath) let loadTime = Date().timeIntervalSince(loadAudioStart) - + let convertAudioStart = Date() let audioArray = AudioProcessor.convertBufferToArray(buffer: audioBuffer) let convertTime = Date().timeIntervalSince(convertAudioStart) currentTimings.audioLoading = loadTime + convertTime Logging.debug("Audio loading time: \(loadTime), Audio convert time: \(convertTime)") - + let transcribeResults: [TranscriptionResult] = try await transcribe( audioArray: audioArray, decodeOptions: decodeOptions, callback: callback ) - + return transcribeResults } - + // MARK: - Transcribe single audio sample array - + /// Deprecated @available(*, deprecated, message: "Subject to removal in a future version. Use `transcribe(audioArray:decodeOptions:callback:) async throws -> [TranscriptionResult]` instead.") @_disfavoredOverload @@ -722,7 +734,7 @@ open class WhisperKit { let result: [TranscriptionResult] = try await transcribe(audioArray: audioArray, decodeOptions: decodeOptions, callback: callback) return result.first } - + /// Main entry point for transcribing audio /// - Parameters: /// - audioArray: Array of 16khz raw float audio samples @@ -736,11 +748,11 @@ open class WhisperKit { callback: TranscriptionCallback = nil ) async throws -> [TranscriptionResult] { var transcribeResults = [TranscriptionResult]() - + // Determine if the audio array requires chunking let isChunkable = audioArray.count > WhisperKit.windowSamples switch (isChunkable, decodeOptions?.chunkingStrategy) { - case (true, .vad): + case (true, .vad): // We have some audio that will require multiple windows and a strategy to chunk them let vad = decodeOptions?.voiceActivityDetector ?? EnergyVAD() let chunker = VADAudioChunker(vad: vad) @@ -749,35 +761,35 @@ open class WhisperKit { maxChunkLength: WhisperKit.windowSamples, decodeOptions: decodeOptions ) - - // Reset the seek times since we've already chunked the audio - var chunkedOptions = decodeOptions - chunkedOptions?.clipTimestamps = [] - let chunkedDecodeOptions = Array(repeating: chunkedOptions, count: audioChunks.count) - - // Send chunked samples to transcribe (note: this is recursive) - let chunkedResults: [Result<[TranscriptionResult], Swift.Error>] = await transcribeWithOptions( - audioArrays: audioChunks.map { $0.audioSamples }, - decodeOptionsArray: chunkedDecodeOptions, - callback: callback - ) - - // Update the seek offsets based on the audio chunks - let updatedTranscriptionResults = chunker.updateSeekOffsetsForResults( - chunkedResults: chunkedResults, - audioChunks: audioChunks - ) - - transcribeResults = updatedTranscriptionResults - default: - // Audio is short enough to transcribe in a single window and doesn't require chunking - transcribeResults = try await runTranscribeTask( - audioArray: audioArray, - decodeOptions: decodeOptions, - callback: callback - ) + + // Reset the seek times since we've already chunked the audio + var chunkedOptions = decodeOptions + chunkedOptions?.clipTimestamps = [] + let chunkedDecodeOptions = Array(repeating: chunkedOptions, count: audioChunks.count) + + // Send chunked samples to transcribe (note: this is recursive) + let chunkedResults: [Result<[TranscriptionResult], Swift.Error>] = await transcribeWithOptions( + audioArrays: audioChunks.map { $0.audioSamples }, + decodeOptionsArray: chunkedDecodeOptions, + callback: callback + ) + + // Update the seek offsets based on the audio chunks + let updatedTranscriptionResults = chunker.updateSeekOffsetsForResults( + chunkedResults: chunkedResults, + audioChunks: audioChunks + ) + + transcribeResults = updatedTranscriptionResults + default: + // Audio is short enough to transcribe in a single window and doesn't require chunking + transcribeResults = try await runTranscribeTask( + audioArray: audioArray, + decodeOptions: decodeOptions, + callback: callback + ) } - + if let decodeOptions, decodeOptions.verbose { Logging.info("Total Transcription Results: \(transcribeResults.count)") for (i, transcribeTaskResult) in transcribeResults.enumerated() { @@ -785,10 +797,10 @@ open class WhisperKit { transcribeTaskResult.logSegments() } } - + return transcribeResults } - + /// Runs the transcription task on a single audio sample array asynchronously. /// - Returns: An array of `TranscriptionResult`. /// - Throws: An error if the transcription fails or if the tokenizer is unavailable. @@ -800,16 +812,16 @@ open class WhisperKit { if modelState != .loaded { try await loadModels() } - + guard let tokenizer else { // Tokenizer required for decoding throw WhisperError.tokenizerUnavailable() } - + let childProgress = Progress() progress.totalUnitCount += 1 progress.addChild(childProgress, withPendingUnitCount: 1) - + let transcribeTask = TranscribeTask( currentTimings: currentTimings, progress: childProgress, @@ -819,25 +831,25 @@ open class WhisperKit { textDecoder: textDecoder, tokenizer: tokenizer ) - + do { try Task.checkCancellation() - + let transcribeTaskResult = try await transcribeTask.run( audioArray: audioArray, decodeOptions: decodeOptions, callback: callback ) - + if let decodeOptions, decodeOptions.verbose { transcribeTaskResult.logTimings() } - + if progress.isFinished { // Reset progress if it is completed progress = Progress() } - + return [transcribeTaskResult] } catch { // Handle cancellation diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 709e4ec..c92f222 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -1132,38 +1132,38 @@ final class UnitTests: XCTestCase { // When looking for silence boundaries, a smaller frame length is preferred let vadForSilence = EnergyVAD(frameLengthSamples: 320) let nonSilentChunks1 = vadForSilence.calculateActiveChunks(in: []) - XCTAssertEqual(nonSilentChunks1.map(\.startIndex), []) - XCTAssertEqual(nonSilentChunks1.map(\.endIndex), []) + XCTAssertEqual(nonSilentChunks1.map(\SampleRange.startIndex), []) + XCTAssertEqual(nonSilentChunks1.map(\SampleRange.endIndex), []) let nonSilentChunks2 = vadForSilence.calculateActiveChunks(in: Array(repeating: 0, count: 1600)) - XCTAssertEqual(nonSilentChunks2.map(\.startIndex), []) - XCTAssertEqual(nonSilentChunks2.map(\.endIndex), []) + XCTAssertEqual(nonSilentChunks2.map(\SampleRange.startIndex), []) + XCTAssertEqual(nonSilentChunks2.map(\SampleRange.endIndex), []) let nonSilentChunks3 = vadForSilence.calculateActiveChunks(in: Array(repeating: 1, count: 1600)) - XCTAssertEqual(nonSilentChunks3.map(\.startIndex), [0]) - XCTAssertEqual(nonSilentChunks3.map(\.endIndex), [1600]) + XCTAssertEqual(nonSilentChunks3.map(\SampleRange.startIndex), [0]) + XCTAssertEqual(nonSilentChunks3.map(\SampleRange.endIndex), [1600]) let nonSilentChunks4 = vadForSilence.calculateActiveChunks(in: Array(repeating: 0, count: 1600) + Array(repeating: 1, count: 1600)) - XCTAssertEqual(nonSilentChunks4.map(\.startIndex), [1600]) - XCTAssertEqual(nonSilentChunks4.map(\.endIndex), [3200]) + XCTAssertEqual(nonSilentChunks4.map(\SampleRange.startIndex), [1600]) + XCTAssertEqual(nonSilentChunks4.map(\SampleRange.endIndex), [3200]) let nonSilentChunksWithUnevenFrameLength1 = vadForSilence.calculateActiveChunks(in: Array(repeating: 1, count: 1601)) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength1.map(\.startIndex), [0]) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength1.map(\.endIndex), [1601]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength1.map(\SampleRange.startIndex), [0]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength1.map(\SampleRange.endIndex), [1601]) let nonSilentChunksWithUnevenFrameLength2 = vadForSilence.calculateActiveChunks(in: Array(repeating: 1, count: 1599)) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength2.map(\.startIndex), [0]) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength2.map(\.endIndex), [1599]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength2.map(\SampleRange.startIndex), [0]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength2.map(\SampleRange.endIndex), [1599]) let nonSilentChunksWithUnevenFrameLength3 = vadForSilence.calculateActiveChunks(in: Array(repeating: 1, count: 1599) + Array(repeating: 0, count: 1600)) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength3.map(\.startIndex), [0]) - XCTAssertEqual(nonSilentChunksWithUnevenFrameLength3.map(\.endIndex), [1600]) // frame length + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength3.map(\SampleRange.startIndex), [0]) + XCTAssertEqual(nonSilentChunksWithUnevenFrameLength3.map(\SampleRange.endIndex), [1600]) // frame length // Even with a smaller frame lenth, sometimes we need an overlap to detect them when they are very close to the boundary let vadWithOverlap = EnergyVAD(frameLengthSamples: 320, frameOverlapSamples: 80) let nonSilentChunksWithOverlap = vadWithOverlap.calculateActiveChunks(in: Array(repeating: 0, count: 1600) + Array(repeating: 1, count: 1600)) - XCTAssertEqual(nonSilentChunksWithOverlap.map(\.startIndex), [1280]) - XCTAssertEqual(nonSilentChunksWithOverlap.map(\.endIndex), [3200]) + XCTAssertEqual(nonSilentChunksWithOverlap.map(\SampleRange.startIndex), [1280]) + XCTAssertEqual(nonSilentChunksWithOverlap.map(\SampleRange.endIndex), [3200]) // When specifically looking for speech instead of silence, a larger window is preferred let vadWithLargeWindow = EnergyVAD(frameLength: 0.2, frameOverlap: 0.1)