# HG changeset patch # User bergstrj@iro.umontreal.ca # Date 1206597584 14400 # Node ID 57f4015e2e094b0e0329dc9e2cc7a658056d9179 # Parent 60b164a0d84ae46b88ad59f42c6d2e99f2489c1f Iterators extend LookupList diff -r 60b164a0d84a -r 57f4015e2e09 _test_dataset.py --- a/_test_dataset.py Thu Mar 27 00:19:16 2008 -0400 +++ b/_test_dataset.py Thu Mar 27 01:59:44 2008 -0400 @@ -73,7 +73,7 @@ for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)): self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2])) - + if __name__ == '__main__': unittest.main() diff -r 60b164a0d84a -r 57f4015e2e09 dataset.py --- a/dataset.py Thu Mar 27 00:19:16 2008 -0400 +++ b/dataset.py Thu Mar 27 01:59:44 2008 -0400 @@ -40,8 +40,7 @@ i[identifier], but the derived class is free to accept any type of identifier, and add extra functionality to the iterator. """ - for i in self.minibatches( minibatch_size = 1): - yield Example(i.keys(), [v[0] for v in i.values()]) + raise AbstractFunction() def zip(self, *fieldnames): """ @@ -61,8 +60,17 @@ The derived class may accept fieldname arguments of any type. """ - for i in self.minibatches(fieldnames, minibatch_size = 1): - yield [f[0] for f in i] + class Iter(LookupList): + def __init__(self, ll): + LookupList.__init__(self, ll.keys(), ll.values()) + self.ll = ll + def __iter__(self): #makes for loop work + return self + def next(self): + self.ll.next() + self._values = [v[0] for v in self.ll._values] + return self + return Iter(self.minibatches(fieldnames, minibatch_size = 1)) minibatches_fieldnames = None minibatches_minibatch_size = 1 @@ -177,6 +185,8 @@ assert minibatch_size>=1 and minibatch_size<=len(dataset) self.current = -self.minibatch_size self.fieldnames = fieldnames + if len(dataset) % minibatch_size: + raise NotImplementedError() def __iter__(self): return self @@ -287,11 +297,11 @@ by the numpy.array(dataset) call. """ - class Iterator(object): + class Iterator(LookupList): """An iterator over a finite dataset that implements wrap-around""" def __init__(self, dataset, fieldnames, minibatch_size, next_max): + LookupList.__init__(self, fieldnames, [0] * len(fieldnames)) self.dataset=dataset - self.fieldnames = fieldnames self.minibatch_size=minibatch_size self.next_count = 0 self.next_max = next_max @@ -300,8 +310,7 @@ if minibatch_size >= len(dataset): raise NotImplementedError() - def __iter__(self): - #Why do we do this? -JB + def __iter__(self): #makes for loop work return self @staticmethod @@ -323,28 +332,29 @@ raise StopIteration #determine the first and last elements of the slice we'll return + rows = self.dataset.data.shape[0] self.current += self.minibatch_size - if self.current >= len(self.dataset): - self.current -= len(self.dataset) + if self.current >= rows: + self.current -= rows upper = self.current + self.minibatch_size - if upper <= len(self.dataset): + data = self.dataset.data + + if upper <= rows: #this is the easy case, we only need once slice - dataview = self.dataset.data[self.current:upper] + dataview = data[self.current:upper] else: # the minibatch wraps around the end of the dataset - dataview = self.dataset.data[self.current:] - upper -= len(self.dataset) + dataview = data[self.current:] + upper -= rows assert upper > 0 - dataview = self.matcat(dataview, self.dataset.data[:upper]) + dataview = self.matcat(dataview, data[:upper]) - rval = [dataview[:, self.dataset.fields[f]] for f in self.fieldnames] + self._values = [dataview[:, self.dataset.fields[f]]\ + for f in self._names] - if self.fieldnames: - rval = Example(self.fieldnames, rval) - - return rval + return self def __init__(self, data, fields=None): @@ -372,6 +382,9 @@ # and coherent with the data array assert fieldslice.start >= 0 and fieldslice.stop <= cols + def __iter__(self): + return self.zip(*self.fieldNames()) + def minibatches(self, fieldnames = DataSet.minibatches_fieldnames, minibatch_size = DataSet.minibatches_minibatch_size, diff -r 60b164a0d84a -r 57f4015e2e09 lookup_list.py --- a/lookup_list.py Thu Mar 27 00:19:16 2008 -0400 +++ b/lookup_list.py Thu Mar 27 01:59:44 2008 -0400 @@ -47,20 +47,28 @@ if key in self._name2index: self._values[self._name2index[key]]=value else: - self._name2index[key]=len(self) - self._values.append(value) - self._names.append(key) + raise KeyError(key) def __getattr__(self,name): - return self._values[self._name2index[name]] + try: + return self._values[self._name2index[name]] + except KeyError, e: + raise AttributeError(name) - def __setattr__(self,name,value): - if name in self._name2index: - self._values[self._name2index[name]]=value - else: - self._name2index[name]=len(self) - self._values.append(value) - self._names.append(name) + if 0: + # This makes subclassing horrible, just call append_keyval if it's + # really what you want to do. + # -JB + def __setattr__(self,name,value): + if name in self._name2index: + self._values[self._name2index[name]]=value + else: + raise AttributeError(name) + + def append_keyval(self, key, value): + self._name2index[key]=len(self) + self._values.append(value) + self._names.append(key) def __len__(self): return len(self._values)