Skip to content

Commit

Permalink
add drop_last
Browse files Browse the repository at this point in the history
  • Loading branch information
konas122 committed Mar 7, 2024
1 parent 4c053bb commit cb88bad
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions dazero/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@


class DataLoader:
def __init__(self, dataset, batch_size, shuffle=True, gpu=False):
def __init__(self, dataset, batch_size, shuffle=True, gpu=False, drop_last=False):
if not isinstance(dataset, Dataset):
raise TypeError("It must be of type dazero.Dataset")
self.dataset = dataset
self.drop_last = drop_last
self.batch_size = batch_size
self.shuffle = shuffle
self.data_size = len(dataset)
self.max_iter = math.ceil(self.data_size / batch_size)

if drop_last:
self.max_iter = self.data_size // batch_size
else:
self.max_iter = math.ceil(self.data_size / batch_size)
self.gpu = gpu

self._reset()
Expand Down

0 comments on commit cb88bad

Please sign in to comment.