Skip to content

Commit

Permalink
Cleanup (#132)
Browse files Browse the repository at this point in the history
* Lint tests

* Lint library

* Lint examples

* Fix log prob alignment and timing

* Allow tokenizer to be loaded from disk if it exists already
  • Loading branch information
ZachNagengast authored May 1, 2024
1 parent c770b54 commit c20943d
Show file tree
Hide file tree
Showing 22 changed files with 245 additions and 224 deletions.
35 changes: 18 additions & 17 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct ContentView: View {
@State private var availableModels: [String] = []
@State private var availableLanguages: [String] = []
@State private var disabledModels: [String] = WhisperKit.recommendedModels().disabled

@AppStorage("selectedAudioInput") private var selectedAudioInput: String = "No Audio Input"
@AppStorage("selectedModel") private var selectedModel: String = WhisperKit.recommendedModels().default
@AppStorage("selectedTab") private var selectedTab: String = "Transcribe"
Expand Down Expand Up @@ -73,7 +73,6 @@ struct ContentView: View {
@State private var unconfirmedSegments: [TranscriptionSegment] = []
@State private var unconfirmedText: [String] = []


// MARK: Eager mode properties

@State private var eagerResults: [TranscriptionResult?] = []
Expand Down Expand Up @@ -274,7 +273,8 @@ struct ContentView: View {
!isRecording,
!isTranscribing,
whisperKit.progress.fractionCompleted > 0,
whisperKit.progress.fractionCompleted < 1 {
whisperKit.progress.fractionCompleted < 1
{
ProgressView(whisperKit.progress)
.progressViewStyle(.linear)
.labelsHidden()
Expand Down Expand Up @@ -314,7 +314,7 @@ struct ContentView: View {
.progressViewStyle(CircularProgressViewStyle())
.scaleEffect(0.5)
}

Button(action: {
deleteModel()
}, label: {
Expand Down Expand Up @@ -405,14 +405,15 @@ struct ContentView: View {
if let audioDevices = audioDevices,
!audioDevices.isEmpty,
selectedAudioInput == "No Audio Input",
let device = audioDevices.first {
let device = audioDevices.first
{
selectedAudioInput = device.name
}
}
#endif
}
}

var controlsView: some View {
VStack {
basicSettingsView
Expand Down Expand Up @@ -887,13 +888,12 @@ struct ContentView: View {
}
})
}

await MainActor.run {
loadingProgressValue = specializationProgressRatio
modelState = .downloaded
}


if let modelFolder = folder {
whisperKit.modelFolder = modelFolder

Expand Down Expand Up @@ -936,26 +936,26 @@ struct ContentView: View {
if !localModels.contains(model) {
localModels.append(model)
}

availableLanguages = Constants.languages.map { $0.key }.sorted()
loadingProgressValue = 1.0
modelState = whisperKit.modelState
}
}
}
}

func deleteModel() {
if localModels.contains(selectedModel) {
let modelFolder = URL(fileURLWithPath: localModelPath).appendingPathComponent(selectedModel)

do {
try FileManager.default.removeItem(at: modelFolder)

if let index = localModels.firstIndex(of: selectedModel) {
localModels.remove(at: index)
}

modelState = .unloaded
} catch {
print("Error deleting model: \(error)")
Expand Down Expand Up @@ -1058,18 +1058,19 @@ struct ContentView: View {
print("Microphone access was not granted.")
return
}

var deviceId: DeviceID?
#if os(macOS)
if self.selectedAudioInput != "No Audio Input",
let devices = self.audioDevices,
let device = devices.first(where: {$0.name == selectedAudioInput}) {
let device = devices.first(where: { $0.name == selectedAudioInput })
{
deviceId = device.id
}

// There is no built-in microphone
if deviceId == nil {
throw WhisperError.microphoneUnavailable()
throw WhisperError.microphoneUnavailable()
}
#endif

Expand Down Expand Up @@ -1403,7 +1404,7 @@ struct ContentView: View {
return nil
}

Logging.info("[EagerMode] \(lastAgreedSeconds)-\(Double(samples.count)/16000.0) seconds")
Logging.info("[EagerMode] \(lastAgreedSeconds)-\(Double(samples.count) / 16000.0) seconds")

let streamingAudio = samples
var streamOptions = options
Expand Down
2 changes: 0 additions & 2 deletions Examples/WhisperAX/WhisperAXTests/WhisperAXTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import XCTest

final class WhisperAXTests: XCTestCase {

override func setUpWithError() throws {
// Put setup code here. This method is called before the invocation of each test method in the class.
}
Expand All @@ -27,5 +26,4 @@ final class WhisperAXTests: XCTestCase {
// Put the code you want to measure the time of here.
}
}

}
1 change: 0 additions & 1 deletion Examples/WhisperAX/WhisperAXUITests/WhisperAXUITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import XCTest

final class WhisperAXUITests: XCTestCase {

override func setUpWithError() throws {
// Put setup code here. This method is called before the invocation of each test method in the class.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import XCTest

final class WhisperAXUITestsLaunchTests: XCTestCase {

override class var runsForEachTargetApplicationUIConfiguration: Bool {
true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ struct WhisperAXWatchView: View {
let currentTranscription = (confirmedSegments.map { $0.text } + unconfirmedSegments.map { $0.text }).joined(separator: " ")
ShareLink(item: currentTranscription, label: {
Image(systemName: "square.and.arrow.up")
}) }
})
}
ToolbarItem(placement: .bottomBar) {
Button {
withAnimation {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import XCTest
@testable import Basic_Watch_App
import XCTest

final class WhisperAX_Watch_AppTests: XCTestCase {

override func setUpWithError() throws {
// Put setup code here. This method is called before the invocation of each test method in the class.
}
Expand All @@ -28,5 +27,4 @@ final class WhisperAX_Watch_AppTests: XCTestCase {
// Put the code you want to measure the time of here.
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import XCTest

final class WhisperAX_Watch_AppUITests: XCTestCase {

override func setUpWithError() throws {
// Put setup code here. This method is called before the invocation of each test method in the class.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import XCTest

final class WhisperAX_Watch_AppUITestsLaunchTests: XCTestCase {

override class var runsForEachTargetApplicationUIConfiguration: Bool {
true
}
Expand Down
20 changes: 10 additions & 10 deletions Sources/WhisperKit/Core/LogitsFilter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ open class SuppressBlankFilter: LogitsFiltering {
self.sampleBegin = sampleBegin
self.suppressTokenIndexes = [
[0, 0, specialTokens.whitespaceToken as NSNumber],
[0, 0, specialTokens.endToken as NSNumber]
[0, 0, specialTokens.endToken as NSNumber],
]
}

Expand Down Expand Up @@ -75,10 +75,11 @@ open class TimestampRulesFilter: LogitsFiltering {

public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
guard let sampleBegin = sampleBegin(for: tokens),
sampleBegin > tokens.count else {
sampleBegin > tokens.count
else {
return logits
}

// suppress <|notimestamps|> which is handled by `withoutTimestamps`
logits.fill(indexes: [[0, 0, specialTokens.noTimestampsToken as NSNumber]], with: -FloatType.infinity)

Expand Down Expand Up @@ -244,7 +245,6 @@ open class TimestampRulesFilter: LogitsFiltering {
}
}


@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
open class LanguageLogitsFilter: LogitsFiltering {
let allLanguageTokens: Set<Int>
Expand All @@ -259,19 +259,19 @@ open class LanguageLogitsFilter: LogitsFiltering {
self.nonLanguageTokenIndexes = LanguageLogitsFilter.getNonLanguageTokenIndexes(logitsDim: self.logitsDim, allLanguageTokens: self.allLanguageTokens)
}

// Retain the logits that correspond to language tokens and suppress non-language tokens
/// Retain the logits that correspond to language tokens and suppress non-language tokens
public func filterLogits(_ logits: MLMultiArray, withTokens tokens: [Int]) -> MLMultiArray {
guard tokens.count == sampleBegin else{
guard tokens.count == sampleBegin else {
return logits
}
logits.fill(indexes: nonLanguageTokenIndexes, with: -FloatType.infinity)
return logits
}
private static func getNonLanguageTokenIndexes(logitsDim: Int, allLanguageTokens: Set<Int>) -> [[NSNumber]]{

private static func getNonLanguageTokenIndexes(logitsDim: Int, allLanguageTokens: Set<Int>) -> [[NSNumber]] {
var indexes: [[NSNumber]] = []
for i in 0..<logitsDim{
if !allLanguageTokens.contains(i){
for i in 0..<logitsDim {
if !allLanguageTokens.contains(i) {
indexes.append([0, 0, i as NSNumber])
}
}
Expand Down
Loading

0 comments on commit c20943d

Please sign in to comment.