Mercurial > pylearn
comparison dataset.py @ 19:57f4015e2e09
Iterators extend LookupList
author | bergstrj@iro.umontreal.ca |
---|---|
date | Thu, 27 Mar 2008 01:59:44 -0400 |
parents | 759d17112b23 |
children | 266c68cb6136 |
comparison
equal
deleted
inserted
replaced
18:60b164a0d84a | 19:57f4015e2e09 |
---|---|
38 all the fields of DataSet self. Every field of "i" will give access to | 38 all the fields of DataSet self. Every field of "i" will give access to |
39 a the field of a single example. Fields should be accessible via | 39 a the field of a single example. Fields should be accessible via |
40 i[identifier], but the derived class is free to accept any type of | 40 i[identifier], but the derived class is free to accept any type of |
41 identifier, and add extra functionality to the iterator. | 41 identifier, and add extra functionality to the iterator. |
42 """ | 42 """ |
43 for i in self.minibatches( minibatch_size = 1): | 43 raise AbstractFunction() |
44 yield Example(i.keys(), [v[0] for v in i.values()]) | |
45 | 44 |
46 def zip(self, *fieldnames): | 45 def zip(self, *fieldnames): |
47 """ | 46 """ |
48 Supports two forms of syntax: | 47 Supports two forms of syntax: |
49 | 48 |
59 f1, f2, and f3 fields of a single example on each loop iteration. | 58 f1, f2, and f3 fields of a single example on each loop iteration. |
60 | 59 |
61 The derived class may accept fieldname arguments of any type. | 60 The derived class may accept fieldname arguments of any type. |
62 | 61 |
63 """ | 62 """ |
64 for i in self.minibatches(fieldnames, minibatch_size = 1): | 63 class Iter(LookupList): |
65 yield [f[0] for f in i] | 64 def __init__(self, ll): |
65 LookupList.__init__(self, ll.keys(), ll.values()) | |
66 self.ll = ll | |
67 def __iter__(self): #makes for loop work | |
68 return self | |
69 def next(self): | |
70 self.ll.next() | |
71 self._values = [v[0] for v in self.ll._values] | |
72 return self | |
73 return Iter(self.minibatches(fieldnames, minibatch_size = 1)) | |
66 | 74 |
67 minibatches_fieldnames = None | 75 minibatches_fieldnames = None |
68 minibatches_minibatch_size = 1 | 76 minibatches_minibatch_size = 1 |
69 minibatches_n_batches = None | 77 minibatches_n_batches = None |
70 def minibatches(self, | 78 def minibatches(self, |
175 self.dataset=dataset | 183 self.dataset=dataset |
176 self.minibatch_size=minibatch_size | 184 self.minibatch_size=minibatch_size |
177 assert minibatch_size>=1 and minibatch_size<=len(dataset) | 185 assert minibatch_size>=1 and minibatch_size<=len(dataset) |
178 self.current = -self.minibatch_size | 186 self.current = -self.minibatch_size |
179 self.fieldnames = fieldnames | 187 self.fieldnames = fieldnames |
188 if len(dataset) % minibatch_size: | |
189 raise NotImplementedError() | |
180 | 190 |
181 def __iter__(self): | 191 def __iter__(self): |
182 return self | 192 return self |
183 | 193 |
184 def next(self): | 194 def next(self): |
285 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array. | 295 each 'example' is just a one-row ArrayDataSet, otherwise it is a numpy array. |
286 Any dataset can also be converted to a numpy array (losing the notion of fields | 296 Any dataset can also be converted to a numpy array (losing the notion of fields |
287 by the numpy.array(dataset) call. | 297 by the numpy.array(dataset) call. |
288 """ | 298 """ |
289 | 299 |
290 class Iterator(object): | 300 class Iterator(LookupList): |
291 """An iterator over a finite dataset that implements wrap-around""" | 301 """An iterator over a finite dataset that implements wrap-around""" |
292 def __init__(self, dataset, fieldnames, minibatch_size, next_max): | 302 def __init__(self, dataset, fieldnames, minibatch_size, next_max): |
303 LookupList.__init__(self, fieldnames, [0] * len(fieldnames)) | |
293 self.dataset=dataset | 304 self.dataset=dataset |
294 self.fieldnames = fieldnames | |
295 self.minibatch_size=minibatch_size | 305 self.minibatch_size=minibatch_size |
296 self.next_count = 0 | 306 self.next_count = 0 |
297 self.next_max = next_max | 307 self.next_max = next_max |
298 self.current = -self.minibatch_size | 308 self.current = -self.minibatch_size |
299 assert minibatch_size > 0 | 309 assert minibatch_size > 0 |
300 if minibatch_size >= len(dataset): | 310 if minibatch_size >= len(dataset): |
301 raise NotImplementedError() | 311 raise NotImplementedError() |
302 | 312 |
303 def __iter__(self): | 313 def __iter__(self): #makes for loop work |
304 #Why do we do this? -JB | |
305 return self | 314 return self |
306 | 315 |
307 @staticmethod | 316 @staticmethod |
308 def matcat(a, b): | 317 def matcat(a, b): |
309 a0, a1 = a.shape | 318 a0, a1 = a.shape |
321 self.next_count += 1 | 330 self.next_count += 1 |
322 if self.next_count == self.next_max: | 331 if self.next_count == self.next_max: |
323 raise StopIteration | 332 raise StopIteration |
324 | 333 |
325 #determine the first and last elements of the slice we'll return | 334 #determine the first and last elements of the slice we'll return |
335 rows = self.dataset.data.shape[0] | |
326 self.current += self.minibatch_size | 336 self.current += self.minibatch_size |
327 if self.current >= len(self.dataset): | 337 if self.current >= rows: |
328 self.current -= len(self.dataset) | 338 self.current -= rows |
329 upper = self.current + self.minibatch_size | 339 upper = self.current + self.minibatch_size |
330 | 340 |
331 if upper <= len(self.dataset): | 341 data = self.dataset.data |
342 | |
343 if upper <= rows: | |
332 #this is the easy case, we only need once slice | 344 #this is the easy case, we only need once slice |
333 dataview = self.dataset.data[self.current:upper] | 345 dataview = data[self.current:upper] |
334 else: | 346 else: |
335 # the minibatch wraps around the end of the dataset | 347 # the minibatch wraps around the end of the dataset |
336 dataview = self.dataset.data[self.current:] | 348 dataview = data[self.current:] |
337 upper -= len(self.dataset) | 349 upper -= rows |
338 assert upper > 0 | 350 assert upper > 0 |
339 dataview = self.matcat(dataview, self.dataset.data[:upper]) | 351 dataview = self.matcat(dataview, data[:upper]) |
340 | 352 |
341 | 353 |
342 rval = [dataview[:, self.dataset.fields[f]] for f in self.fieldnames] | 354 self._values = [dataview[:, self.dataset.fields[f]]\ |
343 | 355 for f in self._names] |
344 if self.fieldnames: | 356 |
345 rval = Example(self.fieldnames, rval) | 357 return self |
346 | |
347 return rval | |
348 | 358 |
349 | 359 |
350 def __init__(self, data, fields=None): | 360 def __init__(self, data, fields=None): |
351 """ | 361 """ |
352 There are two ways to construct an ArrayDataSet: (1) from an | 362 There are two ways to construct an ArrayDataSet: (1) from an |
369 step=1 | 379 step=1 |
370 if not fieldslice.start or not fieldslice.step: | 380 if not fieldslice.start or not fieldslice.step: |
371 fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step) | 381 fields[fieldname] = fieldslice = slice(start,fieldslice.stop,step) |
372 # and coherent with the data array | 382 # and coherent with the data array |
373 assert fieldslice.start >= 0 and fieldslice.stop <= cols | 383 assert fieldslice.start >= 0 and fieldslice.stop <= cols |
384 | |
385 def __iter__(self): | |
386 return self.zip(*self.fieldNames()) | |
374 | 387 |
375 def minibatches(self, | 388 def minibatches(self, |
376 fieldnames = DataSet.minibatches_fieldnames, | 389 fieldnames = DataSet.minibatches_fieldnames, |
377 minibatch_size = DataSet.minibatches_minibatch_size, | 390 minibatch_size = DataSet.minibatches_minibatch_size, |
378 n_batches = DataSet.minibatches_n_batches): | 391 n_batches = DataSet.minibatches_n_batches): |