diff --git a/configs/train_config.json b/configs/train_config.json index d3c2392..a833566 100644 --- a/configs/train_config.json +++ b/configs/train_config.json @@ -3,10 +3,10 @@ "data": { "data_dict_path": "uniref50_gpt_data.pkl", "data_df_path": "uniref50_gzip_subsample.csv", - "split_ratios": [0.8, 0.1, 0.1] + "split_ratios": [0.88, 0.02, 0.1] }, "datamodule": { - "batch_size": 4, + "batch_size": 8, "max_protein_length": 1500 }, "model": { diff --git a/cprt/data/cprt_datamodule.py b/cprt/data/cprt_datamodule.py index 9022abc..7e60b99 100644 --- a/cprt/data/cprt_datamodule.py +++ b/cprt/data/cprt_datamodule.py @@ -97,23 +97,25 @@ def train_dataloader(self) -> DataLoader: # type: ignore[type-arg] shuffle=True, collate_fn=self.collate_fn, num_workers=4, + drop_last=True, ) def val_dataloader(self) -> DataLoader: # type: ignore[type-arg] """Set up val loader.""" return DataLoader( self.val_dataset, - batch_size=self.batch_size, + batch_size=self.batch_size * 2, shuffle=False, collate_fn=self.collate_fn, num_workers=4, + drop_last=True, ) def test_dataloader(self) -> DataLoader: # type: ignore[type-arg] """Set up test loader.""" return DataLoader( self.test_dataset, - batch_size=self.batch_size, + batch_size=self.batch_size * 2, shuffle=False, collate_fn=self.collate_fn, num_workers=4,