class Dataset(object):
def __init__(self, X, y, batch_size, shuffle=False):
"""
Construct a Dataset object to iterate over data X and labels y
Inputs:
- X: Numpy array of data, of any shape
- batch_size: Integer giving number of elements per minibatch
- shuffle: (optional) Boolean, whether to shuffle the data on each
"""
assert X.shape[0] == y.shape[0], 'Got different numbers of data and labels'
self.X, self.y = X, y
self.batch_size, self.shuffle = batch_size, shuffle
def __iter__(self):
N, B = self.X.shape[0], self.batch_size
idxs = np.arange(N)
if self.shuffle:
np.random.shuffle(idxs)
return iter((self.X[i:i+B], self.y[i:i+B]) for i in range(0, N, B))
问问题
95 次