comparison dataset.py @ 376:c9a89be5cb0a

Redesigning linear_regression
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 07 Jul 2008 10:08:35 -0400
parents 18702ceb2096
children 835830e52b42
comparison
equal deleted inserted replaced
375:12ce29abf27d 376:c9a89be5cb0a
1 1
2 from lookup_list import LookupList as Example 2 from lookup_list import LookupList as Example
3 from misc import unique_elements_list_intersection 3 from common.misc import unique_elements_list_intersection
4 from string import join 4 from string import join
5 from sys import maxint 5 from sys import maxint
6 import numpy, copy 6 import numpy, copy
7 7
8 from exceptions import * 8 from exceptions import *
379 379
380 Note: A list-like container is something like a tuple, list, numpy.ndarray or 380 Note: A list-like container is something like a tuple, list, numpy.ndarray or
381 any other object that supports integer indexing and slicing. 381 any other object that supports integer indexing and slicing.
382 382
383 @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete 383 @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete
384 batches only, raise StopIteration 384 batches only, raise StopIteration.
385 @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it.
385 386
386 """ 387 """
387 #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\ 388 #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\
388 assert offset >= 0 389 assert offset >= 0
389 assert offset < len(self) 390 assert offset < len(self)
431 raise AbstractFunction() 432 raise AbstractFunction()
432 433
433 def __call__(self,*fieldnames): 434 def __call__(self,*fieldnames):
434 """ 435 """
435 Return a dataset that sees only the fields whose name are specified. 436 Return a dataset that sees only the fields whose name are specified.
437 """
438 assert self.hasFields(*fieldnames)
439 #return self.fields(*fieldnames).examples()
440 fieldnames_list = list(fieldnames)
441 return FieldsSubsetDataSet(self,fieldnames_list)
442
443 def cached_fields_subset(self,*fieldnames) :
444 """
445 Behaviour is supposed to be the same as __call__(*fieldnames), but the dataset returned is cached.
446 @see : dataset.__call__
436 """ 447 """
437 assert self.hasFields(*fieldnames) 448 assert self.hasFields(*fieldnames)
438 return self.fields(*fieldnames).examples() 449 return self.fields(*fieldnames).examples()
439 450
440 def fields(self,*fieldnames): 451 def fields(self,*fieldnames):
690 self.new_fieldnames=new_fieldnames 701 self.new_fieldnames=new_fieldnames
691 assert src.hasFields(*src_fieldnames) 702 assert src.hasFields(*src_fieldnames)
692 assert len(src_fieldnames)==len(new_fieldnames) 703 assert len(src_fieldnames)==len(new_fieldnames)
693 self.valuesHStack = src.valuesHStack 704 self.valuesHStack = src.valuesHStack
694 self.valuesVStack = src.valuesVStack 705 self.valuesVStack = src.valuesVStack
706 self.lookup_fields = Example(new_fieldnames,src_fieldnames)
695 707
696 def __len__(self): return len(self.src) 708 def __len__(self): return len(self.src)
697 709
698 def fieldNames(self): 710 def fieldNames(self):
699 return self.new_fieldnames 711 return self.new_fieldnames
717 return self.example 729 return self.example
718 return FieldsSubsetIterator(self) 730 return FieldsSubsetIterator(self)
719 731
720 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 732 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
721 assert self.hasFields(*fieldnames) 733 assert self.hasFields(*fieldnames)
722 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) 734 cursor = Example(fieldnames,[0]*len(fieldnames))
735 for batch in self.src.minibatches_nowrap([self.lookup_fields[f] for f in fieldnames],minibatch_size,n_batches,offset):
736 cursor._values=batch._values
737 yield cursor
738
723 def __getitem__(self,i): 739 def __getitem__(self,i):
724 return FieldsSubsetDataSet(self.src[i],self.new_fieldnames) 740 # return FieldsSubsetDataSet(self.src[i],self.new_fieldnames)
741 complete_example = self.src[i]
742 return Example(self.new_fieldnames,
743 [complete_example[field]
744 for field in self.src_fieldnames])
745
725 746
726 747
727 class DataSetFields(Example): 748 class DataSetFields(Example):
728 """ 749 """
729 Although a L{DataSet} iterates over examples (like rows of a matrix), an associated 750 Although a L{DataSet} iterates over examples (like rows of a matrix), an associated
857 raise NotImplementedError() 878 raise NotImplementedError()
858 def __iter__(self): 879 def __iter__(self):
859 return self 880 return self
860 def next(self): 881 def next(self):
861 upper = self.next_example+minibatch_size 882 upper = self.next_example+minibatch_size
862 assert upper<=self.ds.length 883 if upper > len(self.ds) :
884 raise StopIteration()
885 assert upper<=len(self.ds) # instead of self.ds.length
863 #minibatch = Example(self.ds._fields.keys(), 886 #minibatch = Example(self.ds._fields.keys(),
864 # [field[self.next_example:upper] 887 # [field[self.next_example:upper]
865 # for field in self.ds._fields]) 888 # for field in self.ds._fields])
866 # tbm: modif to use fieldnames 889 # tbm: modif to use fieldnames
867 values = [] 890 values = []
1312 if cache_all_upon_construction: 1335 if cache_all_upon_construction:
1313 # this potentially brings all the source examples 1336 # this potentially brings all the source examples
1314 # into memory at once, which may be too much 1337 # into memory at once, which may be too much
1315 # the work could possibly be done by minibatches 1338 # the work could possibly be done by minibatches
1316 # that are as large as possible but no more than what memory allows. 1339 # that are as large as possible but no more than what memory allows.
1317 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() 1340 #
1341 # field_values is supposed to be an DataSetFields, that inherits from LookupList
1342 #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
1343 fields_values = DataSetFields(source_dataset,None)
1318 assert all([len(self)==len(field_values) for field_values in fields_values]) 1344 assert all([len(self)==len(field_values) for field_values in fields_values])
1319 for example in fields_values.examples(): 1345 for example in fields_values.examples():
1320 self.cached_examples.append(copy.copy(example)) 1346 self.cached_examples.append(copy.copy(example))
1321 1347
1322 self.fieldNames = source_dataset.fieldNames 1348 self.fieldNames = source_dataset.fieldNames
1331 class CacheIterator(object): 1357 class CacheIterator(object):
1332 def __init__(self,dataset): 1358 def __init__(self,dataset):
1333 self.dataset=dataset 1359 self.dataset=dataset
1334 self.current=offset 1360 self.current=offset
1335 self.all_fields = self.dataset.fieldNames()==fieldnames 1361 self.all_fields = self.dataset.fieldNames()==fieldnames
1362 self.n_batches = n_batches
1363 self.batch_counter = 0
1336 def __iter__(self): return self 1364 def __iter__(self): return self
1337 def next(self): 1365 def next(self):
1366 self.batch_counter += 1
1367 if self.n_batches and self.batch_counter > self.n_batches :
1368 raise StopIteration()
1338 upper = self.current+minibatch_size 1369 upper = self.current+minibatch_size
1370 if upper > len(self.dataset.source_dataset):
1371 raise StopIteration()
1339 cache_len = len(self.dataset.cached_examples) 1372 cache_len = len(self.dataset.cached_examples)
1340 if upper>cache_len: # whole minibatch is not already in cache 1373 if upper>cache_len: # whole minibatch is not already in cache
1341 # cache everything from current length to upper 1374 # cache everything from current length to upper
1342 for example in self.dataset.source_dataset[cache_len:upper]: 1375 #for example in self.dataset.source_dataset[cache_len:upper]:
1376 for example in self.dataset.source_dataset.subset[cache_len:upper]:
1343 self.dataset.cached_examples.append(example) 1377 self.dataset.cached_examples.append(example)
1344 all_fields_minibatch = Example(self.dataset.fieldNames(), 1378 all_fields_minibatch = Example(self.dataset.fieldNames(),
1345 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) 1379 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
1380
1346 self.current+=minibatch_size 1381 self.current+=minibatch_size
1347 if self.all_fields: 1382 if self.all_fields:
1348 return all_fields_minibatch 1383 return all_fields_minibatch
1349 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) 1384 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
1350 return CacheIterator(self) 1385 return CacheIterator(self)