Mercurial > pylearn
comparison dataset.py @ 5:8039918516fe
Added MinibatchIterator
author | bengioy@bengiomac.local |
---|---|
date | Sun, 23 Mar 2008 22:44:43 -0400 |
parents | f7dcfb5f9d5b |
children | d5738b79089a |
comparison
equal
deleted
inserted
replaced
4:f7dcfb5f9d5b | 5:8039918516fe |
---|---|
54 | 54 |
55 def __getslice__(self,*slice_args): | 55 def __getslice__(self,*slice_args): |
56 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | 56 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" |
57 raise NotImplementedError | 57 raise NotImplementedError |
58 | 58 |
59 def minibatches(self,minibatch_size): | |
60 """Return an iterator for the dataset that goes through minibatches of the given size.""" | |
61 return MinibatchIterator(self,minibatch_size) | |
62 | |
63 class MinibatchIterator(object): | |
64 """ | |
65 Iterator class for FiniteDataSet that can iterate by minibatches | |
66 (sub-dataset of consecutive examples). | |
67 """ | |
68 def __init__(self,dataset,minibatch_size): | |
69 assert minibatch_size>0 and minibatch_size<len(dataset) | |
70 self.dataset=dataset | |
71 self.minibatch_size=minibatch_size | |
72 self.current=-minibatch_size | |
73 def __iter__(self): | |
74 return self | |
75 def next(self): | |
76 self.current+=self.minibatch_size | |
77 if self.current>=len(self.dataset): | |
78 self.current=-self.minibatchsize | |
79 raise StopIteration | |
80 return self.dataset[self.current:self.current+self.minibatchsize] | |
81 | |
59 # we may want ArrayDataSet defined in another python file | 82 # we may want ArrayDataSet defined in another python file |
60 | 83 |
61 import numpy | 84 import numpy |
62 | 85 |
63 class ArrayDataSet(FiniteDataSet): | 86 class ArrayDataSet(FiniteDataSet): |
195 # copy the field here | 218 # copy the field here |
196 result[:,slice(c,slice_width)]=self.data[field_slice] | 219 result[:,slice(c,slice_width)]=self.data[field_slice] |
197 c+=slice_width | 220 c+=slice_width |
198 return result | 221 return result |
199 return self.data | 222 return self.data |
200 | 223 |