Mercurial > pylearn
comparison dataset.py @ 376:c9a89be5cb0a
Redesigning linear_regression
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 07 Jul 2008 10:08:35 -0400 |
parents | 18702ceb2096 |
children | 835830e52b42 |
comparison
equal
deleted
inserted
replaced
375:12ce29abf27d | 376:c9a89be5cb0a |
---|---|
1 | 1 |
2 from lookup_list import LookupList as Example | 2 from lookup_list import LookupList as Example |
3 from misc import unique_elements_list_intersection | 3 from common.misc import unique_elements_list_intersection |
4 from string import join | 4 from string import join |
5 from sys import maxint | 5 from sys import maxint |
6 import numpy, copy | 6 import numpy, copy |
7 | 7 |
8 from exceptions import * | 8 from exceptions import * |
379 | 379 |
380 Note: A list-like container is something like a tuple, list, numpy.ndarray or | 380 Note: A list-like container is something like a tuple, list, numpy.ndarray or |
381 any other object that supports integer indexing and slicing. | 381 any other object that supports integer indexing and slicing. |
382 | 382 |
383 @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete | 383 @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete |
384 batches only, raise StopIteration | 384 batches only, raise StopIteration. |
385 @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it. | |
385 | 386 |
386 """ | 387 """ |
387 #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\ | 388 #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\ |
388 assert offset >= 0 | 389 assert offset >= 0 |
389 assert offset < len(self) | 390 assert offset < len(self) |
431 raise AbstractFunction() | 432 raise AbstractFunction() |
432 | 433 |
433 def __call__(self,*fieldnames): | 434 def __call__(self,*fieldnames): |
434 """ | 435 """ |
435 Return a dataset that sees only the fields whose name are specified. | 436 Return a dataset that sees only the fields whose name are specified. |
437 """ | |
438 assert self.hasFields(*fieldnames) | |
439 #return self.fields(*fieldnames).examples() | |
440 fieldnames_list = list(fieldnames) | |
441 return FieldsSubsetDataSet(self,fieldnames_list) | |
442 | |
443 def cached_fields_subset(self,*fieldnames) : | |
444 """ | |
445 Behaviour is supposed to be the same as __call__(*fieldnames), but the dataset returned is cached. | |
446 @see : dataset.__call__ | |
436 """ | 447 """ |
437 assert self.hasFields(*fieldnames) | 448 assert self.hasFields(*fieldnames) |
438 return self.fields(*fieldnames).examples() | 449 return self.fields(*fieldnames).examples() |
439 | 450 |
440 def fields(self,*fieldnames): | 451 def fields(self,*fieldnames): |
690 self.new_fieldnames=new_fieldnames | 701 self.new_fieldnames=new_fieldnames |
691 assert src.hasFields(*src_fieldnames) | 702 assert src.hasFields(*src_fieldnames) |
692 assert len(src_fieldnames)==len(new_fieldnames) | 703 assert len(src_fieldnames)==len(new_fieldnames) |
693 self.valuesHStack = src.valuesHStack | 704 self.valuesHStack = src.valuesHStack |
694 self.valuesVStack = src.valuesVStack | 705 self.valuesVStack = src.valuesVStack |
706 self.lookup_fields = Example(new_fieldnames,src_fieldnames) | |
695 | 707 |
696 def __len__(self): return len(self.src) | 708 def __len__(self): return len(self.src) |
697 | 709 |
698 def fieldNames(self): | 710 def fieldNames(self): |
699 return self.new_fieldnames | 711 return self.new_fieldnames |
717 return self.example | 729 return self.example |
718 return FieldsSubsetIterator(self) | 730 return FieldsSubsetIterator(self) |
719 | 731 |
720 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 732 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
721 assert self.hasFields(*fieldnames) | 733 assert self.hasFields(*fieldnames) |
722 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) | 734 cursor = Example(fieldnames,[0]*len(fieldnames)) |
735 for batch in self.src.minibatches_nowrap([self.lookup_fields[f] for f in fieldnames],minibatch_size,n_batches,offset): | |
736 cursor._values=batch._values | |
737 yield cursor | |
738 | |
723 def __getitem__(self,i): | 739 def __getitem__(self,i): |
724 return FieldsSubsetDataSet(self.src[i],self.new_fieldnames) | 740 # return FieldsSubsetDataSet(self.src[i],self.new_fieldnames) |
741 complete_example = self.src[i] | |
742 return Example(self.new_fieldnames, | |
743 [complete_example[field] | |
744 for field in self.src_fieldnames]) | |
745 | |
725 | 746 |
726 | 747 |
727 class DataSetFields(Example): | 748 class DataSetFields(Example): |
728 """ | 749 """ |
729 Although a L{DataSet} iterates over examples (like rows of a matrix), an associated | 750 Although a L{DataSet} iterates over examples (like rows of a matrix), an associated |
857 raise NotImplementedError() | 878 raise NotImplementedError() |
858 def __iter__(self): | 879 def __iter__(self): |
859 return self | 880 return self |
860 def next(self): | 881 def next(self): |
861 upper = self.next_example+minibatch_size | 882 upper = self.next_example+minibatch_size |
862 assert upper<=self.ds.length | 883 if upper > len(self.ds) : |
884 raise StopIteration() | |
885 assert upper<=len(self.ds) # instead of self.ds.length | |
863 #minibatch = Example(self.ds._fields.keys(), | 886 #minibatch = Example(self.ds._fields.keys(), |
864 # [field[self.next_example:upper] | 887 # [field[self.next_example:upper] |
865 # for field in self.ds._fields]) | 888 # for field in self.ds._fields]) |
866 # tbm: modif to use fieldnames | 889 # tbm: modif to use fieldnames |
867 values = [] | 890 values = [] |
1312 if cache_all_upon_construction: | 1335 if cache_all_upon_construction: |
1313 # this potentially brings all the source examples | 1336 # this potentially brings all the source examples |
1314 # into memory at once, which may be too much | 1337 # into memory at once, which may be too much |
1315 # the work could possibly be done by minibatches | 1338 # the work could possibly be done by minibatches |
1316 # that are as large as possible but no more than what memory allows. | 1339 # that are as large as possible but no more than what memory allows. |
1317 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() | 1340 # |
1341 # field_values is supposed to be an DataSetFields, that inherits from LookupList | |
1342 #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() | |
1343 fields_values = DataSetFields(source_dataset,None) | |
1318 assert all([len(self)==len(field_values) for field_values in fields_values]) | 1344 assert all([len(self)==len(field_values) for field_values in fields_values]) |
1319 for example in fields_values.examples(): | 1345 for example in fields_values.examples(): |
1320 self.cached_examples.append(copy.copy(example)) | 1346 self.cached_examples.append(copy.copy(example)) |
1321 | 1347 |
1322 self.fieldNames = source_dataset.fieldNames | 1348 self.fieldNames = source_dataset.fieldNames |
1331 class CacheIterator(object): | 1357 class CacheIterator(object): |
1332 def __init__(self,dataset): | 1358 def __init__(self,dataset): |
1333 self.dataset=dataset | 1359 self.dataset=dataset |
1334 self.current=offset | 1360 self.current=offset |
1335 self.all_fields = self.dataset.fieldNames()==fieldnames | 1361 self.all_fields = self.dataset.fieldNames()==fieldnames |
1362 self.n_batches = n_batches | |
1363 self.batch_counter = 0 | |
1336 def __iter__(self): return self | 1364 def __iter__(self): return self |
1337 def next(self): | 1365 def next(self): |
1366 self.batch_counter += 1 | |
1367 if self.n_batches and self.batch_counter > self.n_batches : | |
1368 raise StopIteration() | |
1338 upper = self.current+minibatch_size | 1369 upper = self.current+minibatch_size |
1370 if upper > len(self.dataset.source_dataset): | |
1371 raise StopIteration() | |
1339 cache_len = len(self.dataset.cached_examples) | 1372 cache_len = len(self.dataset.cached_examples) |
1340 if upper>cache_len: # whole minibatch is not already in cache | 1373 if upper>cache_len: # whole minibatch is not already in cache |
1341 # cache everything from current length to upper | 1374 # cache everything from current length to upper |
1342 for example in self.dataset.source_dataset[cache_len:upper]: | 1375 #for example in self.dataset.source_dataset[cache_len:upper]: |
1376 for example in self.dataset.source_dataset.subset[cache_len:upper]: | |
1343 self.dataset.cached_examples.append(example) | 1377 self.dataset.cached_examples.append(example) |
1344 all_fields_minibatch = Example(self.dataset.fieldNames(), | 1378 all_fields_minibatch = Example(self.dataset.fieldNames(), |
1345 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) | 1379 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) |
1380 | |
1346 self.current+=minibatch_size | 1381 self.current+=minibatch_size |
1347 if self.all_fields: | 1382 if self.all_fields: |
1348 return all_fields_minibatch | 1383 return all_fields_minibatch |
1349 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | 1384 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) |
1350 return CacheIterator(self) | 1385 return CacheIterator(self) |