view datasets/ftfile.py @ 239:42005ec87747

Mergé (manuellement) les changements de Sylvain pour utiliser le code de dataset d'Arnaud, à cette différence près que je n'utilse pas les givens. J'ai probablement une approche différente pour limiter la taille du dataset dans mon débuggage, aussi.
author fsavard
date Mon, 15 Mar 2010 18:30:21 -0400
parents 1faae5079522
children 966272e7f14b
line wrap: on
line source

from pylearn.io.filetensor import _read_header, _prod
import numpy, theano
from dataset import DataSet
from dsetiter import DataIterator
from itertools import izip, imap

class FTFile(object):
    def __init__(self, fname, scale=1, dtype=None):
        r"""
        Tests:
            >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
        """
        self.file = open(fname, 'rb')
        self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
        self.size = self.dim[0]
        self.scale = scale
        self.dtype = dtype

    def skip(self, num):
        r"""
        Skips `num` items in the file.

        If `num` is negative, skips size-num.

        Tests:
            >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
            >>> f.size
            58646
            >>> f.elsize
            4
            >>> f.file.tell()
            20
            >>> f.skip(1000)
            >>> f.file.tell()
            4020
            >>> f.size
            57646
            >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
            >>> f.size
            58646
            >>> f.file.tell()
            20
            >>> f.skip(-1000)
            >>> f.file.tell()
            230604
            >>> f.size
            1000
        """
        if num < 0:
            num += self.size
        if num < 0:
            raise ValueError('Skipping past the start of the file')
        if num >= self.size:
            self.size = 0
        else:
            self.size -= num
            f_start = self.file.tell()
            self.file.seek(f_start + (self.elsize * _prod(self.dim[1:]) * num))
    
    def read(self, num):
        r"""
        Reads `num` elements from the file and return the result as a
        numpy matrix.  Last read is truncated.

        Tests:
            >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
            >>> f.read(1)
            array([6], dtype=int32)
            >>> f.read(10)
            array([7, 4, 7, 5, 6, 4, 8, 0, 9, 6], dtype=int32)
            >>> f.skip(58630)
            >>> f.read(10)
            array([9, 2, 4, 2, 8], dtype=int32)
            >>> f.read(10)
            array([], dtype=int32)
            >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
            >>> f.read(1)
            array([[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)
        """
        if num > self.size:
            num = self.size
        self.dim[0] = num
        self.size -= num
        res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
        if self.dtype is not None:
            res = res.astype(self.dtype)
        if self.scale != 1:
            res /= self.scale
        return res

class FTSource(object):
    def __init__(self, file, skip=0, size=None, dtype=None, scale=1):
        r"""
        Create a data source from a possible subset of a .ft file.

        Parameters:
            `file` (string) -- the filename
            `skip` (int, optional) -- amount of examples to skip from
                                      the start of the file.  If
                                      negative, skips filesize - skip.
            `size` (int, optional) -- truncates number of examples
                                      read (after skipping).  If
                                      negative truncates to 
                                      filesize - size 
                                      (also after skipping).
            `dtype` (dtype, optional) -- convert the data to this
                                         dtype after reading.
            `scale` (number, optional) -- scale (that is divide) the
                                          data by this number (after
                                          dtype conversion, if any).

        Tests:
           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
        """
        self.file = file
        self.skip = skip
        self.size = size
        self.dtype = dtype
        self.scale = scale
    
    def open(self):
        r"""
        Returns an FTFile that corresponds to this dataset.
        
        Tests:
           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
           >>> f = s.open()
           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
           >>> len(s.open().read(2))
           1
           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
           >>> s.open().size
           1000
           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
           >>> s.open().size
           1
           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
           >>> s.open().size
           58636
        """
        f = FTFile(self.file, scale=self.scale, dtype=self.dtype)
        if self.skip != 0:
            f.skip(self.skip)
        if self.size is not None and self.size < f.size:
            if self.size < 0:
                f.size += self.size
            else:
                f.size = self.size
        return f

class FTData(object):
    r"""
    This is a list of FTSources.
    """
    def __init__(self, datafiles, labelfiles, skip=0, size=None,
                 inscale=1, indtype=None, outscale=1, outdtype=None):
        self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype)
                       for f in  datafiles]
        self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype)
                        for f in labelfiles]

    def open_inputs(self):
        return [f.open() for f in self.inputs]

    def open_outputs(self):
        return [f.open() for f in self.outputs]
    

class FTDataSet(DataSet):
    def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1):
        r"""
        Defines a DataSet from a bunch of files.
        
        Parameters:
           `train_data` -- list of train data files
           `train_label` -- list of train label files (same length as `train_data`)
           `test_data`, `test_labels` -- same thing as train, but for
                                         test.  The number of files
                                         can differ from train.
           `valid_data`, `valid_labels` -- same thing again for validation.
                                           (optional)
           `indtype`, `outdtype`,  -- see FTSource.__init__()
           `inscale`, `outscale`      (optional)
                                                             

        If `valid_data` and `valid_labels` are not supplied then a sample
        approximately equal in size to the test set is taken from the train 
        set.
        """
        if valid_data is None:
            total_valid_size = sum(FTFile(td).size for td in test_data)
            valid_size = total_valid_size/len(train_data)
            self._train = FTData(train_data, train_lbl, size=-valid_size,
                    inscale=inscale, outscale=outscale, indtype=indtype,
                    outdtype=outdtype)
            self._valid = FTData(train_data, train_lbl, skip=-valid_size,
                    inscale=inscale, outscale=outscale, indtype=indtype, 
                    outdtype=outdtype)
        else:
            self._train = FTData(train_data, train_lbl,inscale=inscale,
                    outscale=outscale, indtype=indtype, outdtype=outdtype)
            self._valid = FTData(valid_data, valid_lbl,inscale=inscale,
                    outscale=outscale, indtype=indtype, outdtype=outdtype)
        self._test = FTData(test_data, test_lbl,inscale=inscale,
                outscale=outscale, indtype=indtype, outdtype=outdtype)

    def _return_it(self, batchsize, bufsize, ftdata):
        return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),
                    DataIterator(ftdata.open_outputs(), batchsize, bufsize))