Mercurial > pylearn
comparison dataset.py @ 156:cc8b032417db
merged
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 12 May 2008 16:17:22 -0400 |
parents | f8a1ae7eb83e |
children | 28a988bd19c3 |
comparison
equal
deleted
inserted
replaced
155:ae5651a3696b | 156:cc8b032417db |
---|---|
1043 the record most likely to be accessed next. | 1043 the record most likely to be accessed next. |
1044 """ | 1044 """ |
1045 def __init__(self,source_dataset,cache_all_upon_construction=False): | 1045 def __init__(self,source_dataset,cache_all_upon_construction=False): |
1046 self.source_dataset=source_dataset | 1046 self.source_dataset=source_dataset |
1047 self.cache_all_upon_construction=cache_all_upon_construction | 1047 self.cache_all_upon_construction=cache_all_upon_construction |
1048 self.cached_examples = [] | |
1048 if cache_all_upon_construction: | 1049 if cache_all_upon_construction: |
1049 # this potentially brings all the source examples | 1050 # this potentially brings all the source examples |
1050 # into memory at once, which may be too much | 1051 # into memory at once, which may be too much |
1051 # the work could possibly be done by minibatches | 1052 # the work could possibly be done by minibatches |
1052 # that are as large as possible but no more than what memory allows. | 1053 # that are as large as possible but no more than what memory allows. |
1053 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() | 1054 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() |
1054 self.cached_examples = zip(*fields_values) | 1055 assert all([len(self)==len(field_values) for field_values in fields_values]) |
1055 else: | 1056 for example in fields_values.examples(): |
1056 self.cached_examples = [] | 1057 self.cached_examples.append(example) |
1057 | 1058 |
1058 self.fieldNames = source_dataset.fieldNames | 1059 self.fieldNames = source_dataset.fieldNames |
1059 self.hasFields = source_dataset.hasFields | 1060 self.hasFields = source_dataset.hasFields |
1060 self.valuesHStack = source_dataset.valuesHStack | 1061 self.valuesHStack = source_dataset.valuesHStack |
1061 self.valuesVStack = source_dataset.valuesVStack | 1062 self.valuesVStack = source_dataset.valuesVStack |
1075 if upper>cache_len: # whole minibatch is not already in cache | 1076 if upper>cache_len: # whole minibatch is not already in cache |
1076 # cache everything from current length to upper | 1077 # cache everything from current length to upper |
1077 for example in self.dataset.source_dataset[cache_len:upper]: | 1078 for example in self.dataset.source_dataset[cache_len:upper]: |
1078 self.dataset.cached_examples.append(example) | 1079 self.dataset.cached_examples.append(example) |
1079 all_fields_minibatch = Example(self.dataset.fieldNames(), | 1080 all_fields_minibatch = Example(self.dataset.fieldNames(), |
1080 *self.dataset.cached_examples[self.current:self.current+minibatch_size]) | 1081 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) |
1081 if self.dataset.fieldNames()==fieldnames: | 1082 if self.dataset.fieldNames()==fieldnames: |
1082 return all_fields_minibatch | 1083 return all_fields_minibatch |
1083 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | 1084 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) |
1084 return CacheIterator(self) | 1085 return CacheIterator(self) |
1085 | 1086 |
1087 def __getitem__(self,i): | |
1088 if type(i)==int and len(self.cached_examples)>i: | |
1089 return self.cached_examples[i] | |
1090 else: | |
1091 return DataSet.__getitem__(self,i) | |
1086 | 1092 |
1087 class ApplyFunctionDataSet(DataSet): | 1093 class ApplyFunctionDataSet(DataSet): |
1088 """ | 1094 """ |
1089 A dataset that contains as fields the results of applying a given function | 1095 A dataset that contains as fields the results of applying a given function |
1090 example-wise or minibatch-wise to all the fields of an input dataset. | 1096 example-wise or minibatch-wise to all the fields of an input dataset. |