From aa4bb901c8a09c8b5320ff4bb9fe64a095614d4e Mon Sep 17 00:00:00 2001 From: Zach Nagengast Date: Thu, 30 May 2024 05:55:37 -0700 Subject: [PATCH] Use hashmap to track early stopping (#155) --- .../WhisperAX/WhisperAX.xcodeproj/project.pbxproj | 8 ++++---- Examples/WhisperAX/WhisperAX/Views/ContentView.swift | 4 ++++ Sources/WhisperKit/Core/TextDecoder.swift | 12 ++++++++---- Sources/WhisperKit/Core/Utils.swift | 2 +- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj index 906a95d..58ede85 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj @@ -608,14 +608,14 @@ GENERATE_INFOPLIST_FILE = YES; INFOPLIST_KEY_NSMicrophoneUsageDescription = "Required to record audio from the microphone for transcription."; INFOPLIST_KEY_UISupportedInterfaceOrientations = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown"; - INFOPLIST_KEY_WKCompanionAppBundleIdentifier = com.argmax.whisperkit.WhisperAX; + INFOPLIST_KEY_WKCompanionAppBundleIdentifier = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}"; INFOPLIST_KEY_WKRunsIndependentlyOfCompanionApp = YES; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", ); MARKETING_VERSION = 0.1.2; - PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX.watchapp; + PRODUCT_BUNDLE_IDENTIFIER = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}.watchapp"; PRODUCT_NAME = "WhisperAX Watch App"; PROVISIONING_PROFILE_SPECIFIER = ""; SDKROOT = watchos; @@ -893,7 +893,7 @@ LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; MACOSX_DEPLOYMENT_TARGET = 14.0; - MARKETING_VERSION = 0.3.0; + MARKETING_VERSION = 0.3.1; PRODUCT_BUNDLE_IDENTIFIER = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}"; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; @@ -939,7 +939,7 @@ LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; MACOSX_DEPLOYMENT_TARGET = 14.0; - MARKETING_VERSION = 0.3.0; + MARKETING_VERSION = 0.3.1; PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = auto; diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 8bc4b8f..bada3ca 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -1308,10 +1308,12 @@ struct ContentView: View { let checkTokens: [Int] = currentTokens.suffix(checkWindow) let compressionRatio = compressionRatio(of: checkTokens) if compressionRatio > options.compressionRatioThreshold! { + Logging.debug("Early stopping due to compression threshold") return false } } if progress.avgLogprob! < options.logProbThreshold! { + Logging.debug("Early stopping due to logprob threshold") return false } return nil @@ -1519,10 +1521,12 @@ struct ContentView: View { let checkTokens: [Int] = currentTokens.suffix(checkWindow) let compressionRatio = compressionRatio(of: checkTokens) if compressionRatio > options.compressionRatioThreshold! { + Logging.debug("Early stopping due to compression threshold") return false } } if progress.avgLogprob! < options.logProbThreshold! { + Logging.debug("Early stopping due to logprob threshold") return false } diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 96e83ad..0012a45 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -344,7 +344,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { public var tokenizer: WhisperTokenizer? public var prefillData: WhisperMLModel? public var isModelMultilingual: Bool = false - public var shouldEarlyStop: Bool = false + public var shouldEarlyStop = [UUID: Bool]() private var languageLogitsFilter: LanguageLogitsFilter? public var supportsWordTimestamps: Bool { @@ -588,7 +588,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel { Logging.debug("Running main loop for a maximum of \(loopCount) iterations, starting at index \(prefilledIndex)") var hasAlignment = false var isFirstTokenLogProbTooLow = false - shouldEarlyStop = false + let windowUUID = UUID() + shouldEarlyStop[windowUUID] = false for tokenIndex in prefilledIndex.. String { let task = Process() let pipe = Pipe()