view dataset.py @ 456:131e19dfe793

Added sandbox.embeddings
author Joseph Turian <turian@iro.umontreal.ca>
date Tue, 07 Oct 2008 17:56:52 -0400
parents fb62f0e4bcfe
children
line wrap: on
line source


from lookup_list import LookupList as Example
from common.misc import unique_elements_list_intersection
from string import join
from sys import maxint
import numpy, copy

from exceptions import *

class AttributesHolder(object):
    def __init__(self): pass

    def attributeNames(self):
        raise AbstractFunction()

    def setAttributes(self,attribute_names,attribute_values,make_copies=False):
        """
        Allow the attribute_values to not be a list (but a single value) if the attribute_names is of length 1.
        """
        if len(attribute_names)==1 and not (isinstance(attribute_values,list) or isinstance(attribute_values,tuple) ):
            attribute_values = [attribute_values]
        if make_copies:
            for name,value in zip(attribute_names,attribute_values):
                self.__setattr__(name,copy.deepcopy(value))
        else:
            for name,value in zip(attribute_names,attribute_values):
                self.__setattr__(name,value)

    def getAttributes(self,attribute_names=None, return_copy=False):
        """
        Return all (if attribute_names=None, in the order of attributeNames()) or a specified subset of attributes.
        """
        if attribute_names is None:
            attribute_names = self.attributeNames()
        if return_copy:
            return [copy.copy(self.__getattribute__(name)) for name in attribute_names]
        else:
            return [self.__getattribute__(name) for name in attribute_names]
    
class DataSet(AttributesHolder):
    """A virtual base class for datasets.

    A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction
    with learning algorithms (for training and testing them): rows/records are called examples, and
    columns/attributes are called fields. The field value for a particular example can be an arbitrary
    python object, which depends on the particular dataset.
    
    We call a DataSet a 'stream' when its length is unbounded (in which case its __len__ method
    should return sys.maxint).

    A DataSet is a generator of iterators; these iterators can run through the
    examples or the fields in a variety of ways.  A DataSet need not necessarily have a finite
    or known length, so this class can be used to interface to a 'stream' which
    feeds on-line learning (however, as noted below, some operations are not
    feasible or not recommended on streams).

    To iterate over examples, there are several possibilities:
     - for example in dataset:
     - for val1,val2,... in dataset:
     - for example in dataset(field1, field2,field3, ...):
     - for val1,val2,val3 in dataset(field1, field2,field3):
     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
     Each of these is documented below. All of these iterators are expected
     to provide, in addition to the usual 'next()' method, a 'next_index()' method
     which returns a non-negative integer pointing to the position of the next
     example that will be returned by 'next()' (or of the first example in the
     next minibatch returned). This is important because these iterators
     can wrap around the dataset in order to do multiple passes through it,
     in possibly unregular ways if the minibatch size is not a divisor of the
     dataset length.
    
    To iterate over fields, one can do
     - for field in dataset.fields():
         for field_value in field: # iterate over the values associated to that field for all the dataset examples
     - for field in dataset(field1,field2,...).fields() to select a subset of fields
     - for field in dataset.fields(field1,field2,...) to select a subset of fields
    and each of these fields is iterable over the examples:
     - for field_examples in dataset.fields():
        for example_value in field_examples:
           ...
    but when the dataset is a stream (unbounded length), it is not recommended to do 
    such things because the underlying dataset may refuse to access the different fields in
    an unsynchronized ways. Hence the fields() method is illegal for streams, by default.
    The result of fields() is a L{DataSetFields} object, which iterates over fields,
    and whose elements are iterable over examples. A DataSetFields object can
    be turned back into a DataSet with its examples() method::
      dataset2 = dataset1.fields().examples()
    and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1).
    
    Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content.

    Note: The content of a field can be of any type. Field values can also be 'missing'
    (e.g. to handle semi-supervised learning), and in the case of numeric (numpy array)
    fields (i.e. an ArrayFieldsDataSet), NaN plays the role of a missing value. 
    What about non-numeric values? None.

    Dataset elements can be indexed and sub-datasets (with a subset
    of examples) can be extracted. These operations are not supported
    by default in the case of streams.

     - dataset[:n] returns an Example with the n first examples.

     - dataset[i1:i2:s] returns an Example with the examples i1,i1+s,...i2-s.

     - dataset[i] returns an Example.

     - dataset[[i1,i2,...in]] returns an Example with examples i1,i2,...in.

    A similar command gives you a DataSet instead of Examples :

     - dataset.subset[:n] returns a DataSet with the n first examples.

     - dataset.subset[i1:i2:s] returns a DataSet with the examples i1,i1+s,...i2-s.

     - dataset.subset[i] returns a DataSet.

     - dataset.subset[[i1,i2,...in]] returns a DataSet with examples i1,i2,...in.


     - dataset.<property> returns the value of a property associated with
     the name <property>. The following properties should be supported:
          - 'description': a textual description or name for the dataset
          - 'fieldtypes': a list of types (one per field)
    A DataSet may have other attributes that it makes visible to other objects. These are
    used to store information that is not example-wise but global to the dataset.
    The list of names of these attributes is given by the attribute_names() method.

    Datasets can be concatenated either vertically (increasing the length) or
    horizontally (augmenting the set of fields), if they are compatible, using
    the following operations (with the same basic semantics as numpy.hstack
    and numpy.vstack):

     - dataset1 | dataset2 | dataset3 == dataset.hstack([dataset1,dataset2,dataset3])

    creates a new dataset whose list of fields is the concatenation of the list of
    fields of the argument datasets. This only works if they all have the same length.

     - dataset1 & dataset2 & dataset3 == dataset.vstack([dataset1,dataset2,dataset3])

    creates a new dataset that concatenates the examples from the argument datasets
    (and whose length is the sum of the length of the argument datasets). This only
    works if they all have the same fields.

    According to the same logic, and viewing a DataSetFields object associated to
    a DataSet as a kind of transpose of it, fields1 & fields2 concatenates fields of
    a DataSetFields fields1 and fields2, and fields1 | fields2 concatenates their
    examples.

    A dataset can hold arbitrary key-value pairs that may be used to access meta-data
    or other properties of the dataset or associated with the dataset or the result
    of a computation stored in a dataset. These can be accessed through the [key] syntax
    when key is a string (or more specifically, neither an integer, a slice, nor a list).

    A DataSet sub-class should always redefine the following methods:
       - __len__ if it is not a stream
       - fieldNames
       - minibatches_nowrap (called by DataSet.minibatches())
    For efficiency of implementation, a sub-class might also want to redefine
       - valuesHStack
       - valuesVStack
       - hasFields
       - __getitem__ may not be feasible with some streams
       - __iter__
    A sub-class should also append attributes to self._attribute_names
    (the default value returned by attributeNames()).
    By convention, attributes not in attributeNames() should have a name
    starting with an underscore.
    @todo enforce/test that convention!
    """

    numpy_vstack = lambda fieldname,values: numpy.vstack(values)
    numpy_hstack = lambda fieldnames,values: numpy.hstack(values)
        
    def __init__(self, description=None, fieldnames=None, fieldtypes=None):
        """
        @type fieldnames: list of strings
        @type fieldtypes: list of python types, same length as fieldnames
        @type description: string 
        @param description: description/name for this dataset
        """
        def default_desc():
            return type(self).__name__ \
                    + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )"

        #self.fieldnames = fieldnames

        self.fieldtypes = fieldtypes if fieldtypes is not None \
                else [None]*1 #len(fieldnames)

        self.description =  default_desc() if description is None \
                else description
        self._attribute_names = ["description"]


    attributeNames = property(lambda self: copy.copy(self._attribute_names))

    def __contains__(self, fieldname):
        return (fieldname in self.fieldNames()) \
                or (fieldname in self.attributeNames())

    def __iter__(self):
        """Supports the syntax "for i in dataset: ..."

        Using this syntax, "i" will be an Example instance (or equivalent) with
        all the fields of DataSet self.  Every field of "i" will give access to
        a field of a single example.  Fields should be accessible via
        i["fielname"] or i[3] (in the order defined by the elements of the
        Example returned by this iterator), but the derived class is free
        to accept any type of identifier, and add extra functionality to the iterator.

        The default implementation calls the minibatches iterator and extracts the first example of each field.
        """
        return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1))

    def __len__(self):
        """
        len(dataset) returns the number of examples in the dataset.
        By default, a DataSet is a 'stream', i.e. it has an unbounded length (sys.maxint).
        Sub-classes which implement finite-length datasets should redefine this method.
        Some methods only make sense for finite-length datasets.
        """
        from sys import maxint
        return maxint


    class MinibatchToSingleExampleIterator(object):
        """
        Converts the result of minibatch iterator with minibatch_size==1 into
        single-example values in the result. Therefore the result of
        iterating on the dataset itself gives a sequence of single examples
        (whereas the result of iterating over minibatches gives in each
        Example field an iterable object over the individual examples in
        the minibatch).
        """
        def __init__(self, minibatch_iterator):
            self.minibatch_iterator = minibatch_iterator
            self.minibatch = None
        def __iter__(self): #makes for loop work
            return self
        def next(self):
            size1_minibatch = self.minibatch_iterator.next()
            if not self.minibatch:
                names = size1_minibatch.keys()
                # next lines are a hack, but there was problem when we were getting [array(327)] for instance
                try:
                    values = [value[0] for value in size1_minibatch.values()]
                except :
                    values = [value for value in size1_minibatch.values()]
                self.minibatch = Example(names,values)
            else:
                self.minibatch._values = [value[0] for value in size1_minibatch.values()]
            return self.minibatch
        
        def next_index(self):
            return self.minibatch_iterator.next_index()

    class MinibatchWrapAroundIterator(object):
        """
        An iterator for minibatches that handles the case where we need to wrap around the
        dataset because n_batches*minibatch_size > len(dataset). It is constructed from
        a dataset that provides a minibatch iterator that does not need to handle that problem.
        This class is a utility for dataset subclass writers, so that they do not have to handle
        this issue multiple times, nor check that fieldnames are valid, nor handle the
        empty fieldnames (meaning 'use all the fields').
        """
        def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
            self.dataset=dataset
            self.fieldnames=fieldnames
            self.minibatch_size=minibatch_size
            self.n_batches=n_batches
            self.n_batches_done=0
            self.next_row=offset
            self.L=len(dataset)
            self.offset=offset % self.L
            ds_nbatches =  (self.L-self.next_row)/self.minibatch_size
            if n_batches is not None:
                ds_nbatches = min(n_batches,ds_nbatches)
            if fieldnames:
                assert dataset.hasFields(*fieldnames)
            else:
                self.fieldnames=dataset.fieldNames()
            self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size, ds_nbatches,self.next_row)

        def __iter__(self):
            return self

        def next_index(self):
            return self.next_row

        def next(self):
            if self.n_batches and self.n_batches_done==self.n_batches:
                raise StopIteration
            elif not self.n_batches and self.next_row ==self.L:
                raise StopIteration
            upper = self.next_row+self.minibatch_size
            if upper <=self.L:
                minibatch = self.iterator.next()
            else:
                if not self.n_batches:
                    upper=min(upper, self.L)
                    # if their is not a fixed number of batch, we continue to the end of the dataset.
                    # this can create a minibatch that is smaller then the minibatch_size
                    assert (self.L-self.next_row)<=self.minibatch_size
                    minibatch = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
                else:
                    # we must concatenate (vstack) the bottom and top parts of our minibatch
                    # first get the beginning of our minibatch (top of dataset)
                    first_part = self.dataset.minibatches_nowrap(self.fieldnames,self.L-self.next_row,1,self.next_row).next()
                    second_part = self.dataset.minibatches_nowrap(self.fieldnames,upper-self.L,1,0).next()
                    minibatch = Example(self.fieldnames,
                                        [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
                                         for name in self.fieldnames])
            self.next_row=upper
            self.n_batches_done+=1
            if upper >= self.L and self.n_batches:
                self.next_row -= self.L
                ds_nbatches =  (self.L-self.next_row)/self.minibatch_size
                if self.n_batches is not None:
                    ds_nbatches = min(self.n_batches,ds_nbatches)
                self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size,
                                                                ds_nbatches,self.next_row)
            return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack,
                                                  self.dataset.valuesHStack),
                                 minibatch.keys())


    minibatches_fieldnames = None
    minibatches_minibatch_size = 1
    minibatches_n_batches = None
    def minibatches(self,
                    fieldnames = minibatches_fieldnames,
                    minibatch_size = minibatches_minibatch_size,
                    n_batches = minibatches_n_batches,
                    offset = 0):
        """
        Return an iterator that supports three forms of syntax:

            for i in dataset.minibatches(None,**kwargs): ...

            for i in dataset.minibatches([f1, f2, f3],**kwargs): ...

            for i1, i2, i3 in dataset.minibatches([f1, f2, f3],**kwargs): ...

        Using the first two syntaxes, "i" will be an indexable object, such as a list,
        tuple, or Example instance. In both cases, i[k] is a list-like container
        of a batch of current examples. In the second case, i[0] is
        list-like container of the f1 field of a batch current examples, i[1] is
        a list-like container of the f2 field, etc.

        Using the first syntax, all the fields will be returned in "i".
        Using the third syntax, i1, i2, i3 will be list-like containers of the
        f1, f2, and f3 fields of a batch of examples on each loop iteration.

        The minibatches iterator is expected to return upon each call to next()
        a DataSetFields object, which is a Example (indexed by the field names) whose
        elements are iterable and indexable over the minibatch examples, and which keeps a pointer to
        a sub-dataset that can be used to iterate over the individual examples
        in the minibatch. Hence a minibatch can be converted back to a regular
        dataset or its fields can be looked at individually (and possibly iterated over).

        PARAMETERS
        - fieldnames (list of any type, default None):
        The loop variables i1, i2, i3 (in the example above) should contain the
        f1, f2, and f3 fields of the current batch of examples.  If None, the
        derived class can choose a default, e.g. all fields.

        - minibatch_size (integer, default 1)
        On every iteration, the variables i1, i2, i3 will have
        exactly minibatch_size elements. e.g. len(i1) == minibatch_size

        @DEPRECATED n_batches : not used anywhere
        - n_batches (integer, default None)
        The iterator will loop exactly this many times, and then stop.  If None,
        the derived class can choose a default.  If (-1), then the returned
        iterator should support looping indefinitely.

        - offset (integer, default 0)
        The iterator will start at example 'offset' in the dataset, rather than the default.
        
        Note: A list-like container is something like a tuple, list, numpy.ndarray or
        any other object that supports integer indexing and slicing.

        @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete
        batches only, raise StopIteration.
        @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it.

        """
        #return DataSet.MinibatchWrapAroundIterator(self, fieldnames, minibatch_size, n_batches,offset)
        assert offset >= 0
        assert offset < len(self)
        assert offset + minibatch_size -1 < len(self)
        if fieldnames == None :
            fieldnames = self.fieldNames()
        return self.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset)

    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
        """
        This is the minibatches iterator generator that sub-classes must define.
        It does not need to worry about wrapping around multiple times across the dataset,
        as this is handled by MinibatchWrapAroundIterator when DataSet.minibatches() is called.
        The next() method of the returned iterator does not even need to worry about
        the termination condition (as StopIteration will be raised by DataSet.minibatches
        before an improper call to minibatches_nowrap's next() is made).
        That next() method can assert that its next row will always be within [0,len(dataset)).
        The iterator returned by minibatches_nowrap does not need to implement
        a next_index() method either, as this will be provided by MinibatchWrapAroundIterator.
        """
        raise AbstractFunction()

    def is_unbounded(self):
        """
        Tests whether a dataset is unbounded (e.g. a stream).
        """
        return len(self)==maxint

    def hasFields(self,*fieldnames):
        """
        Return true if the given field name (or field names, if multiple arguments are
        given) is recognized by the DataSet (i.e. can be used as a field name in one
        of the iterators).

        The default implementation may be inefficient (O(# fields in dataset)), as it calls the fieldNames()
        method. Many datasets may store their field names in a dictionary, which would allow more efficiency.
        """
        return len(unique_elements_list_intersection(fieldnames,self.fieldNames()))>0
        
    def fieldNames(self):
        """
        Return the list of field names that are supported by the iterators,
        and for which hasFields(fieldname) would return True.
        """
        raise AbstractFunction()

    def __call__(self,*fieldnames):
        """
        Return a dataset that sees only the fields whose name are specified.
        """
        assert self.hasFields(*fieldnames)
        #return self.fields(*fieldnames).examples()
        fieldnames_list = list(fieldnames)
        return FieldsSubsetDataSet(self,fieldnames_list)

    def cached_fields_subset(self,*fieldnames) :
        """
        Behaviour is supposed to be the same as __call__(*fieldnames), but the dataset returned is cached.
        @see : dataset.__call__
        """
        assert self.hasFields(*fieldnames)
        return self.fields(*fieldnames).examples()

    def fields(self,*fieldnames):
        """
        Return a DataSetFields object associated with this dataset.
        """
        return DataSetFields(self,fieldnames)

    def getitem_key(self, fieldname):
        """A not-so-well thought-out place to put code that used to be in
        getitem.
        """
        #removing as per discussion June 4. --JSB

        i = fieldname
        # else check for a fieldname
        if self.hasFields(i):
            return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0]
        # else we are trying to access a property of the dataset
        assert i in self.__dict__ # else it means we are trying to access a non-existing property
        return self.__dict__[i]

    def __getitem__(self,i):
        """
        @rtype: Example 
        @returns: single or multiple examples

        @type i: integer or slice or <iterable> of integers
        @param i:
            dataset[i] returns the (i+1)-th example of the dataset.
            dataset[i:j] returns a LookupList with examples i,i+1,...,j-1.
            dataset[i:j:s] returns a LookupList with examples i,i+2,i+4...,j-2.
            dataset[[i1,i2,..,in]] returns a LookupList with examples i1,i2,...,in.

        @note:
        Some stream datasets may be unable to implement random access, i.e.
        arbitrary slicing/indexing because they can only iterate through
        examples one or a minibatch at a time and do not actually store or keep
        past (or future) examples.

        The default implementation of getitem uses the minibatches iterator
        to obtain one example, one slice, or a list of examples. It may not
        always be the most efficient way to obtain the result, especially if
        the data are actually stored in a memory array.
        """

        if type(i) is int:
            assert i >= 0 # TBM: see if someone complains and want negative i
            if i >= len(self) :
                raise IndexError
            i_batch = self.minibatches_nowrap(self.fieldNames(),
                    minibatch_size=1, n_batches=1, offset=i)
            return DataSet.MinibatchToSingleExampleIterator(i_batch).next()

        #if i is a contiguous slice
        if type(i) is slice and (i.step in (None, 1)):
            offset = 0 if i.start is None else i.start
            upper_bound = len(self) if i.stop is None else i.stop
            upper_bound = min(len(self) , upper_bound)
            #return MinibatchDataSet(self.minibatches_nowrap(self.fieldNames(),
            #        minibatch_size=upper_bound - offset,
            #        n_batches=1,
            #        offset=offset).next())
            # now returns a LookupList
            return self.minibatches_nowrap(self.fieldNames(),
                    minibatch_size=upper_bound - offset,
                    n_batches=1,
                    offset=offset).next()

        # if slice has a step param, convert it to list and handle it with the
        # list code
        if type(i) is slice:
            offset = 0 if i.start is None else i.start
            upper_bound = len(self) if i.stop is None else i.stop
            upper_bound = min(len(self) , upper_bound)
            i = list(range(offset, upper_bound, i.step))

        # handle tuples, arrays, lists
        if hasattr(i, '__getitem__'):
            for idx in i:
                #dis-allow nested slices
                if not isinstance(idx, int):
                    raise TypeError(idx)
                if idx >= len(self) :
                    raise IndexError
            # call back into self.__getitem__
            examples = [self.minibatches_nowrap(self.fieldNames(),
                    minibatch_size=1, n_batches=1, offset=ii).next()
                    for ii in i]
            # re-index the fields in each example by field instead of by example
            field_values = [[] for blah in  self.fieldNames()]
            for e in examples:
                for f,v in zip(field_values, e):
                    f.append(v)
            #build them into a LookupList (a.ka. Example)
            zz = zip(self.fieldNames(),field_values)
            vst = [self.valuesVStack(fieldname,field_values) for fieldname,field_values in zz]
            example = Example(self.fieldNames(), vst)
            #return MinibatchDataSet(example, self.valuesVStack, self.valuesHStack)
            # now returns a LookupList
            return example

        # what in the world is i?
        raise TypeError(i, type(i))


    """
    Enables the call dataset.subset[a:b:c] that will return a DataSet
    around the examples returned by __getitem__(slice(a,b,c))
       
    @SEE DataSet.__getsubset(self)
    """
    subset = property(lambda s : s.__getsubset(),doc="returns a subset as a DataSet")


    def __getsubset(self) :
        """
        Enables the call data.subset[a:b:c], returns a DataSet.
        Default implementation is a simple wrap around __getitem__() using MinibatchDataSet.

        @RETURN DataSet
        @SEE DataSet.subset = property(lambda s : s.__getsubset())
        """
        _self = self
        class GetSliceReturnsDataSet(object) :
            def __getitem__(self,slice) :
                return MinibatchDataSet(_self.__getitem__(slice))
        return GetSliceReturnsDataSet()



    def valuesHStack(self,fieldnames,fieldvalues):
        """
        Return a value that corresponds to concatenating (horizontally) several field values.
        This can be useful to merge some fields. The implementation of this operation is likely
        to involve a copy of the original values. When the values are numpy arrays, the
        result should be numpy.hstack(values). If it makes sense, this operation should
        work as well when each value corresponds to multiple examples in a minibatch
        e.g. if each value is a Ni-vector and a minibatch of length L is a LxNi matrix,
        then the result should be a Lx(N1+N2+..) matrix equal to numpy.hstack(values).
        The default is to use numpy.hstack for numpy.ndarray values, and a list
        pointing to the original values for other data types.
        """
        all_numpy=True
        for value in fieldvalues:
            if not type(value) is numpy.ndarray:
                all_numpy=False
        if all_numpy:
            return numpy.hstack(fieldvalues)
        # the default implementation of horizontal stacking is to put values in a list
        return fieldvalues

    def valuesVStack(self,fieldname,values):
        """
        @param fieldname: the name of the field from which the values were taken 
        @type fieldname: any type 

        @param values: bits near the beginning or end of the dataset 
        @type values: list of minibatches (returned by minibatches_nowrap) 

        @return: the concatenation (stacking) of the values 
        @rtype: something suitable as a minibatch field 
        """
        rval = []
        for v in values:
            rval.extend(v)
        return rval

    def __or__(self,other):
        """
        dataset1 | dataset2 returns a dataset whose list of fields is the concatenation of the list of
        fields of the argument datasets. This only works if they all have the same length.
        """
        return HStackedDataSet([self,other])

    def __and__(self,other):
        """
        dataset1 & dataset2 is a dataset that concatenates the examples from the argument datasets
        (and whose length is the sum of the length of the argument datasets). This only
        works if they all have the same fields.
        """
        return VStackedDataSet([self,other])

def hstack(datasets):
    """
    hstack(dataset1,dataset2,...) returns dataset1 | datataset2 | ...
    which is a dataset whose fields list is the concatenation of the fields
    of the individual datasets.
    """
    assert len(datasets)>0
    if len(datasets)==1:
        return datasets[0]
    return HStackedDataSet(datasets)

def vstack(datasets):
    """
    vstack(dataset1,dataset2,...) returns dataset1 & datataset2 & ...
    which is a dataset which iterates first over the examples of dataset1, then
    over those of dataset2, etc.
    """
    assert len(datasets)>0
    if len(datasets)==1:
        return datasets[0]
    return VStackedDataSet(datasets)

class FieldsSubsetDataSet(DataSet):
    """
    A sub-class of L{DataSet} that selects a subset of the fields.
    """
    def __init__(self,src,fieldnames):
        self.src=src
        self.fieldnames=fieldnames
        assert src.hasFields(*fieldnames)
        self.valuesHStack = src.valuesHStack
        self.valuesVStack = src.valuesVStack

    def __len__(self): return len(self.src)
    
    def fieldNames(self):
        return self.fieldnames

    def __iter__(self):
        class FieldsSubsetIterator(object):
            def __init__(self,ds):
                self.ds=ds
                self.src_iter=ds.src.__iter__()
                self.example=None
            def __iter__(self): return self
            def next(self):
                complete_example = self.src_iter.next()
                if self.example:
                    self.example._values=[complete_example[field]
                                          for field in self.ds.fieldnames]
                else:
                    self.example=Example(self.ds.fieldnames,
                                         [complete_example[field] for field in self.ds.fieldnames])
                return self.example
        return FieldsSubsetIterator(self)

    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
        assert self.hasFields(*fieldnames)
        return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset)
    def dontuse__getitem__(self,i):
        return FieldsSubsetDataSet(self.src[i],self.fieldnames)
    
class RenamedFieldsDataSet(DataSet):
    """
    A sub-class of L{DataSet} that selects and renames a subset of the fields.
    """
    def __init__(self,src,src_fieldnames,new_fieldnames):
        self.src=src
        self.src_fieldnames=src_fieldnames
        self.new_fieldnames=new_fieldnames
        assert src.hasFields(*src_fieldnames)
        assert len(src_fieldnames)==len(new_fieldnames)
        self.valuesHStack = src.valuesHStack
        self.valuesVStack = src.valuesVStack
        self.lookup_fields = Example(new_fieldnames,src_fieldnames)

    def __len__(self): return len(self.src)
    
    def fieldNames(self):
        return self.new_fieldnames

    def __iter__(self):
        class FieldsSubsetIterator(object):
            def __init__(self,ds):
                self.ds=ds
                self.src_iter=ds.src.__iter__()
                self.example=None
            def __iter__(self): return self
            def next(self):
                complete_example = self.src_iter.next()
                if self.example:
                    self.example._values=[complete_example[field]
                                          for field in self.ds.src_fieldnames]
                else:
                    self.example=Example(self.ds.new_fieldnames,
                                         [complete_example[field]
                                          for field in self.ds.src_fieldnames])
                return self.example
        return FieldsSubsetIterator(self)

    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
        assert self.hasFields(*fieldnames)
        cursor = Example(fieldnames,[0]*len(fieldnames))
        for batch in self.src.minibatches_nowrap([self.lookup_fields[f] for f in fieldnames],minibatch_size,n_batches,offset):
            cursor._values=batch._values
            yield cursor
    
    def __getitem__(self,i):
#        return FieldsSubsetDataSet(self.src[i],self.new_fieldnames)
        complete_example = self.src[i]
        return Example(self.new_fieldnames,
                             [complete_example[field]
                              for field in self.src_fieldnames])



class DataSetFields(Example):
    """
    Although a L{DataSet} iterates over examples (like rows of a matrix), an associated
    DataSetFields iterates over fields (like columns of a matrix), and can be understood
    as a transpose of the associated dataset.

    To iterate over fields, one can do
    * for fields in dataset.fields()
    * for fields in dataset(field1,field2,...).fields() to select a subset of fields
    * for fields in dataset.fields(field1,field2,...) to select a subset of fields
    and each of these fields is iterable over the examples:
    * for field_examples in dataset.fields():
        for example_value in field_examples:
           ...
    but when the dataset is a stream (unbounded length), it is not recommended to do 
    such things because the underlying dataset may refuse to access the different fields in
    an unsynchronized ways. Hence the fields() method is illegal for streams, by default.
    The result of fields() is a DataSetFields object, which iterates over fields,
    and whose elements are iterable over examples. A DataSetFields object can
    be turned back into a DataSet with its examples() method:
      dataset2 = dataset1.fields().examples()
    and dataset2 should behave exactly like dataset1 (in fact by default dataset2==dataset1).

    DataSetFields can be concatenated vertically or horizontally. To be consistent with
    the syntax used for DataSets, the | concatenates the fields and the & concatenates
    the examples.
    """
    def __init__(self,dataset,fieldnames):
        original_dataset=dataset
        if not fieldnames:
            fieldnames=dataset.fieldNames()
        elif not list(fieldnames)==list(dataset.fieldNames()):
            #we must cast to list, othersize('x','y')!=['x','y']
            dataset = FieldsSubsetDataSet(dataset,fieldnames)
        assert dataset.hasFields(*fieldnames)
        self.dataset=dataset

        if isinstance(dataset,MinibatchDataSet):
            Example.__init__(self,fieldnames,list(dataset._fields))
        elif isinstance(original_dataset,MinibatchDataSet):
            Example.__init__(self,fieldnames,
                                [original_dataset._fields[field]
                                 for field in fieldnames])
        else:
            minibatch_iterator = dataset.minibatches(fieldnames,
                                                     minibatch_size=len(dataset),
                                                     n_batches=1)
            minibatch=minibatch_iterator.next()
            Example.__init__(self,fieldnames,minibatch)
        
    def examples(self):
        return self.dataset
    
    def __or__(self,other):
        """
        fields1 | fields2 is a DataSetFields that whose list of examples is the concatenation
        of the list of examples of DataSetFields fields1 and fields2.
        """
        return (self.examples() + other.examples()).fields()

    def __and__(self,other):
        """
        fields1 + fields2 is a DataSetFields that whose list of fields is the concatenation
        of the fields of DataSetFields fields1 and fields2.
        """
        return (self.examples() | other.examples()).fields()

    
class MinibatchDataSet(DataSet):
    """
    Turn a L{Example} of same-length (iterable) fields into an example-iterable dataset.
    Each element of the lookup-list should be an iterable and sliceable, all of the same length.
    """
    def __init__(self,fields_lookuplist,values_vstack=DataSet().valuesVStack,
                 values_hstack=DataSet().valuesHStack):
        """
        The user can (and generally should) also provide values_vstack(fieldname,fieldvalues)
        and a values_hstack(fieldnames,fieldvalues) functions behaving with the same
        semantics as the DataSet methods of the same name (but without the self argument).
        """

        self._fields=fields_lookuplist
        assert len(fields_lookuplist)>0
        self.length=len(fields_lookuplist[0])
        for field in fields_lookuplist[1:]:
            if self.length != len(field) :
                print 'self.length = ',self.length
                print 'len(field) = ', len(field)
                print 'self._fields.keys() = ', self._fields.keys()
                print 'field=',field
                print 'fields_lookuplist=', fields_lookuplist
            assert self.length==len(field)
        self.valuesVStack=values_vstack
        self.valuesHStack=values_hstack

    def __len__(self):
        return self.length

    def dontuse__getitem__(self,i):
        if type(i) in (slice,list):
            return DataSetFields(MinibatchDataSet(
                Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames())
        if type(i) is int:
            return Example(self._fields.keys(),[field[i] for field in self._fields])
        if self.hasFields(i):
            return self._fields[i]
        assert i in self.__dict__ # else it means we are trying to access a non-existing property
        return self.__dict__[i]

    def fieldNames(self):
        return self._fields.keys()

    def hasFields(self,*fieldnames):
        for fieldname in fieldnames:
            if fieldname not in self._fields.keys():
                return False
        return True

    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
        #@TODO bug somewhere here, fieldnames doesnt seem to be well handled
        class Iterator(object):
            def __init__(self,ds,fieldnames):
                # tbm: added two next lines to handle fieldnames
                if fieldnames is None: fieldnames = ds._fields.keys()
                self.fieldnames = fieldnames

                self.ds=ds
                self.next_example=offset
                assert minibatch_size >= 0
                if offset+minibatch_size > ds.length:
                    raise NotImplementedError()
            def __iter__(self):
                return self
            def next(self):
                upper = self.next_example+minibatch_size
                if upper > len(self.ds) :
                    raise StopIteration()
                assert upper<=len(self.ds) # instead of self.ds.length
                #minibatch = Example(self.ds._fields.keys(),
                #                    [field[self.next_example:upper]
                #                     for field in self.ds._fields])
                # tbm: modif to use fieldnames
                values = []
                for f in self.fieldnames :
                    #print 'we have field',f,'in fieldnames'
                    values.append( self.ds._fields[f][self.next_example:upper] )
                minibatch = Example(self.fieldnames,values)
                #print minibatch
                self.next_example+=minibatch_size
                return minibatch

        # tbm: added fieldnames to handle subset of fieldnames
        return Iterator(self,fieldnames)

class HStackedDataSet(DataSet):
    """
    A L{DataSet} that wraps several datasets and shows a view that includes all their fields,
    i.e. whose list of fields is the concatenation of their lists of fields.

    If a field name is found in more than one of the datasets, then either an error is
    raised or the fields are renamed (either by prefixing the __name__ attribute 
    of the dataset + ".", if it exists, or by suffixing the dataset index in the argument list).

    @todo: automatically detect a chain of stacked datasets due to A | B | C | D ...
    """
    def __init__(self,datasets,accept_nonunique_names=False,description=None,field_types=None):
        DataSet.__init__(self,description,field_types)
        self.datasets=datasets
        self.accept_nonunique_names=accept_nonunique_names
        self.fieldname2dataset={}

        def rename_field(fieldname,dataset,i):
            if hasattr(dataset,"__name__"):
                return dataset.__name__ + "." + fieldname
            return fieldname+"."+str(i)
            
        # make sure all datasets have the same length and unique field names
        self.length=None
        names_to_change=[]
        for i in xrange(len(datasets)):
            dataset = datasets[i]
            length=len(dataset)
            if self.length:
                assert self.length==length
            else:
                self.length=length
            for fieldname in dataset.fieldNames():
                if fieldname in self.fieldname2dataset: # name conflict!
                    if accept_nonunique_names:
                        fieldname=rename_field(fieldname,dataset,i)
                        names2change.append((fieldname,i))
                    else:
                        raise ValueError("Incompatible datasets: non-unique field name = "+fieldname)
                self.fieldname2dataset[fieldname]=i
        for fieldname,i in names_to_change:
            del self.fieldname2dataset[fieldname]
            self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i
            
    def __len__(self):
        return len(self.datasets[0])
    
    def hasFields(self,*fieldnames):
        for fieldname in fieldnames:
            if not fieldname in self.fieldname2dataset:
                return False
        return True

    def fieldNames(self):
        return self.fieldname2dataset.keys()
            
    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):

        class HStackedIterator(object):
            def __init__(self,hsds,iterators):
                self.hsds=hsds
                self.iterators=iterators
            def __iter__(self):
                return self
            def next(self):
                # concatenate all the fields of the minibatches
                l=Example()
                for iter in self.iterators:
                    l.append_lookuplist(iter.next())
                return l
                                     
        assert self.hasFields(*fieldnames)
        # find out which underlying datasets are necessary to service the required fields
        # and construct corresponding minibatch iterators
        if fieldnames and fieldnames!=self.fieldNames():
            datasets=set([])
            fields_in_dataset=dict([(dataset,[]) for dataset in datasets])
            for fieldname in fieldnames:
                dataset=self.datasets[self.fieldname2dataset[fieldname]]
                datasets.add(dataset)
                fields_in_dataset[dataset].append(fieldname)
            datasets=list(datasets)
            iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches,offset)
                       for dataset in datasets]
        else:
            datasets=self.datasets
            iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets]
        return HStackedIterator(self,iterators)


    def untested_valuesVStack(self,fieldname,fieldvalues):
        return self.datasets[self.fieldname2dataset[fieldname]].valuesVStack(fieldname,fieldvalues)
    
    def untested_valuesHStack(self,fieldnames,fieldvalues):
        """
        We will use the sub-dataset associated with the first fieldname in the fieldnames list
        to do the work, hoping that it can cope with the other values (i.e. won't care
        about the incompatible fieldnames). Hence this heuristic will always work if
        all the fieldnames are of the same sub-dataset.
        """
        return self.datasets[self.fieldname2dataset[fieldnames[0]]].valuesHStack(fieldnames,fieldvalues)

class VStackedDataSet(DataSet):
    """
    A L{DataSet} that wraps several datasets and shows a view that includes all their examples,
    in the order provided. This clearly assumes that they all have the same field names
    and all (except possibly the last one) are of finite length.

    @todo: automatically detect a chain of stacked datasets due to A + B + C + D ...
    """
    def __init__(self,datasets):
        self.datasets=datasets
        self.length=0
        self.index2dataset={}
        assert len(datasets)>0
        fieldnames = datasets[-1].fieldNames()
        self.datasets_start_row=[]
        # We use this map from row index to dataset index for constant-time random access of examples,
        # to avoid having to search for the appropriate dataset each time and slice is asked for.
        for dataset,k in enumerate(datasets[0:-1]):
            assert dataset.is_unbounded() # All VStacked datasets (except possibly the last) must be bounded (have a length).
            L=len(dataset)
            for i in xrange(L):
                self.index2dataset[self.length+i]=k
            self.datasets_start_row.append(self.length)
            self.length+=L
            assert dataset.fieldNames()==fieldnames
        self.datasets_start_row.append(self.length)
        self.length+=len(datasets[-1])
        # If length is very large, we should use a more memory-efficient mechanism
        # that does not store all indices
        if self.length>1000000:
            # 1 million entries would require about 60 meg for the index2dataset map
            # TODO
            print "A more efficient mechanism for index2dataset should be implemented"

    def __len__(self):
        return self.length
    
    def fieldNames(self):
        return self.datasets[0].fieldNames()

    def hasFields(self,*fieldnames):
        return self.datasets[0].hasFields(*fieldnames)

    def locate_row(self,row):
        """Return (dataset_index, row_within_dataset) for global row number"""
        dataset_index = self.index2dataset[row]
        row_within_dataset = self.datasets_start_row[dataset_index]
        return dataset_index, row_within_dataset
        
    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):

        class VStackedIterator(object):
            def __init__(self,vsds):
                self.vsds=vsds
                self.next_row=offset
                self.next_dataset_index,self.next_dataset_row=self.vsds.locate_row(offset)
                self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
                  self.next_iterator(vsds.datasets[0],offset,n_batches)

            def next_iterator(self,dataset,starting_offset,batches_left):
                L=len(dataset)
                ds_nbatches = (L-starting_offset)/minibatch_size
                if batches_left is not None:
                    ds_nbatches = max(batches_left,ds_nbatches)
                if minibatch_size>L:
                    ds_minibatch_size=L
                    n_left_in_mb=minibatch_size-L
                    ds_nbatches=1
                else:
                    n_left_in_mb=0
                return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \
                       L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb

            def move_to_next_dataset(self):
                if self.n_left_at_the_end_of_ds>0:
                    self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
                      self.next_iterator(vsds.datasets[self.next_dataset_index],
                                         self.n_left_at_the_end_of_ds,1)
                else:
                    self.next_dataset_index +=1
                    if self.next_dataset_index==len(self.vsds.datasets):
                        self.next_dataset_index = 0
                    self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
                      self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches)
                
            def __iter__(self):
                return self

            def next(self):
                dataset=self.vsds.datasets[self.next_dataset_index]
                mb = self.next_iterator.next()
                if self.n_left_in_mb:
                    extra_mb = []
                    while self.n_left_in_mb>0:
                        self.move_to_next_dataset()
                        extra_mb.append(self.next_iterator.next())
                    mb = Example(fieldnames,
                                       [dataset.valuesVStack(name,
                                                             [mb[name]]+[b[name] for b in extra_mb])
                                            for name in fieldnames])
                    
                self.next_row+=minibatch_size
                self.next_dataset_row+=minibatch_size
                if self.next_row+minibatch_size>len(dataset):
                    self.move_to_next_dataset()
                return examples
        return VStackedIterator(self)
                        
class ArrayFieldsDataSet(DataSet):
    """
    Virtual super-class of datasets whose field values are numpy array,
    thus defining valuesHStack and valuesVStack for sub-classes.
    """
    def __init__(self,description=None,field_types=None):
        DataSet.__init__(self,description,field_types)
    def untested_valuesHStack(self,fieldnames,fieldvalues):
        """Concatenate field values horizontally, e.g. two vectors
        become a longer vector, two matrices become a wider matrix, etc."""
        return numpy.hstack(fieldvalues)
    def untested_valuesVStack(self,fieldname,values):
        """Concatenate field values vertically, e.g. two vectors
        become a two-row matrix, two matrices become a longer matrix, etc."""
        return numpy.vstack(values)



class NArraysDataSet(ArrayFieldsDataSet) :
    """
    An NArraysDataSet stores fields that are numpy tensor, whose first axis
    iterates over examples. It's a generalization of ArrayDataSet.
    """
    #@TODO not completely implemented yet
    def __init__(self, data_arrays, fieldnames, **kwargs) :
        """
        Construct an NArraysDataSet from a list of numpy tensor (data_arrays) and a list
        of fieldnames. The number of arrays must be the same as the number of
        fieldnames. Each set of numpy tensor must have the same first dimension (first
        axis) corresponding to the number of examples.

        Every tensor is treated as a numpy array (using numpy.asarray)
        """
        ArrayFieldsDataSet.__init__(self,**kwargs)
        assert len(data_arrays) == len(fieldnames)
        assert len(fieldnames) > 0
        ndarrays = [numpy.asarray(a) for a in data_arrays]
        lens = [a.shape[0] for a in ndarrays]
        num_examples = lens[0] #they must all be equal anyway
        self._fieldnames = fieldnames
        for k in ndarrays :
            assert k.shape[0] == num_examples
        self._datas = ndarrays
        # create dict 
        self.map_field_idx = dict()
        for k in range(len(fieldnames)):
            self.map_field_idx[fieldnames[k]] = k


    def __len__(self) :
        """
        Length of the dataset is based on the first array = data_arrays[0], using its shape
        """
        return self._datas[0].shape[0]

    def fieldNames(self) :
        """
        Returns the fieldnames as set in self.__init__
        """
        return self._fieldnames

    def field_pos(self,fieldname) :
        """
        Returns the index of a given fieldname. Fieldname must exists! see fieldNames().
        """
        return self.map_field_idx[fieldname]

    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
        cursor = Example(fieldnames,[0]*len(fieldnames))
        fieldnames = self.fieldNames() if fieldnames is None else fieldnames
        for n in xrange(n_batches):
            if offset == len(self):
                break
            for f in range(len(cursor._names)) :
                idx = self.field_pos(cursor._names[f])
                sub_data = self._datas[idx][offset : offset+minibatch_size]
                cursor._values[f] = sub_data
            offset += len(sub_data) #can be less than minibatch_size at end
            yield cursor

        #return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)




class ArrayDataSet(ArrayFieldsDataSet):
    """
    An ArrayDataSet stores the fields as groups of columns in a numpy tensor,
    whose first axis iterates over examples, second axis determines fields.
    If the underlying array is N-dimensional (has N axes), then the field
    values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2).
    """

    def __init__(self, data_array, fields_columns, **kwargs):
        """
        Construct an ArrayDataSet from the underlying numpy array (data) and
        a map (fields_columns) from fieldnames to field columns. The columns of a field are specified
        using the standard arguments for indexing/slicing: integer for a column index,
        slice for an interval of columns (with possible stride), or iterable of column indices.
        """
        ArrayFieldsDataSet.__init__(self, **kwargs)
        self.data=data_array
        self.fields_columns=fields_columns

        # check consistency and complete slices definitions
        for fieldname, fieldcolumns in self.fields_columns.items():
            if type(fieldcolumns) is int:
                assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
                if 1:
                    #I changed this because it didn't make sense to me,
                    # and it made it more difficult to write my learner.
                    # If it breaks stuff, let's talk about it.
                    # - James 22/05/2008
                    self.fields_columns[fieldname]=[fieldcolumns]
                else:
                    self.fields_columns[fieldname]=fieldcolumns
            elif type(fieldcolumns) is slice:
                start,step=fieldcolumns.start,fieldcolumns.step
                if not start:
                    start=0
                if not step:
                    step=1
                self.fields_columns[fieldname]=slice(start,fieldcolumns.stop,step)
            elif hasattr(fieldcolumns,"__iter__"): # something like a list
                for i in fieldcolumns:
                    assert i>=0 and i<data_array.shape[1]

    def fieldNames(self):
        return self.fields_columns.keys()

    def __len__(self):
        return len(self.data)

    def __getitem__(self,key):
        """More efficient implementation than the default __getitem__"""
        fieldnames=self.fields_columns.keys()
        values=self.fields_columns.values()
        if type(key) is int:
            return Example(fieldnames,
                           [self.data[key,col] for col in values])
        if type(key) is slice:
            return Example(fieldnames,[self.data[key,col] for col in values])
        if type(key) is list:
            for i in range(len(key)):
                if self.hasFields(key[i]):
                    key[i]=self.fields_columns[key[i]]
            return Example(fieldnames,
                               #we must separate differently for list as numpy
                               # doesn't support self.data[[i1,...],[i2,...]]
                               # when their is more then two i1 and i2
                               [self.data[key,:][:,col]
                               if isinstance(col,list) else
                               self.data[key,col] for col in values])

        # else check for a fieldname
        if self.hasFields(key):
            return self.data[:,self.fields_columns[key]]
        # else we are trying to access a property of the dataset
        assert key in self.__dict__ # else it means we are trying to access a non-existing property
        return self.__dict__[key]
        
    def dontuse__iter__(self):
        class ArrayDataSetIteratorIter(object):
            def __init__(self,dataset,fieldnames):
                if fieldnames is None: fieldnames = dataset.fieldNames()
                # store the resulting minibatch in a lookup-list of values
                self.minibatch = Example(fieldnames,[0]*len(fieldnames))
                self.dataset=dataset
                self.current=0
                self.columns = [self.dataset.fields_columns[f] 
                                for f in self.minibatch._names]
                self.l = self.dataset.data.shape[0]
            def __iter__(self):
                return self
            def next(self):
                #@todo: we suppose that we need to stop only when minibatch_size == 1.
                # Otherwise, MinibatchWrapAroundIterator do it.
                if self.current>=self.l:
                    raise StopIteration
                sub_data =  self.dataset.data[self.current]
                self.minibatch._values = [sub_data[c] for c in self.columns]

                self.current+=1
                return self.minibatch

        return ArrayDataSetIteratorIter(self,self.fieldNames())

    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
        cursor = Example(fieldnames,[0]*len(fieldnames))
        fieldnames = self.fieldNames() if fieldnames is None else fieldnames
        if n_batches == None:
            n_batches = (len(self) - offset) / minibatch_size
        for n in xrange(n_batches):
            if offset == len(self):
                break
            sub_data = self.data[offset : offset+minibatch_size]
            offset += len(sub_data) #can be less than minibatch_size at end
            cursor._values = [sub_data[:,self.fields_columns[f]] for f in cursor._names]
            yield cursor

        #return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)


class CachedDataSet(DataSet):
  """
  Wrap a L{DataSet} whose values are computationally expensive to obtain
  (e.g. because they involve some computation, or disk access),
  so that repeated accesses to the same example are done cheaply,
  by caching every example value that has been accessed at least once.

  Optionally, for finite-length dataset, all the values can be computed
  (and cached) upon construction of the CachedDataSet, rather at the
  first access.

  @todo: when cache_all_upon_construction create mini-batches that are as 
  large as possible but not so large as to fill up memory.
  
  @todo: add disk-buffering capability, so that when the cache becomes too
  big for memory, we cache things on disk, trying to keep in memory only
  the record most likely to be accessed next.
  """
  def __init__(self,source_dataset,cache_all_upon_construction=False):
      self.source_dataset=source_dataset
      self.cache_all_upon_construction=cache_all_upon_construction
      self.cached_examples = []
      if cache_all_upon_construction:
          # this potentially brings all the source examples
          # into memory at once, which may be too much
          # the work could possibly be done by minibatches
          # that are as large as possible but no more than what memory allows.
          #
          # field_values is supposed to be an DataSetFields, that inherits from LookupList
          #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
          fields_values = DataSetFields(source_dataset,None)
          assert all([len(self)==len(field_values) for field_values in fields_values])
          for example in fields_values.examples():
              self.cached_examples.append(copy.copy(example))

      self.fieldNames = source_dataset.fieldNames
      self.hasFields = source_dataset.hasFields
      self.valuesHStack = source_dataset.valuesHStack
      self.valuesVStack = source_dataset.valuesVStack
      
  def __len__(self):
      return len(self.source_dataset)

  def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
      class CacheIterator(object):
          def __init__(self,dataset):
              self.dataset=dataset
              self.current=offset
              self.all_fields = self.dataset.fieldNames()==fieldnames
              self.n_batches = n_batches
              self.batch_counter = 0
          def __iter__(self): return self
          def next(self):
              self.batch_counter += 1
              if self.n_batches and self.batch_counter > self.n_batches :
                  raise StopIteration()
              upper = self.current+minibatch_size
              if upper > len(self.dataset.source_dataset):
                  raise StopIteration()
              cache_len = len(self.dataset.cached_examples)
              if upper>cache_len: # whole minibatch is not already in cache
                  # cache everything from current length to upper
                  #for example in self.dataset.source_dataset[cache_len:upper]:
                  for example in self.dataset.source_dataset.subset[cache_len:upper]:
                      self.dataset.cached_examples.append(example)
              all_fields_minibatch = Example(self.dataset.fieldNames(),
                                             zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))

              self.current+=minibatch_size
              if self.all_fields:
                  return all_fields_minibatch
              return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
      return CacheIterator(self)

  def dontuse__getitem__(self,i):
      if type(i)==int and len(self.cached_examples)>i:
          return self.cached_examples[i]
      else:
          return self.source_dataset[i]
      
  def __iter__(self):
      class CacheIteratorIter(object):
          def __init__(self,dataset):
              self.dataset=dataset
              self.l = len(dataset)
              self.current = 0
              self.fieldnames = self.dataset.fieldNames()
              self.example = Example(self.fieldnames,[0]*len(self.fieldnames))
          def __iter__(self): return self
          def next(self):
              if self.current>=self.l:
                  raise StopIteration
              cache_len = len(self.dataset.cached_examples)
              if self.current>=cache_len: # whole minibatch is not already in cache
                  # cache everything from current length to upper
                  self.dataset.cached_examples.append(
                      self.dataset.source_dataset[self.current])
              self.example._values = self.dataset.cached_examples[self.current]
              self.current+=1
              return self.example

      return CacheIteratorIter(self)

class ApplyFunctionDataSet(DataSet):
    """
    A L{DataSet} that contains as fields the results of applying a
    given function example-wise or minibatch-wise to all the fields of
    an input dataset.  The output of the function should be an iterable
    (e.g. a list or a LookupList) over the resulting values.
    
    The function take as input the fields of the dataset, not the examples.

    In minibatch mode, the function is expected to work on minibatches
    (takes a minibatch in input and returns a minibatch in output). More
    precisely, it means that each element of the input or output list
    should be iterable and indexable over the individual example values
    (typically these elements will be numpy arrays). All of the elements
    in the input and output lists should have the same length, which is
    the length of the minibatch.

    The function is applied each time an example or a minibatch is accessed.
    To avoid re-doing computation, wrap this dataset inside a CachedDataSet.

    If the values_{h,v}stack functions are not provided, then
    the input_dataset.values{H,V}Stack functions are used by default.

    """

    def __init__(self,input_dataset,function,output_names,minibatch_mode=True,
                 values_hstack=None,values_vstack=None,
                 description=None,fieldtypes=None):
        """
        Constructor takes an input dataset that has as many fields as the function
        expects as inputs. The resulting dataset has as many fields as the function
        produces as outputs, and that should correspond to the number of output names
        (provided in a list).

        Note that the expected semantics of the function differs in minibatch mode
        (it takes minibatches of inputs and produces minibatches of outputs, as
        documented in the class comment).

        TBM: are fieldtypes the old field types (from input_dataset) or the new ones
        (for the new dataset created)?
        """
        self.input_dataset=input_dataset
        self.function=function
        self.output_names=output_names
        #print 'self.output_names in afds:', self.output_names
        #print 'length in afds:', len(self.output_names)
        self.minibatch_mode=minibatch_mode
        DataSet.__init__(self,description,fieldtypes)
        self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack
        self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack

    def __len__(self):
        return len(self.input_dataset)

    def fieldNames(self):
        return self.output_names

    def minibatches_nowrap(self, fieldnames, *args, **kwargs):
        all_input_fieldNames = self.input_dataset.fieldNames()
        mbnw = self.input_dataset.minibatches_nowrap

        for input_fields in mbnw(all_input_fieldNames, *args, **kwargs):
            if self.minibatch_mode:
                all_output_fields = self.function(*input_fields)
            else:
                input_examples = zip(*input_fields) #makes so that [i] means example i
                output_examples = [self.function(*input_example)
                                    for input_example in input_examples]
                all_output_fields = zip(*output_examples)

            #print 'output_names=', self.output_names
            #print 'all_output_fields', all_output_fields
            #print 'len(all_output_fields)=', len(all_output_fields)
            all_outputs = Example(self.output_names, all_output_fields)
            if fieldnames==self.output_names:
                rval = all_outputs
            else:
                rval = Example(fieldnames,[all_outputs[name] for name in fieldnames])
            #print 'rval', rval
            #print '--------'
            yield rval

    def untested__iter__(self): # only implemented for increased efficiency
        class ApplyFunctionSingleExampleIterator(object):
            def __init__(self,output_dataset):
                self.current=0
                self.output_dataset=output_dataset
                self.input_iterator=output_dataset.input_dataset.__iter__()
            def __iter__(self): return self
            def next(self):
                if self.output_dataset.minibatch_mode:
                    function_inputs = [[input] for input in self.input_iterator.next()]
                    outputs = self.output_dataset.function(*function_inputs)
                    assert all([hasattr(output,'__iter__') for output in outputs])
                    function_outputs = [output[0] for output in outputs]
                else:
                    function_inputs = self.input_iterator.next()
                    function_outputs = self.output_dataset.function(*function_inputs)
                return Example(self.output_dataset.output_names,function_outputs)
        return ApplyFunctionSingleExampleIterator(self)
    
def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
    """
    Wraps an arbitrary L{DataSet} into one for supervised learning tasks
    by forcing the user to define a set of fields as the 'input' field
    and a set of fields as the 'target' field. Optionally, a single
    weight_field can also be defined.
    """
    args = ((input_fields,'input'),(output_fields,'target'))
    if weight_field: args+=(([weight_field],'weight'))
    return src_dataset.merge_fields(*args)