diff --git a/dataset.py b/dataset.py index ebf4785..1d4e677 100644 --- a/dataset.py +++ b/dataset.py @@ -31,12 +31,14 @@ def __init__( tokenizer: transformers.PreTrainedTokenizer, max_len: int = 512, inference: bool = False, + n_mels: int = 80, ): super(SpeechDataset, self).__init__() print("Formatting inputs...") self.tokenizer = tokenizer self.max_len = max_len self.inference = inference + self.n_mels = n_mels self.raw_data = [] with open(data_path, "r") as f: for line in f: @@ -55,7 +57,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: audio = torchaudio.transforms.Resample(sample_rate, 16000)(audio) audio = audio[0] # get the first channel audio = whisper.pad_or_trim(audio) - mel = whisper.log_mel_spectrogram(audio) + mel = whisper.log_mel_spectrogram(audio, n_mels=self.n_mels) ids_audio = [0] * int(mel.shape[1] / 10) # 10x downsample tgt_audio = [IGNORE_TOKEN_ID] * len(ids_audio) chat = [{"role": "user", "content": "Transcribe the speech"}] diff --git a/speech_llm.py b/speech_llm.py index 3aa2f7f..75b8296 100644 --- a/speech_llm.py +++ b/speech_llm.py @@ -155,4 +155,4 @@ def init_model(model_args): model = SpeechLLM(config, llm_model, encoder, projector) if model_args.projector_model_path is not None: model.load_projector(model_args.projector_model_path) - return model + return model, encoder.dims.n_mels diff --git a/train.py b/train.py index 6a5fe74..2f7aa88 100644 --- a/train.py +++ b/train.py @@ -30,7 +30,7 @@ def main(): training_args, ) = parser.parse_args_into_dataclasses() - model = init_model(model_args) + model, n_mels = init_model(model_args) model.freeze_llm() model.freeze_encoder() @@ -48,11 +48,13 @@ def main(): print("Loading data...") train_dataset = SpeechDataset(data_args.data_path, tokenizer=tokenizer, - max_len=training_args.model_max_length) + max_len=training_args.model_max_length, + n_mels=n_mels) if data_args.eval_data_path: eval_dataset = SpeechDataset(data_args.eval_data_path, tokenizer=tokenizer, - max_len=training_args.model_max_length) + max_len=training_args.model_max_length, + n_mels=n_mels) else: eval_dataset = None # Start trainer