Mercurial > pylearn
diff dataset.py @ 23:526e192b0699
Working on ApplyFunctionDataSet, added constraint that
DataSet iterators must have a next_index() method.
author | bengioy@esprit.iro.umontreal.ca |
---|---|
date | Wed, 09 Apr 2008 18:27:13 -0400 |
parents | b6b36f65664f |
children | 672fe4b23032 |
line wrap: on
line diff
--- a/dataset.py Mon Apr 07 20:44:37 2008 -0400 +++ b/dataset.py Wed Apr 09 18:27:13 2008 -0400 @@ -17,7 +17,14 @@ - for val1,val2,val3 in dataset.zip([field1, field2,field3]) - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N) - for example in dataset - Each of these is documented below. + Each of these is documented below. All of these iterators are expected + to provide, in addition to the usual 'next()' method, a 'next_index()' method + which returns a non-negative integer pointing to the position of the next + example that will be returned by 'next()' (or of the first example in the + next minibatch returned). This is important because these iterators + can wrap around the dataset in order to do multiple passes through it, + in possibly unregular ways if the minibatch size is not a divisor of the + dataset length. Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content. @@ -40,7 +47,7 @@ def __init__(self): pass - class Iter(LookupList): + class Iterator(LookupList): def __init__(self, ll): LookupList.__init__(self, ll.keys(), ll.values()) self.ll = ll @@ -50,6 +57,8 @@ self.ll.next() self._values = [v[0] for v in self.ll._values] return self + def next_index(self): + return self.ll.next_index() def __iter__(self): """Supports the syntax "for i in dataset: ..." @@ -61,7 +70,7 @@ Example returned by this iterator), but the derived class is free to accept any type of identifier, and add extra functionality to the iterator. """ - return DataSet.Iter(self.minibatches(None, minibatch_size = 1)) + return DataSet.Iterator(self.minibatches(None, minibatch_size = 1)) def zip(self, *fieldnames): """ @@ -81,7 +90,7 @@ The derived class may accept fieldname arguments of any type. """ - return DataSet.Iter(self.minibatches(fieldnames, minibatch_size = 1)) + return DataSet.Iterator(self.minibatches(fieldnames, minibatch_size = 1)) minibatches_fieldnames = None minibatches_minibatch_size = 1 @@ -141,15 +150,7 @@ """ raise AbstractFunction() - def rename(*new_field_specifications): - #Yoshua- - # Do you mean for this to be a virtual method? - # Wouldn't this functionality be easier to provide via a - # RenamingDataSet, such as the one I've written below? - # -JB - # You are right. Whichever implementation, however, we need a generic way to - # 'concatenate' fields, to handle the ([old_field1, old_field2, ...], new_field) semantics. - # -YB + def merge_fields(*specifications): """ Return a new dataset that maps old fields (of self) to new fields (of the returned dataset). The minimal syntax that should be supported is the following: @@ -161,6 +162,30 @@ """ raise AbstractFunction() + def merge_field_values(*field_value_pairs) + """ + Return the value that corresponds to merging the values of several fields, + given as arguments (field_name, field_value) pairs with self.hasField(field_name). + This may be used by implementations of merge_fields. + Raise a ValueError if the operation is not possible. + """ + fieldnames,fieldvalues = zip(*field_value_pairs) + raise ValueError("Unable to merge values of these fields:"+repr(fieldnames)) + + def examples2minibatch(examples): + """ + Combine a list of Examples into a minibatch. A minibatch is an Example whose fields + are iterable over the examples of the minibatch. + """ + raise AbstractFunction() + + def rename(rename_dict): + """ + Return a new dataset that renames fields, using a dictionnary that maps old field + names to new field names. The only fields visible by the returned dataset are those + whose names are keys of the rename_dict. + """ + return RenamingDataSet(self,rename_dict) def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): """ @@ -278,7 +303,7 @@ if hasattr(dataset, 'as_array_dataset'): return dataset.as_array_dataset() - raise NotImplementedError() + raise NotImplementedError # Make ONE big minibatch with all the examples, to separate the fields. n_examples = len(dataset) @@ -343,6 +368,13 @@ rval[:a0,:] = a rval[a0:,:] = b return rval + + def next_index(self): + n_rows = self.dataset.data.shape[0] + next_i = self.current+self.minibatch_size + if next_i >= n_rows: + next_i -= n_rows + return next_i def next(self): @@ -352,21 +384,19 @@ raise StopIteration #determine the first and last elements of the slice we'll return - rows = self.dataset.data.shape[0] - self.current += self.minibatch_size - if self.current >= rows: - self.current -= rows + n_rows = self.dataset.data.shape[0] + self.current = self.next_index() upper = self.current + self.minibatch_size data = self.dataset.data - if upper <= rows: + if upper <= n_rows: #this is the easy case, we only need once slice dataview = data[self.current:upper] else: # the minibatch wraps around the end of the dataset dataview = data[self.current:] - upper -= rows + upper -= n_rows assert upper > 0 dataview = self.matcat(dataview, data[:upper]) @@ -518,6 +548,19 @@ c+=slice_width return result + def rename(*new_field_specifications): + """ + Return a new dataset that maps old fields (of self) to new fields (of the returned + dataset). The minimal syntax that should be supported is the following: + new_field_specifications = [new_field_spec1, new_field_spec2, ...] + new_field_spec = ([old_field1, old_field2, ...], new_field) + In general both old_field and new_field should be strings, but some datasets may also + support additional indexing schemes within each field (e.g. column slice + of a matrix-like field). + """ + # if all old fields of each spec are + raise NotImplementedError() + class ApplyFunctionDataSet(DataSet): """ A dataset that contains as fields the results of applying @@ -532,31 +575,35 @@ once the output fields for some examples have been computed, then are cached (to avoid recomputation if the same examples are again requested). """ - def __init__(src,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): + def __init__(src,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True, compute_now=False): DataSet.__init__(self) self.src=src self.function=function + assert src.hasFields(input_fields) self.input_fields=input_fields self.output_fields=output_fields + assert not (copy_inputs and compute_now and not hasattr(src,'fieldNames')) self.copy_inputs=copy_inputs self.accept_minibatches=accept_minibatches - src_fieldnames = src.fieldNames() - if copy_inputs: - for src_field in src_fieldnames: - assert src_field not in output_fields - self.fieldnames=src_fieldnames+output_fields - else: - self.fieldnames=output_fields - for input_field in input_fields: - assert input_field in src_fieldnames self.cache=cache - if cache: + self.compute_now=compute_now + if compute_now: + assert hasattr(src,'__len__') and len(src)>=0 + fieldnames = output_fields + if copy_inputs: fieldnames = src.fieldNames() + output_fields + if accept_minibatches: + # make a single minibatch with all the inputs + inputs = src.minibatches(input_fields,len(src)).next() + # and apply the function to it, and transpose into a list of examples (field values, actually) + self.cached_examples = zip(*Example(output_fields,function(*inputs))) + else: + # compute a list with one tuple per example, with the function outputs + self.cached_examples = [ function(input) for input in src.zip(input_fields) ] + else if cache: # maybe a fixed-size array kind of structure would be more efficient than a list # in the case where src is FiniteDataSet. -YB - self.cached_examples = [] + self.cached_examples = [] - def fieldNames(self): return self.fieldnames - def minibatches(self, fieldnames = DataSet.minibatches_fieldnames, minibatch_size = DataSet.minibatches_minibatch_size, @@ -566,30 +613,69 @@ def __init__(self,dataset): if fieldnames is None: - LookupList.__init__(self, [],[]) - else: - LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) + assert hasattr(dataset,"fieldNames") + fieldnames = dataset.fieldNames() + self.example_index=0 + LookupList.__init__(self, fieldnames, [0]*len(fieldnames)) self.dataset=dataset - self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(self.dataset.input_fields))), + self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(dataset.input_fields))), minibatch_size,n_batches) + self.fieldnames_not_in_input = [] + if self.copy_inputs: + self.fieldnames_not_in_input = filter(lambda x: not x in dataset.input_fields, fieldnames) def __iter__(self): return self + def next_index(self): + return self.src_iterator.next_index() + def next(self): + example_index = self.src_iterator.next_index() src_examples = self.src_iterator.next() if self.dataset.copy_inputs: - function_inputs = src_examples + function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields] else: - function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields] - outputs = Example(self.dataset.output_fields,self.dataset.function(*function_inputs)) - if self.dataset.copy_inputs: - return src_examples + outputs + function_inputs = src_examples + if self.dataset.cached_examples: + cache_len=len(self.cached_examples) + if example_index<cache_len+minibatch_size: + outputs_list = self.cached_examples[example_index:example_index+minibatch_size] + # convert the minibatch list of examples + # into a list of fields each of which iterate over the minibatch + outputs = zip(*outputs_list) + else: + outputs = self.dataset.function(*function_inputs) + if self.dataset.cache: + # convert the list of fields, each of which can iterate over the minibatch + # into a list of examples in the minibatch (each of which is a list of field values) + outputs_list = zip(*outputs) + # copy the outputs_list into the cache + for i in xrange(cache_len,example_index): + self.cached_examples.append(None) + self.cached_examples += outputs_list else: - return outputs + outputs = self.dataset.function(*function_inputs) + + return Example(self.fieldnames_not_in_input+self.dataset.output_fields, + [src_examples[field_name] for field_name in self.fieldnames_not_in_input]+outputs) + for fieldname in fieldnames: assert fieldname in self.output_fields or self.src.hasFields(fieldname) return Iterator(self) +def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): + """ + Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the + user to define a set of fields as the 'input' field and a set of fields + as the 'target' field. Optionally, a single weight_field can also be defined. + """ + args = ((input_fields,'input'),(output_fields,'target')) + if weight_field: args+=(([weight_field],'weight')) + return src_dataset.rename(*args) + + + +