diff dataset.py @ 226:3595ba2610f7

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 23 May 2008 17:12:12 -0400
parents 517364d48ae0
children 17c5d080964b
line wrap: on
line diff
--- a/dataset.py	Fri May 23 17:11:39 2008 -0400
+++ b/dataset.py	Fri May 23 17:12:12 2008 -0400
@@ -245,8 +245,7 @@
             if n_batches is not None:
                 ds_nbatches = min(n_batches,ds_nbatches)
             if fieldnames:
-                if not dataset.hasFields(*fieldnames):
-                    raise ValueError('field not present', fieldnames)
+                assert dataset.hasFields(*fieldnames)
             else:
                 self.fieldnames=dataset.fieldNames()
             self.iterator = self.dataset.minibatches_nowrap(self.fieldnames,self.minibatch_size,
@@ -670,6 +669,11 @@
         assert len(fields_lookuplist)>0
         self.length=len(fields_lookuplist[0])
         for field in fields_lookuplist[1:]:
+            if self.length != len(field) :
+                print 'self.length = ',self.length
+                print 'len(field) = ', len(field)
+                print 'self._fields.keys() = ', self._fields.keys()
+                print 'field=',field
             assert self.length==len(field)
         self.values_vstack=values_vstack
         self.values_hstack=values_hstack
@@ -698,8 +702,13 @@
         return True
 
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+        #@TODO bug somewhere here, fieldnames doesnt seem to be well handled
         class Iterator(object):
-            def __init__(self,ds):
+            def __init__(self,ds,fieldnames):
+                # tbm: added two next lines to handle fieldnames
+                if fieldnames is None: fieldnames = ds._fields.keys()
+                self.fieldnames = fieldnames
+
                 self.ds=ds
                 self.next_example=offset
                 assert minibatch_size > 0
@@ -710,13 +719,21 @@
             def next(self):
                 upper = self.next_example+minibatch_size
                 assert upper<=self.ds.length
-                minibatch = Example(self.ds._fields.keys(),
-                                    [field[self.next_example:upper]
-                                     for field in self.ds._fields])
+                #minibatch = Example(self.ds._fields.keys(),
+                #                    [field[self.next_example:upper]
+                #                     for field in self.ds._fields])
+                # tbm: modif to use fieldnames
+                values = []
+                for f in self.fieldnames :
+                    #print 'we have field',f,'in fieldnames'
+                    values.append( self.ds._fields[f][self.next_example:upper] )
+                minibatch = Example(self.fieldnames,values)
+                #print minibatch
                 self.next_example+=minibatch_size
                 return minibatch
 
-        return Iterator(self)
+        # tbm: added fieldnames to handle subset of fieldnames
+        return Iterator(self,fieldnames)
 
     def valuesVStack(self,fieldname,fieldvalues):
         return self.values_vstack(fieldname,fieldvalues)
@@ -970,16 +987,7 @@
         for fieldname, fieldcolumns in self.fields_columns.items():
             if type(fieldcolumns) is int:
                 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
-
-                if 0:
-                    #I changed this because it didn't make sense to me,
-                    # and it made it more difficult to write my learner.
-                    # If it breaks stuff, let's talk about it.
-                    # - James 22/05/2008
-                    self.fields_columns[fieldname]=[fieldcolumns]
-                else:
-                    self.fields_columns[fieldname]=fieldcolumns
-
+                self.fields_columns[fieldname]=[fieldcolumns]
             elif type(fieldcolumns) is slice:
                 start,step=None,None
                 if not fieldcolumns.start: