Mercurial > pylearn
comparison dataset.py @ 40:88fd1cce08b9
replaced infinity for length by raise UnboundedDataSet and use & instead of + to concatenate datasets
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Fri, 25 Apr 2008 10:41:19 -0400 |
parents | c682c6e9bf93 |
children | 283e95c15b47 |
comparison
equal
deleted
inserted
replaced
39:c682c6e9bf93 | 40:88fd1cce08b9 |
---|---|
4 from misc import * | 4 from misc import * |
5 import copy | 5 import copy |
6 | 6 |
7 class AbstractFunction (Exception): """Derived class must override this function""" | 7 class AbstractFunction (Exception): """Derived class must override this function""" |
8 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" | 8 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" |
9 class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" | |
9 | 10 |
10 class DataSet(object): | 11 class DataSet(object): |
11 """A virtual base class for datasets. | 12 """A virtual base class for datasets. |
12 | 13 |
13 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction | 14 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction |
14 with learning algorithms (for training and testing them): rows/records are called examples, and | 15 with learning algorithms (for training and testing them): rows/records are called examples, and |
15 columns/attributes are called fields. The field value for a particular example can be an arbitrary | 16 columns/attributes are called fields. The field value for a particular example can be an arbitrary |
16 python object, which depends on the particular dataset. | 17 python object, which depends on the particular dataset. |
17 | 18 |
18 We call a DataSet a 'stream' when its length is unbounded (len(dataset)==float("infinity")). | 19 We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method |
20 should raise an UnboundedDataSet exception). | |
19 | 21 |
20 A DataSet is a generator of iterators; these iterators can run through the | 22 A DataSet is a generator of iterators; these iterators can run through the |
21 examples or the fields in a variety of ways. A DataSet need not necessarily have a finite | 23 examples or the fields in a variety of ways. A DataSet need not necessarily have a finite |
22 or known length, so this class can be used to interface to a 'stream' which | 24 or known length, so this class can be used to interface to a 'stream' which |
23 feeds on-line learning (however, as noted below, some operations are not | 25 feeds on-line learning (however, as noted below, some operations are not |
25 | 27 |
26 To iterate over examples, there are several possibilities: | 28 To iterate over examples, there are several possibilities: |
27 * for example in dataset([field1, field2,field3, ...]): | 29 * for example in dataset([field1, field2,field3, ...]): |
28 * for val1,val2,val3 in dataset([field1, field2,field3]): | 30 * for val1,val2,val3 in dataset([field1, field2,field3]): |
29 * for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): | 31 * for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): |
32 * for mini1,mini2,mini3 in dataset.minibatches([field1, field2, ...],minibatch_size=N): | |
30 * for example in dataset: | 33 * for example in dataset: |
31 Each of these is documented below. All of these iterators are expected | 34 Each of these is documented below. All of these iterators are expected |
32 to provide, in addition to the usual 'next()' method, a 'next_index()' method | 35 to provide, in addition to the usual 'next()' method, a 'next_index()' method |
33 which returns a non-negative integer pointing to the position of the next | 36 which returns a non-negative integer pointing to the position of the next |
34 example that will be returned by 'next()' (or of the first example in the | 37 example that will be returned by 'next()' (or of the first example in the |
80 * dataset1 | dataset2 | dataset3 == dataset.hstack([dataset1,dataset2,dataset3]) | 83 * dataset1 | dataset2 | dataset3 == dataset.hstack([dataset1,dataset2,dataset3]) |
81 | 84 |
82 creates a new dataset whose list of fields is the concatenation of the list of | 85 creates a new dataset whose list of fields is the concatenation of the list of |
83 fields of the argument datasets. This only works if they all have the same length. | 86 fields of the argument datasets. This only works if they all have the same length. |
84 | 87 |
85 * dataset1 + dataset2 + dataset3 == dataset.vstack([dataset1,dataset2,dataset3]) | 88 * dataset1 & dataset2 & dataset3 == dataset.vstack([dataset1,dataset2,dataset3]) |
86 | 89 |
87 creates a new dataset that concatenates the examples from the argument datasets | 90 creates a new dataset that concatenates the examples from the argument datasets |
88 (and whose length is the sum of the length of the argument datasets). This only | 91 (and whose length is the sum of the length of the argument datasets). This only |
89 works if they all have the same fields. | 92 works if they all have the same fields. |
90 | 93 |
91 According to the same logic, and viewing a DataSetFields object associated to | 94 According to the same logic, and viewing a DataSetFields object associated to |
92 a DataSet as a kind of transpose of it, fields1 + fields2 concatenates fields of | 95 a DataSet as a kind of transpose of it, fields1 + fields2 concatenates fields of |
93 a DataSetFields fields1 and fields2, and fields1 | fields2 concatenates their | 96 a DataSetFields fields1 and fields2, and fields1 | fields2 concatenates their |
94 examples. | 97 examples. |
95 | 98 |
96 | |
97 A DataSet sub-class should always redefine the following methods: | 99 A DataSet sub-class should always redefine the following methods: |
98 * __len__ if it is not a stream | 100 * __len__ if it is not a stream |
99 * __getitem__ may not be feasible with some streams | |
100 * fieldNames | 101 * fieldNames |
101 * minibatches_nowrap (called by DataSet.minibatches()) | 102 * minibatches_nowrap (called by DataSet.minibatches()) |
102 * valuesHStack | 103 * valuesHStack |
103 * valuesVStack | 104 * valuesVStack |
104 For efficiency of implementation, a sub-class might also want to redefine | 105 For efficiency of implementation, a sub-class might also want to redefine |
105 * hasFields | 106 * hasFields |
106 """ | 107 * __getitem__ may not be feasible with some streams |
107 | 108 * __iter__ |
108 infinity = float("infinity") | 109 """ |
109 | 110 |
110 def __init__(self): | 111 def __init__(self): |
111 pass | 112 pass |
112 | 113 |
113 class MinibatchToSingleExampleIterator(object): | 114 class MinibatchToSingleExampleIterator(object): |
114 """ | 115 """ |
122 def __init__(self, minibatch_iterator): | 123 def __init__(self, minibatch_iterator): |
123 self.minibatch_iterator = minibatch_iterator | 124 self.minibatch_iterator = minibatch_iterator |
124 def __iter__(self): #makes for loop work | 125 def __iter__(self): #makes for loop work |
125 return self | 126 return self |
126 def next(self): | 127 def next(self): |
127 return self.minibatch_iterator.next()[0] | 128 size1_minibatch = self.minibatch_iterator.next() |
129 return Example(size1_minibatch.keys,[value[0] for value in size1_minibatch.values()]) | |
130 | |
128 def next_index(self): | 131 def next_index(self): |
129 return self.minibatch_iterator.next_index() | 132 return self.minibatch_iterator.next_index() |
130 | 133 |
131 def __iter__(self): | 134 def __iter__(self): |
132 """Supports the syntax "for i in dataset: ..." | 135 """Supports the syntax "for i in dataset: ..." |
221 of a batch of current examples. In the second case, i[0] is | 224 of a batch of current examples. In the second case, i[0] is |
222 list-like container of the f1 field of a batch current examples, i[1] is | 225 list-like container of the f1 field of a batch current examples, i[1] is |
223 a list-like container of the f2 field, etc. | 226 a list-like container of the f2 field, etc. |
224 | 227 |
225 Using the first syntax, all the fields will be returned in "i". | 228 Using the first syntax, all the fields will be returned in "i". |
226 Beware that some datasets may not support this syntax, if the number | |
227 of fields is infinite (i.e. field values may be computed "on demand"). | |
228 | |
229 Using the third syntax, i1, i2, i3 will be list-like containers of the | 229 Using the third syntax, i1, i2, i3 will be list-like containers of the |
230 f1, f2, and f3 fields of a batch of examples on each loop iteration. | 230 f1, f2, and f3 fields of a batch of examples on each loop iteration. |
231 | 231 |
232 The minibatches iterator is expected to return upon each call to next() | 232 The minibatches iterator is expected to return upon each call to next() |
233 a DataSetFields object, which is a LookupList (indexed by the field names) whose | 233 a DataSetFields object, which is a LookupList (indexed by the field names) whose |
275 raise AbstractFunction() | 275 raise AbstractFunction() |
276 | 276 |
277 def __len__(self): | 277 def __len__(self): |
278 """ | 278 """ |
279 len(dataset) returns the number of examples in the dataset. | 279 len(dataset) returns the number of examples in the dataset. |
280 By default, a DataSet is a 'stream', i.e. it has an unbounded (infinite) length. | 280 By default, a DataSet is a 'stream', i.e. it has an unbounded length (raises UnboundedDataSet). |
281 Sub-classes which implement finite-length datasets should redefine this method. | 281 Sub-classes which implement finite-length datasets should redefine this method. |
282 Some methods only make sense for finite-length datasets, and will perform | 282 Some methods only make sense for finite-length datasets. |
283 assert len(dataset)<DataSet.infinity | 283 """ |
284 in order to check the finiteness of the dataset. | 284 raise UnboundedDataSet() |
285 """ | |
286 return infinity | |
287 | 285 |
288 def hasFields(self,*fieldnames): | 286 def hasFields(self,*fieldnames): |
289 """ | 287 """ |
290 Return true if the given field name (or field names, if multiple arguments are | 288 Return true if the given field name (or field names, if multiple arguments are |
291 given) is recognized by the DataSet (i.e. can be used as a field name in one | 289 given) is recognized by the DataSet (i.e. can be used as a field name in one |
325 | 323 |
326 Note that some stream datasets may be unable to implement random access, i.e. | 324 Note that some stream datasets may be unable to implement random access, i.e. |
327 arbitrary slicing/indexing | 325 arbitrary slicing/indexing |
328 because they can only iterate through examples one or a minibatch at a time | 326 because they can only iterate through examples one or a minibatch at a time |
329 and do not actually store or keep past (or future) examples. | 327 and do not actually store or keep past (or future) examples. |
330 """ | 328 |
331 raise NotImplementedError() | 329 The default implementation of getitem uses the minibatches iterator |
330 to obtain one example, one slice, or a list of examples. It may not | |
331 always be the most efficient way to obtain the result, especially if | |
332 the data are actually stored in a memory array. | |
333 """ | |
334 if type(i) is int: | |
335 return DataSet.MinibatchToSingleExampleIterator( | |
336 self.minibatches(minibatch_size=1,n_batches=1,offset=i)).next() | |
337 if type(i) is slice: | |
338 if not i.start: i.start=0 | |
339 if not i.step: i.step=1 | |
340 if i.step is 1: | |
341 return self.minibatches(minibatch_size=i.stop-i.start,n_batches=1,offset=i.start).next().examples() | |
342 rows = range(i.start,i.stop,i.step) | |
343 else: | |
344 assert type(i) is list | |
345 rows = i | |
346 fields_values = zip(*[self[row] for row in rows]) | |
347 return MinibatchDataSet( | |
348 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) | |
349 for fieldname,field_values | |
350 in zip(self.fieldNames(),fields_values)])) | |
332 | 351 |
333 def valuesHStack(self,fieldnames,fieldvalues): | 352 def valuesHStack(self,fieldnames,fieldvalues): |
334 """ | 353 """ |
335 Return a value that corresponds to concatenating (horizontally) several field values. | 354 Return a value that corresponds to concatenating (horizontally) several field values. |
336 This can be useful to merge some fields. The implementation of this operation is likely | 355 This can be useful to merge some fields. The implementation of this operation is likely |
375 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of | 394 dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of |
376 fields of the argument datasets. This only works if they all have the same length. | 395 fields of the argument datasets. This only works if they all have the same length. |
377 """ | 396 """ |
378 return HStackedDataSet(self,other) | 397 return HStackedDataSet(self,other) |
379 | 398 |
380 def __add__(self,other): | 399 def __and__(self,other): |
381 """ | 400 """ |
382 dataset1 + dataset2 is a dataset that concatenates the examples from the argument datasets | 401 dataset1 & dataset2 is a dataset that concatenates the examples from the argument datasets |
383 (and whose length is the sum of the length of the argument datasets). This only | 402 (and whose length is the sum of the length of the argument datasets). This only |
384 works if they all have the same fields. | 403 works if they all have the same fields. |
385 """ | 404 """ |
386 return VStackedDataSet(self,other) | 405 return VStackedDataSet(self,other) |
387 | 406 |
396 return datasets[0] | 415 return datasets[0] |
397 return HStackedDataSet(datasets) | 416 return HStackedDataSet(datasets) |
398 | 417 |
399 def vstack(datasets): | 418 def vstack(datasets): |
400 """ | 419 """ |
401 vstack(dataset1,dataset2,...) returns dataset1 + datataset2 + ... | 420 vstack(dataset1,dataset2,...) returns dataset1 & datataset2 & ... |
402 which is a dataset which iterates first over the examples of dataset1, then | 421 which is a dataset which iterates first over the examples of dataset1, then |
403 over those of dataset2, etc. | 422 over those of dataset2, etc. |
404 """ | 423 """ |
405 assert len(datasets)>0 | 424 assert len(datasets)>0 |
406 if len(datasets)==1: | 425 if len(datasets)==1: |
428 The result of fields() is a DataSetFields object, which iterates over fields, | 447 The result of fields() is a DataSetFields object, which iterates over fields, |
429 and whose elements are iterable over examples. A DataSetFields object can | 448 and whose elements are iterable over examples. A DataSetFields object can |
430 be turned back into a DataSet with its examples() method: | 449 be turned back into a DataSet with its examples() method: |
431 dataset2 = dataset1.fields().examples() | 450 dataset2 = dataset1.fields().examples() |
432 and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). | 451 and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1). |
452 | |
453 DataSetFields can be concatenated vertically or horizontally. To be consistent with | |
454 the syntax used for DataSets, the | concatenates the fields and the & concatenates | |
455 the examples. | |
433 """ | 456 """ |
434 def __init__(self,dataset,*fieldnames): | 457 def __init__(self,dataset,*fieldnames): |
435 self.dataset=dataset | 458 self.dataset=dataset |
459 if not fieldnames: | |
460 fieldnames=dataset.fieldNames() | |
436 assert dataset.hasFields(*fieldnames) | 461 assert dataset.hasFields(*fieldnames) |
437 LookupList.__init__(self,dataset.fieldNames(), | 462 LookupList.__init__(self,dataset.fieldNames(), |
438 dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(), | 463 dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(), |
439 minibatch_size=len(dataset)).next() | 464 minibatch_size=len(dataset)).next() |
440 def examples(self): | 465 def examples(self): |
445 fields1 | fields2 is a DataSetFields that whose list of examples is the concatenation | 470 fields1 | fields2 is a DataSetFields that whose list of examples is the concatenation |
446 of the list of examples of DataSetFields fields1 and fields2. | 471 of the list of examples of DataSetFields fields1 and fields2. |
447 """ | 472 """ |
448 return (self.examples() + other.examples()).fields() | 473 return (self.examples() + other.examples()).fields() |
449 | 474 |
450 def __add__(self,other): | 475 def __and__(self,other): |
451 """ | 476 """ |
452 fields1 + fields2 is a DataSetFields that whose list of fields is the concatenation | 477 fields1 + fields2 is a DataSetFields that whose list of fields is the concatenation |
453 of the fields of DataSetFields fields1 and fields2. | 478 of the fields of DataSetFields fields1 and fields2. |
454 """ | 479 """ |
455 return (self.examples() | other.examples()).fields() | 480 return (self.examples() | other.examples()).fields() |
477 | 502 |
478 def __len__(self): | 503 def __len__(self): |
479 return self.length | 504 return self.length |
480 | 505 |
481 def __getitem__(self,i): | 506 def __getitem__(self,i): |
482 return Example(self.fields.keys(),[field[i] for field in self.fields]) | 507 return DataSetFields(MinibatchDataSet( |
508 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) | |
483 | 509 |
484 def fieldNames(self): | 510 def fieldNames(self): |
485 return self.fields.keys() | 511 return self.fields.keys() |
486 | 512 |
487 def hasFields(self,*fieldnames): | 513 def hasFields(self,*fieldnames): |
507 [field[next_example:upper] | 533 [field[next_example:upper] |
508 for field in self.ds.fields]) | 534 for field in self.ds.fields]) |
509 self.next_example+=minibatch_size | 535 self.next_example+=minibatch_size |
510 return DataSetFields(MinibatchDataSet(minibatch),fieldnames) | 536 return DataSetFields(MinibatchDataSet(minibatch),fieldnames) |
511 | 537 |
512 return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) | 538 return Iterator(self) |
513 | 539 |
514 def valuesVStack(self,fieldname,fieldvalues): | 540 def valuesVStack(self,fieldname,fieldvalues): |
515 return self.values_vstack(fieldname,fieldvalues) | 541 return self.values_vstack(fieldname,fieldvalues) |
516 | 542 |
517 def valuesHStack(self,fieldnames,fieldvalues): | 543 def valuesHStack(self,fieldnames,fieldvalues): |
637 fieldnames = datasets[-1].fieldNames() | 663 fieldnames = datasets[-1].fieldNames() |
638 self.datasets_start_row=[] | 664 self.datasets_start_row=[] |
639 # We use this map from row index to dataset index for constant-time random access of examples, | 665 # We use this map from row index to dataset index for constant-time random access of examples, |
640 # to avoid having to search for the appropriate dataset each time and slice is asked for. | 666 # to avoid having to search for the appropriate dataset each time and slice is asked for. |
641 for dataset,k in enumerate(datasets[0:-1]): | 667 for dataset,k in enumerate(datasets[0:-1]): |
642 L=len(dataset) | 668 try: |
643 assert L<DataSet.infinity | 669 L=len(dataset) |
670 except UnboundedDataSet: | |
671 print "All VStacked datasets (except possibly the last) must be bounded (have a length)." | |
672 assert False | |
644 for i in xrange(L): | 673 for i in xrange(L): |
645 self.index2dataset[self.length+i]=k | 674 self.index2dataset[self.length+i]=k |
646 self.datasets_start_row.append(self.length) | 675 self.datasets_start_row.append(self.length) |
647 self.length+=L | 676 self.length+=L |
648 assert dataset.fieldNames()==fieldnames | 677 assert dataset.fieldNames()==fieldnames |
719 if self.n_left_in_mb: | 748 if self.n_left_in_mb: |
720 extra_mb = [] | 749 extra_mb = [] |
721 while self.n_left_in_mb>0: | 750 while self.n_left_in_mb>0: |
722 self.move_to_next_dataset() | 751 self.move_to_next_dataset() |
723 extra_mb.append(self.next_iterator.next()) | 752 extra_mb.append(self.next_iterator.next()) |
724 mb = Example(names, | 753 examples = Example(names, |
725 [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb]) | 754 [dataset.valuesVStack(name, |
726 for name in fieldnames]) | 755 [mb[name]]+[b[name] for b in extra_mb]) |
756 for name in fieldnames]) | |
757 mb = DataSetFields(MinibatchDataSet(examples),fieldnames) | |
758 | |
727 self.next_row+=minibatch_size | 759 self.next_row+=minibatch_size |
728 self.next_dataset_row+=minibatch_size | 760 self.next_dataset_row+=minibatch_size |
729 if self.next_row+minibatch_size>len(dataset): | 761 if self.next_row+minibatch_size>len(dataset): |
730 self.move_to_next_dataset() | 762 self.move_to_next_dataset() |
731 return mb | 763 return |
732 | 764 |
733 | 765 |
734 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 766 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
735 """ | 767 """ |
736 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the | 768 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the |