Mercurial > pylearn
diff dataset.py @ 26:672fe4b23032
Fixed dataset errors so that _test_dataset.py works again.
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Fri, 11 Apr 2008 11:14:54 -0400 |
parents | 526e192b0699 |
children | 541a273bc89f |
line wrap: on
line diff
--- a/dataset.py Wed Apr 09 18:27:13 2008 -0400 +++ b/dataset.py Fri Apr 11 11:14:54 2008 -0400 @@ -1,6 +1,7 @@ from lookup_list import LookupList Example = LookupList +import copy class AbstractFunction (Exception): """Derived class must override this function""" @@ -142,7 +143,7 @@ """ raise AbstractFunction() - def hasFields(*fieldnames): + def hasFields(self,*fieldnames): """ Return true if the given field name (or field names, if multiple arguments are given) is recognized by the DataSet (i.e. can be used as a field name in one @@ -150,7 +151,7 @@ """ raise AbstractFunction() - def merge_fields(*specifications): + def merge_fields(self,*specifications): """ Return a new dataset that maps old fields (of self) to new fields (of the returned dataset). The minimal syntax that should be supported is the following: @@ -162,7 +163,7 @@ """ raise AbstractFunction() - def merge_field_values(*field_value_pairs) + def merge_field_values(self,*field_value_pairs): """ Return the value that corresponds to merging the values of several fields, given as arguments (field_name, field_value) pairs with self.hasField(field_name). @@ -172,22 +173,28 @@ fieldnames,fieldvalues = zip(*field_value_pairs) raise ValueError("Unable to merge values of these fields:"+repr(fieldnames)) - def examples2minibatch(examples): + def examples2minibatch(self,examples): """ Combine a list of Examples into a minibatch. A minibatch is an Example whose fields are iterable over the examples of the minibatch. """ raise AbstractFunction() - def rename(rename_dict): + def rename(self,rename_dict): """ Return a new dataset that renames fields, using a dictionnary that maps old field names to new field names. The only fields visible by the returned dataset are those whose names are keys of the rename_dict. """ - return RenamingDataSet(self,rename_dict) - - def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): + self_class = self.__class__ + class SelfRenamingDataSet(RenamingDataSet,self_class): + pass + self.__class__ = SelfRenamingDataSet + # set the rename_dict and src fields + SelfRenamingDataSet.__init__(self,self,rename_dict) + return self + + def applyFunction(self,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True): """ Return a dataset that contains as fields the results of applying the given function (example-wise) to the specified input_fields. The @@ -204,25 +211,6 @@ """ return ApplyFunctionDataSet(function, input_fields, output_fields, copy_inputs, accept_minibatches, cache) -class RenamingDataSet(DataSet): - """A DataSet that wraps another one, and makes it look like the field names - are different - - Renaming is done by a dictionary that maps new names to the old ones used in - self.src. - """ - def __init__(self, src, rename_dct): - DataSet.__init__(self) - self.src = src - self.rename_dct = copy.copy(rename_dct) - - def minibatches(self, - fieldnames = DataSet.minibatches_fieldnames, - minibatch_size = DataSet.minibatches_minibatch_size, - n_batches = DataSet.minibatches_n_batches): - dct = self.rename_dct - new_fieldnames = [dct.get(f, f) for f in fieldnames] - return self.src.minibatches(new_fieldnames, minibatches_size, n_batches) class FiniteLengthDataSet(DataSet): """ @@ -278,10 +266,11 @@ def __init__(self): DataSet.__init__(self) - def hasFields(*fieldnames): + def hasFields(self,*fields): has_fields=True - for fieldname in fieldnames: - if fieldname not in self.fields.keys(): + fieldnames = self.fieldNames() + for name in fields: + if name not in fieldnames: has_fields=False return has_fields @@ -291,6 +280,30 @@ raise AbstractFunction() +class RenamingDataSet(FiniteWidthDataSet): + """A DataSet that wraps another one, and makes it look like the field names + are different + + Renaming is done by a dictionary that maps new names to the old ones used in + self.src. + """ + def __init__(self, src, rename_dct): + DataSet.__init__(self) + self.src = src + self.rename_dct = copy.copy(rename_dct) + + def fieldNames(self): + return self.rename_dct.keys() + + def minibatches(self, + fieldnames = DataSet.minibatches_fieldnames, + minibatch_size = DataSet.minibatches_minibatch_size, + n_batches = DataSet.minibatches_n_batches): + dct = self.rename_dct + new_fieldnames = [dct.get(f, f) for f in fieldnames] + return self.src.minibatches(new_fieldnames, minibatches_size, n_batches) + + # we may want ArrayDataSet defined in another python file import numpy @@ -548,19 +561,6 @@ c+=slice_width return result - def rename(*new_field_specifications): - """ - Return a new dataset that maps old fields (of self) to new fields (of the returned - dataset). The minimal syntax that should be supported is the following: - new_field_specifications = [new_field_spec1, new_field_spec2, ...] - new_field_spec = ([old_field1, old_field2, ...], new_field) - In general both old_field and new_field should be strings, but some datasets may also - support additional indexing schemes within each field (e.g. column slice - of a matrix-like field). - """ - # if all old fields of each spec are - raise NotImplementedError() - class ApplyFunctionDataSet(DataSet): """ A dataset that contains as fields the results of applying @@ -599,7 +599,7 @@ else: # compute a list with one tuple per example, with the function outputs self.cached_examples = [ function(input) for input in src.zip(input_fields) ] - else if cache: + elif cache: # maybe a fixed-size array kind of structure would be more efficient than a list # in the case where src is FiniteDataSet. -YB self.cached_examples = []