Mercurial > pylearn
diff dataset.py @ 19:57f4015e2e09
Iterators extend LookupList
author | bergstrj@iro.umontreal.ca |
---|---|
date | Thu, 27 Mar 2008 01:59:44 -0400 |
parents | 759d17112b23 |
children | 266c68cb6136 |
line wrap: on
line diff
--- a/dataset.py Thu Mar 27 00:19:16 2008 -0400 +++ b/dataset.py Thu Mar 27 01:59:44 2008 -0400 @@ -40,8 +40,7 @@ i[identifier], but the derived class is free to accept any type of identifier, and add extra functionality to the iterator. """ - for i in self.minibatches( minibatch_size = 1): - yield Example(i.keys(), [v[0] for v in i.values()]) + raise AbstractFunction() def zip(self, *fieldnames): """ @@ -61,8 +60,17 @@ The derived class may accept fieldname arguments of any type. """ - for i in self.minibatches(fieldnames, minibatch_size = 1): - yield [f[0] for f in i] + class Iter(LookupList): + def __init__(self, ll): + LookupList.__init__(self, ll.keys(), ll.values()) + self.ll = ll + def __iter__(self): #makes for loop work + return self + def next(self): + self.ll.next() + self._values = [v[0] for v in self.ll._values] + return self + return Iter(self.minibatches(fieldnames, minibatch_size = 1)) minibatches_fieldnames = None minibatches_minibatch_size = 1 @@ -177,6 +185,8 @@ assert minibatch_size>=1 and minibatch_size<=len(dataset) self.current = -self.minibatch_size self.fieldnames = fieldnames + if len(dataset) % minibatch_size: + raise NotImplementedError() def __iter__(self): return self @@ -287,11 +297,11 @@ by the numpy.array(dataset) call. """ - class Iterator(object): + class Iterator(LookupList): """An iterator over a finite dataset that implements wrap-around""" def __init__(self, dataset, fieldnames, minibatch_size, next_max): + LookupList.__init__(self, fieldnames, [0] * len(fieldnames)) self.dataset=dataset - self.fieldnames = fieldnames self.minibatch_size=minibatch_size self.next_count = 0 self.next_max = next_max @@ -300,8 +310,7 @@ if minibatch_size >= len(dataset): raise NotImplementedError() - def __iter__(self): - #Why do we do this? -JB + def __iter__(self): #makes for loop work return self @staticmethod @@ -323,28 +332,29 @@ raise StopIteration #determine the first and last elements of the slice we'll return + rows = self.dataset.data.shape[0] self.current += self.minibatch_size - if self.current >= len(self.dataset): - self.current -= len(self.dataset) + if self.current >= rows: + self.current -= rows upper = self.current + self.minibatch_size - if upper <= len(self.dataset): + data = self.dataset.data + + if upper <= rows: #this is the easy case, we only need once slice - dataview = self.dataset.data[self.current:upper] + dataview = data[self.current:upper] else: # the minibatch wraps around the end of the dataset - dataview = self.dataset.data[self.current:] - upper -= len(self.dataset) + dataview = data[self.current:] + upper -= rows assert upper > 0 - dataview = self.matcat(dataview, self.dataset.data[:upper]) + dataview = self.matcat(dataview, data[:upper]) - rval = [dataview[:, self.dataset.fields[f]] for f in self.fieldnames] + self._values = [dataview[:, self.dataset.fields[f]]\ + for f in self._names] - if self.fieldnames: - rval = Example(self.fieldnames, rval) - - return rval + return self def __init__(self, data, fields=None): @@ -372,6 +382,9 @@ # and coherent with the data array assert fieldslice.start >= 0 and fieldslice.stop <= cols + def __iter__(self): + return self.zip(*self.fieldNames()) + def minibatches(self, fieldnames = DataSet.minibatches_fieldnames, minibatch_size = DataSet.minibatches_minibatch_size,