Mercurial > pylearn
comparison dataset.py @ 152:3f627e844cba
Fixes in CacheDataSet
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 12 May 2008 16:11:24 -0400 |
parents | 39bb21348fdf |
children | f8a1ae7eb83e |
comparison
equal
deleted
inserted
replaced
151:39bb21348fdf | 152:3f627e844cba |
---|---|
1043 the record most likely to be accessed next. | 1043 the record most likely to be accessed next. |
1044 """ | 1044 """ |
1045 def __init__(self,source_dataset,cache_all_upon_construction=False): | 1045 def __init__(self,source_dataset,cache_all_upon_construction=False): |
1046 self.source_dataset=source_dataset | 1046 self.source_dataset=source_dataset |
1047 self.cache_all_upon_construction=cache_all_upon_construction | 1047 self.cache_all_upon_construction=cache_all_upon_construction |
1048 self.cached_examples = [] | |
1048 if cache_all_upon_construction: | 1049 if cache_all_upon_construction: |
1049 # this potentially brings all the source examples | 1050 # this potentially brings all the source examples |
1050 # into memory at once, which may be too much | 1051 # into memory at once, which may be too much |
1051 # the work could possibly be done by minibatches | 1052 # the work could possibly be done by minibatches |
1052 # that are as large as possible but no more than what memory allows. | 1053 # that are as large as possible but no more than what memory allows. |
1053 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() | 1054 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() |
1054 self.cached_examples = zip(*fields_values) | 1055 assert all([len(self)==len(field_values) for field_values in fields_values]) |
1055 else: | 1056 for example in fields_values.examples(): |
1056 self.cached_examples = [] | 1057 self.cached_examples.append(example) |
1057 | 1058 |
1058 self.fieldNames = source_dataset.fieldNames | 1059 self.fieldNames = source_dataset.fieldNames |
1059 self.hasFields = source_dataset.hasFields | 1060 self.hasFields = source_dataset.hasFields |
1060 self.valuesHStack = source_dataset.valuesHStack | 1061 self.valuesHStack = source_dataset.valuesHStack |
1061 self.valuesVStack = source_dataset.valuesVStack | 1062 self.valuesVStack = source_dataset.valuesVStack |
1075 if upper>cache_len: # whole minibatch is not already in cache | 1076 if upper>cache_len: # whole minibatch is not already in cache |
1076 # cache everything from current length to upper | 1077 # cache everything from current length to upper |
1077 for example in self.dataset.source_dataset[cache_len:upper]: | 1078 for example in self.dataset.source_dataset[cache_len:upper]: |
1078 self.dataset.cached_examples.append(example) | 1079 self.dataset.cached_examples.append(example) |
1079 all_fields_minibatch = Example(self.dataset.fieldNames(), | 1080 all_fields_minibatch = Example(self.dataset.fieldNames(), |
1080 *self.dataset.cached_examples[self.current:self.current+minibatch_size]) | 1081 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) |
1081 if self.dataset.fieldNames()==fieldnames: | 1082 if self.dataset.fieldNames()==fieldnames: |
1082 return all_fields_minibatch | 1083 return all_fields_minibatch |
1083 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | 1084 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) |
1084 return CacheIterator(self) | 1085 return CacheIterator(self) |
1085 | 1086 |