diff --git a/python/setup.py b/python/setup.py index f37a98ff..fcb4bbe9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -44,7 +44,7 @@ def get_tag(self): setuptools.setup( name="vosk", - version="0.3.30", + version="0.3.31", author="Alpha Cephei Inc", author_email="contact@alphacephei.com", description="Offline open source speech recognition API based on Kaldi and Vosk", diff --git a/src/kaldi_recognizer.cc b/src/kaldi_recognizer.cc index 7c7f7158..1d32107c 100644 --- a/src/kaldi_recognizer.cc +++ b/src/kaldi_recognizer.cc @@ -142,13 +142,14 @@ KaldiRecognizer::~KaldiRecognizer() { delete g_fst_; delete decode_fst_; delete spk_feature_; - delete lm_fst_; - delete info; - delete lm_to_subtract_det_backoff; - delete lm_to_subtract_det_scale; - delete lm_to_add_orig; - delete lm_to_add; + delete rnnlm_info_; + delete lm_to_subtract_; + delete lm_to_subtract_scale_; + delete carpa_to_add_; + delete carpa_to_add_scale_; + delete rnnlm_to_add_; + delete rnnlm_to_add_scale_; model_->Unref(); if (spk_model_) @@ -166,21 +167,18 @@ void KaldiRecognizer::InitState() void KaldiRecognizer::InitRescoring() { - if (model_->rnnlm_lm_fst_) { - float lm_scale = 0.5; - int lm_order = 4; - - info = new kaldi::rnnlm::RnnlmComputeStateInfo(model_->rnnlm_compute_opts, model_->rnnlm, model_->word_embedding_mat); - lm_to_subtract_det_backoff = new fst::BackoffDeterministicOnDemandFst(*model_->rnnlm_lm_fst_); - lm_to_subtract_det_scale = new fst::ScaleDeterministicOnDemandFst(-lm_scale, lm_to_subtract_det_backoff); - lm_to_add_orig = new kaldi::rnnlm::KaldiRnnlmDeterministicFst(lm_order, *info); - lm_to_add = new fst::ScaleDeterministicOnDemandFst(lm_scale, lm_to_add_orig); - - } else if (model_->std_lm_fst_) { - fst::CacheOptions cache_opts(true, 50000); - fst::ArcMapFstOptions mapfst_opts(cache_opts); - fst::StdToLatticeMapper mapper; - lm_fst_ = new fst::ArcMapFst >(*model_->std_lm_fst_, mapper, mapfst_opts); + if (model_->graph_lm_fst_) { + lm_to_subtract_ = new fst::BackoffDeterministicOnDemandFst(*model_->graph_lm_fst_); + lm_to_subtract_scale_ = new fst::ScaleDeterministicOnDemandFst(-1.0, lm_to_subtract_); + carpa_to_add_ = new ConstArpaLmDeterministicFst(model_->const_arpa_); + + if (model_->rnnlm_enabled_) { + int lm_order = 4; + rnnlm_info_ = new kaldi::rnnlm::RnnlmComputeStateInfo(model_->rnnlm_compute_opts, model_->rnnlm, model_->word_embedding_mat); + rnnlm_to_add_ = new kaldi::rnnlm::KaldiRnnlmDeterministicFst(lm_order, *rnnlm_info_); + rnnlm_to_add_scale_ = new fst::ScaleDeterministicOnDemandFst(0.5, rnnlm_to_add_); + carpa_to_add_scale_ = new fst::ScaleDeterministicOnDemandFst(-0.5, carpa_to_add_); + } } } @@ -592,38 +590,30 @@ const char* KaldiRecognizer::GetResult() kaldi::CompactLattice rlat; decoder_->GetLattice(true, &clat); - if (model_->rnnlm_lm_fst_) { - kaldi::ComposeLatticePrunedOptions compose_opts; - compose_opts.lattice_compose_beam = 3.0; - compose_opts.max_arcs = 3000; - + if (lm_to_subtract_scale_ && carpa_to_add_) { TopSortCompactLatticeIfNeeded(&clat); - fst::ComposeDeterministicOnDemandFst combined_lms(lm_to_subtract_det_scale, lm_to_add); - CompactLattice composed_clat; - ComposeCompactLatticePruned(compose_opts, clat, - &combined_lms, &rlat); - lm_to_add_orig->Clear(); - } else if (model_->std_lm_fst_) { - Lattice lat1; - - ConvertLattice(clat, &lat1); - fst::ScaleLattice(fst::GraphLatticeScale(-1.0), &lat1); - fst::ArcSort(&lat1, fst::OLabelCompare()); - kaldi::Lattice composed_lat; - fst::Compose(lat1, *lm_fst_, &composed_lat); - fst::Invert(&composed_lat); - kaldi::CompactLattice determinized_lat; - DeterminizeLattice(composed_lat, &determinized_lat); - fst::ScaleLattice(fst::GraphLatticeScale(-1), &determinized_lat); - fst::ArcSort(&determinized_lat, fst::OLabelCompare()); - - kaldi::ConstArpaLmDeterministicFst const_arpa_fst(model_->const_arpa_); - kaldi::CompactLattice composed_clat; - kaldi::ComposeCompactLatticeDeterministic(determinized_lat, &const_arpa_fst, &composed_clat); - kaldi::Lattice composed_lat1; - ConvertLattice(composed_clat, &composed_lat1); - fst::Invert(&composed_lat1); - DeterminizeLattice(composed_lat1, &rlat); + CompactLattice tlat; + fst::ComposeDeterministicOnDemandFst combined_lm(lm_to_subtract_scale_, carpa_to_add_); + ComposeCompactLatticeDeterministic(clat, &combined_lm, &tlat); + + if (rnnlm_to_add_scale_) { + ComposeLatticePrunedOptions compose_opts; + compose_opts.lattice_compose_beam = 3.0; + compose_opts.max_arcs = 3000; + TopSortCompactLatticeIfNeeded(&tlat); + fst::ComposeDeterministicOnDemandFst combined_rnnlm(carpa_to_add_scale_, rnnlm_to_add_scale_); + ComposeCompactLatticePruned(compose_opts, tlat, + &combined_rnnlm, &rlat); + rnnlm_to_add_->Clear(); + } else { + rlat = tlat; + } + + kaldi::Lattice slat; + ConvertLattice(rlat, &slat); + fst::Invert(&slat); + DeterminizeLattice(slat, &rlat); + } else { rlat = clat; } diff --git a/src/kaldi_recognizer.h b/src/kaldi_recognizer.h index 62ef2815..1424e7e5 100644 --- a/src/kaldi_recognizer.h +++ b/src/kaldi_recognizer.h @@ -82,15 +82,18 @@ class KaldiRecognizer { OnlineBaseFeature *spk_feature_ = nullptr; // Rescoring - fst::ArcMapFst > *lm_fst_ = nullptr; + fst::BackoffDeterministicOnDemandFst *lm_to_subtract_ = nullptr; + fst::ScaleDeterministicOnDemandFst *lm_to_subtract_scale_ = nullptr; + kaldi::ConstArpaLmDeterministicFst *carpa_to_add_ = nullptr; + fst::ScaleDeterministicOnDemandFst *carpa_to_add_scale_ = nullptr; // RNNLM rescoring - kaldi::rnnlm::RnnlmComputeStateInfo *info = nullptr; - fst::ScaleDeterministicOnDemandFst *lm_to_subtract_det_scale = nullptr; - fst::BackoffDeterministicOnDemandFst *lm_to_subtract_det_backoff = nullptr; - kaldi::rnnlm::KaldiRnnlmDeterministicFst* lm_to_add_orig = nullptr; - fst::DeterministicOnDemandFst *lm_to_add = nullptr; + kaldi::rnnlm::KaldiRnnlmDeterministicFst* rnnlm_to_add_ = nullptr; + fst::DeterministicOnDemandFst *rnnlm_to_add_scale_ = nullptr; + kaldi::rnnlm::RnnlmComputeStateInfo *rnnlm_info_ = nullptr; + + // Other int max_alternatives_ = 0; // Disable alternatives by default bool words_ = false; diff --git a/src/model.cc b/src/model.cc index 0ae3ba48..4bcee355 100644 --- a/src/model.cc +++ b/src/model.cc @@ -174,7 +174,6 @@ void Model::ConfigureV1() rnnlm_feat_embedding_rxfilename_ = model_path_str_ + "/rnnlm/feat_embedding.final.mat"; rnnlm_config_rxfilename_ = model_path_str_ + "/rnnlm/special_symbol_opts.conf"; rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw"; - rnnlm_lm_fst_rxfilename_ = model_path_str_ + "/rescore/G.fst"; } void Model::ConfigureV2() @@ -203,7 +202,6 @@ void Model::ConfigureV2() rnnlm_feat_embedding_rxfilename_ = model_path_str_ + "/rnnlm/feat_embedding.final.mat"; rnnlm_config_rxfilename_ = model_path_str_ + "/rnnlm/special_symbol_opts.conf"; rnnlm_lm_rxfilename_ = model_path_str_ + "/rnnlm/final.raw"; - rnnlm_lm_fst_rxfilename_ = model_path_str_ + "/rescore/G.fst"; } void Model::ReadDataFiles() @@ -296,12 +294,19 @@ void Model::ReadDataFiles() winfo_ = new kaldi::WordBoundaryInfo(opts, winfo_rxfilename_); } + if (stat(carpa_rxfilename_.c_str(), &buffer) == 0) { + + KALDI_LOG << "Loading subtract G.fst model from " << std_fst_rxfilename_; + graph_lm_fst_ = fst::ReadAndPrepareLmFst(std_fst_rxfilename_); + KALDI_LOG << "Loading CARPA model from " << carpa_rxfilename_; + ReadKaldiObject(carpa_rxfilename_, &const_arpa_); + } + // RNNLM Rescoring if (stat(rnnlm_lm_rxfilename_.c_str(), &buffer) == 0) { KALDI_LOG << "Loading RNNLM model from " << rnnlm_lm_rxfilename_; ReadKaldiObject(rnnlm_lm_rxfilename_, &rnnlm); - rnnlm_lm_fst_ = fst::ReadAndPrepareLmFst(rnnlm_lm_fst_rxfilename_); Matrix feature_embedding_mat; ReadKaldiObject(rnnlm_feat_embedding_rxfilename_, &feature_embedding_mat); SparseMatrix word_feature_mat; @@ -319,17 +324,9 @@ void Model::ReadDataFiles() ReadConfigFromFile(rnnlm_config_rxfilename_, &rnnlm_compute_opts); - } else if (stat(carpa_rxfilename_.c_str(), &buffer) == 0) { - - KALDI_LOG << "Loading CARPA model from " << carpa_rxfilename_; - std_lm_fst_ = fst::ReadFstKaldi(std_fst_rxfilename_); - fst::Project(std_lm_fst_, fst::ProjectType::OUTPUT); - if (std_lm_fst_->Properties(fst::kILabelSorted, true) == 0) { - fst::ILabelCompare ilabel_comp; - fst::ArcSort(std_lm_fst_, ilabel_comp); - } - ReadKaldiObject(carpa_rxfilename_, &const_arpa_); + rnnlm_enabled_ = true; } + } void Model::Ref() @@ -363,5 +360,5 @@ Model::~Model() { delete hclg_fst_; delete hcl_fst_; delete g_fst_; - delete std_lm_fst_; + delete graph_lm_fst_; } diff --git a/src/model.h b/src/model.h index 03c046ce..3b68f3c4 100644 --- a/src/model.h +++ b/src/model.h @@ -72,7 +72,6 @@ class Model { string rnnlm_word_feats_rxfilename_; string rnnlm_feat_embedding_rxfilename_; string rnnlm_config_rxfilename_; - string rnnlm_lm_fst_rxfilename_; string rnnlm_lm_rxfilename_; kaldi::OnlineEndpointConfig endpoint_config_; @@ -92,13 +91,13 @@ class Model { fst::Fst *hcl_fst_ = nullptr; fst::Fst *g_fst_ = nullptr; - fst::VectorFst *std_lm_fst_ = nullptr; + fst::VectorFst *graph_lm_fst_ = nullptr; kaldi::ConstArpaLm const_arpa_; kaldi::rnnlm::RnnlmComputeStateComputationOptions rnnlm_compute_opts; CuMatrix word_embedding_mat; - fst::VectorFst *rnnlm_lm_fst_ = NULL; kaldi::nnet3::Nnet rnnlm; + bool rnnlm_enabled_ = false; std::atomic ref_cnt_; };