changeset 163:4b28d7382dbf

Add inital implementation of datasets. For the moment only nist_digits is defined.
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 25 Feb 2010 18:40:01 -0500
parents 050c7ff6b449
children e3de934a98b6
files datasets/__init__.py datasets/dataset.py datasets/dsetiter.py datasets/ftfile.py datasets/nist.py
diffstat 5 files changed, 382 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/datasets/__init__.py	Thu Feb 25 18:40:01 2010 -0500
@@ -0,0 +1,1 @@
+from nist import *
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/datasets/dataset.py	Thu Feb 25 18:40:01 2010 -0500
@@ -0,0 +1,46 @@
+from dsetiter import DataIterator
+
+class DataSet(object):
+    def test(self, batchsize, bufsize=None): 
+        r"""
+        Returns an iterator over the test examples.
+
+        Parameters
+          batchsize (int) -- the size of the minibatches, 0 means
+                             return the whole set at once.
+          bufsize (int, optional) -- the size of the in-memory buffer,
+                                     0 to disable.
+        """
+        return self._return_it(batchsize, bufsize, self._test)
+
+    def train(self, batchsize, bufsize=None):
+        r"""
+        Returns an iterator over the training examples.
+
+        Parameters
+          batchsize (int) -- the size of the minibatches, 0 means
+                             return the whole set at once.
+          bufsize (int, optional) -- the size of the in-memory buffer,
+                                     0 to disable.
+        """
+        return self._return_it(batchsize, bufsize, self._train)
+
+    def valid(self, batchsize, bufsize=None):
+        r"""
+        Returns an iterator over the validation examples.
+
+        Parameters
+          batchsize (int) -- the size of the minibatches, 0 means
+                             return the whole set at once.
+          bufsize (int, optional) -- the size of the in-memory buffer,
+                                     0 to disable.
+        """
+        return self._return_it(batchsize, bufsize, self._valid)
+
+    def _return_it(batchsize, bufsize, data):
+        r"""
+        Must return an iterator over the specified dataset (`data`).
+
+        Implement this in subclassses.
+        """
+        raise NotImplemented
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/datasets/dsetiter.py	Thu Feb 25 18:40:01 2010 -0500
@@ -0,0 +1,156 @@
+import numpy
+
+class DummyFile(object):
+    def __init__(self, size):
+        self.size = size
+
+    def read(self, num):
+        if num > self.size:
+            num = self.size
+        self.size -= num
+        return numpy.zeros((num, 3, 2))
+
+class DataIterator(object):
+    
+    def __init__(self, files, batchsize, bufsize=None):
+        r""" 
+        Makes an iterator which will read examples from `files`
+        and return them in `batchsize` lots.
+
+        Parameters: 
+            files -- list of numpy readers
+            batchsize -- (int) the size of returned batches
+            bufsize -- (int, default=None) internal read buffer size.
+
+        Tests:
+            >>> d = DataIterator([DummyFile(930)], 10, 100)
+            >>> d.batchsize
+            10
+            >>> d.bufsize
+            100
+            >>> d = DataIterator([DummyFile(1)], 10)
+            >>> d.batchsize
+            10
+            >>> d.bufsize
+            10000
+            >>> d = DataIterator([DummyFile(1)], 99)
+            >>> d.batchsize
+            99
+            >>> d.bufsize
+            9999
+            >>> d = DataIterator([DummyFile(1)], 10, 121)
+            >>> d.batchsize
+            10
+            >>> d.bufsize
+            120
+            >>> d = DataIterator([DummyFile(1)], 10, 1)
+            >>> d.batchsize
+            10
+            >>> d.bufsize
+            10
+            >>> d = DataIterator([DummyFile(1)], 2000)
+            >>> d.batchsize
+            2000
+            >>> d.bufsize
+            20000
+            >>> d = DataIterator([DummyFile(1)], 2000, 31254)
+            >>> d.batchsize
+            2000
+            >>> d.bufsize
+            30000
+            >>> d = DataIterator([DummyFile(1)], 2000, 10)
+            >>> d.batchsize
+            2000
+            >>> d.bufsize
+            2000
+        """
+        self.batchsize = batchsize
+        if bufsize is None:
+            self.bufsize = max(10*batchsize, 10000)
+        else:
+            self.bufsize = bufsize
+        self.bufsize -= self.bufsize % self.batchsize
+        if self.bufsize < self.batchsize:
+            self.bufsize = self.batchsize
+        self.files = iter(files)
+        self.curfile = self.files.next()
+        self.empty = False
+        self._fill_buf()
+
+    def _fill_buf(self):
+        r"""
+        Fill the internal buffer.
+
+        Will fill across files in case the current one runs out.
+
+        Test:
+            >>> d = DataIterator([DummyFile(20)], 10, 10)
+            >>> d._fill_buf()
+            >>> d.curpos
+            0
+            >>> len(d.buffer)
+            10
+            >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10)
+            >>> d._fill_buf()
+            >>> len(d.buffer)
+            10
+            >>> d._fill_buf()
+            Traceback (most recent call last):
+              ...
+            StopIteration
+            >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10)
+            >>> d._fill_buf()
+            >>> len(d.buffer)
+            9
+            >>> d._fill_buf()
+            Traceback (most recent call last):
+              ...
+            StopIteration
+        """
+        if self.empty:
+            raise StopIteration
+        self.buffer = self.curfile.read(self.bufsize)
+        
+        while len(self.buffer) < self.bufsize:
+            try:
+                self.curfile = self.files.next()
+            except StopIteration:
+                self.empty = True
+                if len(self.buffer) == 0:
+                    raise StopIteration
+                self.curpos = 0
+                return
+            tmpbuf = self.curfile.read(self.bufsize - len(self.buffer))
+            self.buffer = numpy.row_stack((self.buffer, tmpbuf))
+        self.curpos = 0
+
+    def __next__(self):
+        r"""
+        Returns the next portion of the dataset.
+
+        Test:
+            >>> d = DataIterator([DummyFile(20)], 10, 20)
+            >>> len(d.next())
+            10
+            >>> len(d.next())
+            10
+            >>> d.next()
+            Traceback (most recent call last):
+              ...
+            StopIteration
+            >>> d.next()
+            Traceback (most recent call last):
+              ...
+            StopIteration
+            
+        """
+        if self.curpos >= self.bufsize:
+            self._fill_buf()
+        res = self.buffer[self.curpos:self.curpos+self.batchsize]
+        self.curpos += self.batchsize
+        return res
+
+    next = __next__
+
+    def __iter__(self):
+        return self
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/datasets/ftfile.py	Thu Feb 25 18:40:01 2010 -0500
@@ -0,0 +1,156 @@
+from pylearn.io.filetensor import _read_header, _prod
+import numpy
+from dataset import DataSet
+from dsetiter import DataIterator
+
+class FTFile(object):
+    def __init__(self, fname):
+        r"""
+        Tests:
+            >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
+        """
+        self.file = open(fname, 'rb')
+        self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
+        self.size = self.dim[0]
+
+    def skip(self, num):
+        r"""
+        Skips `num` items in the file.
+
+        Tests:
+            >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
+            >>> f.size
+            58646
+            >>> f.elsize
+            4
+            >>> f.file.tell()
+            20
+            >>> f.skip(1000)
+            >>> f.file.tell()
+            4020
+            >>> f.size
+            57646
+        """
+        if num >= self.size:
+            self.size = 0
+        else:
+            self.size -= num
+            f_start = self.file.tell()
+            self.file.seek(f_start + (self.elsize * _prod(self.dim[1:]) * num))
+    
+    def read(self, num):
+        r"""
+        Reads `num` elements from the file and return the result as a
+        numpy matrix.  Last read is truncated.
+
+        Tests:
+            >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
+            >>> f.read(1)
+            array([6], dtype=int32)
+            >>> f.read(10)
+            array([7, 4, 7, 5, 6, 4, 8, 0, 9, 6], dtype=int32)
+            >>> f.skip(58630)
+            >>> f.read(10)
+            array([9, 2, 4, 2, 8], dtype=int32)
+            >>> f.read(10)
+            array([], dtype=int32)
+            >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+            >>> f.read(1)
+            array([[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)
+        """
+        if num > self.size:
+            num = self.size
+        self.dim[0] = num
+        self.size -= num
+        return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
+
+class FTSource(object):
+    def __init__(self, file, skip=0, size=None):
+        r"""
+        Create a data source from a possible subset of a .ft file.
+
+        Parameters:
+            `file` (string) -- the filename
+            `skip` (int, optional) -- amount of examples to skip from the start of the file
+            `size` (int, optional) -- truncates number of examples read (after skipping)
+        
+        Tests:
+           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
+           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
+           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
+        """
+        self.file = file
+        self.skip = skip
+        self.size = size
+    
+    def open(self):
+        r"""
+        Returns an FTFile that corresponds to this dataset.
+        
+        Tests:
+           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
+           >>> f = s.open()
+           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
+           >>> len(s.open().read(2))
+           1
+           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
+           >>> s.open().size
+           1000
+           >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
+           >>> s.open().size
+           1
+        """
+        f = FTFile(self.file)
+        if self.skip != 0:
+            f.skip(self.skip)
+        if self.size is not None and self.size < f.size:
+            f.size = self.size
+        return f
+
+class FTData(object):
+    r"""
+    This is a list of FTSources.
+    """
+    def __init__(self, datafiles, labelfiles, skip=0, size=None):
+        self.inputs = [FTSource(f, skip, size) for f in  datafiles]
+        self.outputs = [FTSource(f, skip, size) for f in labelfiles]
+
+    def open_inputs(self):
+        return [f.open() for f in self.inputs]
+
+    def open_outputs(self):
+        return [f.open() for f in self.outputs]
+    
+
+class FTDataSet(DataSet):
+    def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None):
+        r"""
+        Defines a DataSet from a bunch of files.
+        
+        Parameters:
+           `train_data` -- list of train data files
+           `train_label` -- list of train label files (same length as `train_data`)
+           `test_data`, `test_labels` -- same thing as train, but for
+                                         test.  The number of files
+                                         can differ from train.
+           `valid_data`, `valid_labels` -- same thing again for validation.
+                                           (optional)
+
+        If `valid_data` and `valid_labels` are not supplied then a sample
+        approximately equal in size to the test set is taken from the train 
+        set.
+        """
+        if valid_data is None:
+            total_valid_size = sum(FTFile(td).size for td in test_data)
+            valid_size = total_valid_size/len(train_data)
+            self._train = FTData(train_data, train_lbl, skip=valid_size)
+            self._valid = FTData(train_data, train_lbl, size=valid_size)
+        else:
+            self._train = FTData(train_data, train_lbl)
+            self._valid = FTData(valid_data, valid_lbl)
+        self._test = FTData(test_data, test_lbl)
+
+    def _return_it(self, batchsize, bufsize, ftdata):
+        return zip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),
+                   DataIterator(ftdata.open_outputs(), batchsize, bufsize))
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/datasets/nist.py	Thu Feb 25 18:40:01 2010 -0500
@@ -0,0 +1,23 @@
+__all__ = ['nist_digits']
+
+from ftfile import FTDataSet
+
+PATH = '/data/lisa/data/nist/by_class/'
+
+nist_digits = FTDataSet(train_data = [PATH+'digits/digits_train_data.ft'],
+                        train_lbl = [PATH+'digits/digits_train_labels.ft'],
+                        test_data = [PATH+'digits/digits_test_data.ft'],
+                        test_lbl = [PATH+'digits/digits_test_labels.ft'])
+nist_lower = FTDataSet(train_data = [PATH+'lower/lower_train_data.ft'],
+                        train_lbl = [PATH+'lower/lower_train_labels.ft'],
+                        test_data = [PATH+'lower/lower_test_data.ft'],
+                        test_lbl = [PATH+'lower/lower_test_labels.ft'])
+nist_upper = FTDataSet(train_data = [PATH+'upper/upper_train_data.ft'],
+                        train_lbl = [PATH+'upper/upper_train_labels.ft'],
+                        test_data = [PATH+'upper/upper_test_data.ft'],
+                        test_lbl = [PATH+'upper/upper_test_labels.ft'])
+nist_all = FTDataSet(train_data = [PATH+'all/all_train_data.ft'],
+                        train_lbl = [PATH+'all/all_train_labels.ft'],
+                        test_data = [PATH+'all/all_test_data.ft'],
+                        test_lbl = [PATH+'all/all_test_labels.ft'])
+