comparison dataset.py @ 152:3f627e844cba

Fixes in CacheDataSet
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 12 May 2008 16:11:24 -0400
parents 39bb21348fdf
children f8a1ae7eb83e
comparison
equal deleted inserted replaced
151:39bb21348fdf 152:3f627e844cba
1043 the record most likely to be accessed next. 1043 the record most likely to be accessed next.
1044 """ 1044 """
1045 def __init__(self,source_dataset,cache_all_upon_construction=False): 1045 def __init__(self,source_dataset,cache_all_upon_construction=False):
1046 self.source_dataset=source_dataset 1046 self.source_dataset=source_dataset
1047 self.cache_all_upon_construction=cache_all_upon_construction 1047 self.cache_all_upon_construction=cache_all_upon_construction
1048 self.cached_examples = []
1048 if cache_all_upon_construction: 1049 if cache_all_upon_construction:
1049 # this potentially brings all the source examples 1050 # this potentially brings all the source examples
1050 # into memory at once, which may be too much 1051 # into memory at once, which may be too much
1051 # the work could possibly be done by minibatches 1052 # the work could possibly be done by minibatches
1052 # that are as large as possible but no more than what memory allows. 1053 # that are as large as possible but no more than what memory allows.
1053 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() 1054 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
1054 self.cached_examples = zip(*fields_values) 1055 assert all([len(self)==len(field_values) for field_values in fields_values])
1055 else: 1056 for example in fields_values.examples():
1056 self.cached_examples = [] 1057 self.cached_examples.append(example)
1057 1058
1058 self.fieldNames = source_dataset.fieldNames 1059 self.fieldNames = source_dataset.fieldNames
1059 self.hasFields = source_dataset.hasFields 1060 self.hasFields = source_dataset.hasFields
1060 self.valuesHStack = source_dataset.valuesHStack 1061 self.valuesHStack = source_dataset.valuesHStack
1061 self.valuesVStack = source_dataset.valuesVStack 1062 self.valuesVStack = source_dataset.valuesVStack
1075 if upper>cache_len: # whole minibatch is not already in cache 1076 if upper>cache_len: # whole minibatch is not already in cache
1076 # cache everything from current length to upper 1077 # cache everything from current length to upper
1077 for example in self.dataset.source_dataset[cache_len:upper]: 1078 for example in self.dataset.source_dataset[cache_len:upper]:
1078 self.dataset.cached_examples.append(example) 1079 self.dataset.cached_examples.append(example)
1079 all_fields_minibatch = Example(self.dataset.fieldNames(), 1080 all_fields_minibatch = Example(self.dataset.fieldNames(),
1080 *self.dataset.cached_examples[self.current:self.current+minibatch_size]) 1081 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
1081 if self.dataset.fieldNames()==fieldnames: 1082 if self.dataset.fieldNames()==fieldnames:
1082 return all_fields_minibatch 1083 return all_fields_minibatch
1083 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) 1084 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
1084 return CacheIterator(self) 1085 return CacheIterator(self)
1085 1086