# HG changeset patch # User bengioy@grenat.iro.umontreal.ca # Date 1209482736 14400 # Node ID a5c70dc42972de2691a649c1e370eb9bf115dc25 # Parent 5a85fda9b19b3455d7db050350a1beb18890c85d Test functions for dataset.py diff -r 5a85fda9b19b -r a5c70dc42972 dataset.py --- a/dataset.py Mon Apr 28 13:52:54 2008 -0400 +++ b/dataset.py Tue Apr 29 11:25:36 2008 -0400 @@ -4,6 +4,7 @@ from misc import unique_elements_list_intersection from string import join from sys import maxint +import numpy class AbstractFunction (Exception): """Derived class must override this function""" class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" @@ -376,7 +377,7 @@ rows = i if rows is not None: fields_values = zip(*[self[row] for row in rows]) - return DataSet.MinibatchDataSet( + return MinibatchDataSet( Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) for fieldname,field_values in zip(self.fieldNames(),fields_values)])) diff -r 5a85fda9b19b -r a5c70dc42972 lookup_list.py --- a/lookup_list.py Mon Apr 28 13:52:54 2008 -0400 +++ b/lookup_list.py Tue Apr 29 11:25:36 2008 -0400 @@ -46,9 +46,11 @@ The key in example[key] can either be an integer to index the fields or the name of the field. """ - if isinstance(key,int): + if isinstance(key,int) or isinstance(key,slice) or isinstance(key,list): return self._values[key] else: # if not an int, key must be a name + # expecting key to be a valid field name + assert isinstance(key,str) return self._values[self._name2index[key]] def __setitem__(self,key,value): diff -r 5a85fda9b19b -r a5c70dc42972 test_dataset.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test_dataset.py Tue Apr 29 11:25:36 2008 -0400 @@ -0,0 +1,24 @@ + +from dataset import * +from math import * +import numpy + +def test1(): + global a,ds + a = numpy.random.rand(10,4) + print a + ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]}) + print "len(ds)=",len(ds) + print "example 0 = ",ds[0] + print "x=",ds["x"] + print "x|y" + for x,y in ds("x","y"): + print x,y + minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4) + minibatch = minibatch_iterator.__iter__().next() + print "minibatch=",minibatch + for var in minibatch: + print "var=",var + print "take a slice:",ds[1:6:2] + +test1()