Mercurial > pylearn
comparison dataset.py @ 6:d5738b79089a
Removed MinibatchIterator and instead made minibatch_size a field of all DataSets,
so that they can all iterate over minibatches, optionally.
author | bengioy@bengiomac.local |
---|---|
date | Mon, 24 Mar 2008 09:04:06 -0400 |
parents | 8039918516fe |
children | 6f8f338686db |
comparison
equal
deleted
inserted
replaced
5:8039918516fe | 6:d5738b79089a |
---|---|
4 """ | 4 """ |
5 This is a virtual base class or interface for datasets. | 5 This is a virtual base class or interface for datasets. |
6 A dataset is basically an iterator over examples. It does not necessarily | 6 A dataset is basically an iterator over examples. It does not necessarily |
7 have a fixed length (this is useful for 'streams' which feed on-line learning). | 7 have a fixed length (this is useful for 'streams' which feed on-line learning). |
8 Datasets with fixed and known length are FiniteDataSet, a subclass of DataSet. | 8 Datasets with fixed and known length are FiniteDataSet, a subclass of DataSet. |
9 Examples and datasets have named fields. | 9 Examples and datasets optionally have named fields. |
10 One can obtain a sub-dataset by taking dataset.field or dataset(field1,field2,field3,...). | 10 One can obtain a sub-dataset by taking dataset.field or dataset(field1,field2,field3,...). |
11 Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. | 11 Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. |
12 The content of a field can be of any type, but often will be a numpy tensor. | 12 The content of a field can be of any type, but often will be a numpy array. |
13 """ | 13 The minibatch_size field, if different than 1, means that the iterator (next() method) |
14 | 14 returns not a single example but an array of length minibatch_size, i.e., an indexable |
15 def __init__(self): | 15 object. |
16 pass | 16 """ |
17 | |
18 def __init__(self,minibatch_size=1): | |
19 assert minibatch_size>0 | |
20 self.minibatch_size=minibatch_size | |
17 | 21 |
18 def __iter__(self): | 22 def __iter__(self): |
19 return self | 23 return self |
20 | 24 |
21 def next(self): | 25 def next(self): |
22 """Return the next example in the dataset.""" | 26 """ |
27 Return the next example or the next minibatch in the dataset. | |
28 A minibatch (of length > 1) should be something one can iterate on again in order | |
29 to obtain the individual examples. If the dataset has fields, | |
30 then the example or the minibatch must have the same fields | |
31 (typically this is implemented by returning another (small) dataset, when | |
32 there are fields). | |
33 """ | |
23 raise NotImplementedError | 34 raise NotImplementedError |
24 | 35 |
25 def __getattr__(self,fieldname): | 36 def __getattr__(self,fieldname): |
26 """Return a sub-dataset containing only the given fieldname as field.""" | 37 """Return a sub-dataset containing only the given fieldname as field.""" |
27 return self(fieldname) | 38 return self(fieldname) |
39 Virtual interface, a subclass of DataSet for datasets which have a finite, known length. | 50 Virtual interface, a subclass of DataSet for datasets which have a finite, known length. |
40 Examples are indexed by an integer between 0 and self.length()-1, | 51 Examples are indexed by an integer between 0 and self.length()-1, |
41 and a subdataset can be obtained by slicing. | 52 and a subdataset can be obtained by slicing. |
42 """ | 53 """ |
43 | 54 |
44 def __init__(self): | 55 def __init__(self,minibatch_size): |
45 pass | 56 DataSet.__init__(self,minibatch_size) |
46 | 57 |
47 def __len__(self): | 58 def __len__(self): |
48 """len(dataset) returns the number of examples in the dataset.""" | 59 """len(dataset) returns the number of examples in the dataset.""" |
49 raise NotImplementedError | 60 raise NotImplementedError |
50 | 61 |
54 | 65 |
55 def __getslice__(self,*slice_args): | 66 def __getslice__(self,*slice_args): |
56 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | 67 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" |
57 raise NotImplementedError | 68 raise NotImplementedError |
58 | 69 |
59 def minibatches(self,minibatch_size): | |
60 """Return an iterator for the dataset that goes through minibatches of the given size.""" | |
61 return MinibatchIterator(self,minibatch_size) | |
62 | |
63 class MinibatchIterator(object): | |
64 """ | |
65 Iterator class for FiniteDataSet that can iterate by minibatches | |
66 (sub-dataset of consecutive examples). | |
67 """ | |
68 def __init__(self,dataset,minibatch_size): | |
69 assert minibatch_size>0 and minibatch_size<len(dataset) | |
70 self.dataset=dataset | |
71 self.minibatch_size=minibatch_size | |
72 self.current=-minibatch_size | |
73 def __iter__(self): | |
74 return self | |
75 def next(self): | |
76 self.current+=self.minibatch_size | |
77 if self.current>=len(self.dataset): | |
78 self.current=-self.minibatchsize | |
79 raise StopIteration | |
80 return self.dataset[self.current:self.current+self.minibatchsize] | |
81 | |
82 # we may want ArrayDataSet defined in another python file | 70 # we may want ArrayDataSet defined in another python file |
83 | 71 |
84 import numpy | 72 import numpy |
85 | 73 |
86 class ArrayDataSet(FiniteDataSet): | 74 class ArrayDataSet(FiniteDataSet): |
87 """ | 75 """ |
88 A fixed-length and fixed-width dataset in which each element is a numpy.array | 76 A fixed-length and fixed-width dataset in which each element is a numpy array |
89 or a number, hence the whole dataset corresponds to a numpy.array. Fields | 77 or a number, hence the whole dataset corresponds to a numpy array. Fields |
90 must correspond to a slice of columns. If the dataset has fields, | 78 must correspond to a slice of columns. If the dataset has fields, |
91 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. | 79 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array. |
92 Any dataset can also be converted to a numpy.array (losing the notion of fields) | 80 Any dataset can also be converted to a numpy array (losing the notion of fields) |
93 by the asarray(dataset) call. | 81 by the asarray(dataset) call. |
94 """ | 82 """ |
95 | 83 |
96 def __init__(self,dataset=None,data=None,fields={}): | 84 def __init__(self,dataset=None,data=None,fields={},minibatch_size=1): |
97 """ | 85 """ |
98 Construct an ArrayDataSet, either from a DataSet, or from | 86 Construct an ArrayDataSet, either from a DataSet, or from |
99 a numpy.array plus an optional specification of fields (by | 87 a numpy array plus an optional specification of fields (by |
100 a dictionary of column slices indexed by field names). | 88 a dictionary of column slices indexed by field names). |
101 """ | 89 """ |
90 FiniteDataSet.__init__(self,minibatch_size) | |
102 self.current_row=-1 # used for view of this dataset as an iterator | 91 self.current_row=-1 # used for view of this dataset as an iterator |
103 if dataset!=None: | 92 if dataset!=None: |
104 assert data==None and fields=={} | 93 assert data==None and fields=={} |
105 # convert dataset to an ArrayDataSet | 94 # convert dataset to an ArrayDataSet |
106 raise NotImplementedError | 95 raise NotImplementedError |
120 step=1 | 109 step=1 |
121 if not fieldslice.start or not fieldslice.step: | 110 if not fieldslice.start or not fieldslice.step: |
122 fieldslice = slice(start,fieldslice.stop,step) | 111 fieldslice = slice(start,fieldslice.stop,step) |
123 # and coherent with the data array | 112 # and coherent with the data array |
124 assert fieldslice.start>=0 and fieldslice.stop<=self.width | 113 assert fieldslice.start>=0 and fieldslice.stop<=self.width |
114 assert minibatch_size<=len(self.data) | |
125 | 115 |
126 def next(self): | 116 def next(self): |
127 """ | 117 """ |
128 Return the next example in the dataset. If the dataset has fields, | 118 Return the next example(s) in the dataset. If self.minibatch_size>1 return that |
129 the 'example' is just a one-row ArrayDataSet, otherwise it is a numpy.array. | 119 many examples. If the dataset has fields, the example or the minibatch of examples |
120 is just a minibatch_size-rows ArrayDataSet (so that the fields can be accessed), | |
121 but that resulting mini-dataset has a minibatch_size of 1, so that one can iterate | |
122 example-wise on it. On the other hand, if the dataset has no fields (e.g. because | |
123 it is already the field of a bigger dataset), then the returned example or minibatch | |
124 is a numpy array. Following the array semantics of indexing and slicing, | |
125 if the minibatch_size is 1 (and there are no fields), then the result is an array | |
126 with one less dimension (e.g., a vector, if the dataset is a matrix), corresponding | |
127 to a row. Again, if the minibatch_size is >1, one can iterate on the result to | |
128 obtain individual examples (as rows). | |
130 """ | 129 """ |
131 if self.fields: | 130 if self.fields: |
132 self.current_row+=1 | 131 self.current_row+=self.minibatch_size |
133 if self.current_row==len(self.data): | 132 if self.current_row>=len(self.data): |
134 self.current_row=-1 | 133 self.current_row=-self.minibatch_size |
135 raise StopIteration | 134 raise StopIteration |
136 return self[self.current_row] | 135 if self.minibatch_size==1: |
136 return self[self.current_row] | |
137 else: | |
138 return self[self.current_row:self.current_row+self.minibatch_size] | |
137 else: | 139 else: |
138 return self.data[self.current_row] | 140 if self.minibatch_size==1: |
141 return self.data[self.current_row] | |
142 else: | |
143 return self.data[self.current_row:self.current_row+self.minibatch_size] | |
139 | 144 |
140 def __getattr__(self,fieldname): | 145 def __getattr__(self,fieldname): |
141 """Return a sub-dataset containing only the given fieldname as field.""" | 146 """Return a numpy array with the content associated with the given field name.""" |
142 data=self.data[self.fields[fieldname]] | 147 return self.data[self.fields[fieldname]] |
143 if len(data)==1: | |
144 return data | |
145 else: | |
146 return ArrayDataSet(data=data) | |
147 | 148 |
148 def __call__(self,*fieldnames): | 149 def __call__(self,*fieldnames): |
149 """Return a sub-dataset containing only the given fieldnames as fields.""" | 150 """Return a sub-dataset containing only the given fieldnames as fields.""" |
150 min_col=self.data.shape[1] | 151 min_col=self.data.shape[1] |
151 max_col=0 | 152 max_col=0 |
153 min_col=min(min_col,field_slice.start) | 154 min_col=min(min_col,field_slice.start) |
154 max_col=max(max_col,field_slice.stop) | 155 max_col=max(max_col,field_slice.stop) |
155 new_fields={} | 156 new_fields={} |
156 for field in self.fields: | 157 for field in self.fields: |
157 new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step) | 158 new_fields[field[0]]=slice(field[1].start-min_col,field[1].stop-min_col,field[1].step) |
158 return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields) | 159 return ArrayDataSet(data=self.data[:,min_col:max_col],fields=new_fields,minibatch_size=self.minibatch_size) |
159 | 160 |
160 def fieldNames(self): | 161 def fieldNames(self): |
161 """Return the list of field names that are supported by getattr and getFields.""" | 162 """Return the list of field names that are supported by getattr and getFields.""" |
162 return self.fields.keys() | 163 return self.fields.keys() |
163 | 164 |
177 else: | 178 else: |
178 return data[i] | 179 return data[i] |
179 | 180 |
180 def __getslice__(self,*slice_args): | 181 def __getslice__(self,*slice_args): |
181 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" | 182 """dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.""" |
182 return ArrayDataSet(data=self.data[slice(slice_args)],fields=self.fields) | 183 return ArrayDataSet(data=self.data[apply(slice,slice_args)],fields=self.fields) |
183 | 184 |
184 def asarray(self): | 185 def asarray(self): |
185 if self.fields: | 186 if self.fields: |
186 columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) | 187 columns_used = numpy.zeros((self.data.shape[1]),dtype=bool) |
187 for field_slice in self.fields.values(): | 188 for field_slice in self.fields.values(): |