diff dataset.py @ 223:517364d48ae0

should have solved the problem with minibatches not handling subsets of fieldnames, although maybe not super efficient
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Fri, 23 May 2008 16:01:01 -0400
parents df3fae88ab46
children 17c5d080964b
line wrap: on
line diff
--- a/dataset.py	Fri May 23 14:16:54 2008 -0400
+++ b/dataset.py	Fri May 23 16:01:01 2008 -0400
@@ -669,6 +669,11 @@
         assert len(fields_lookuplist)>0
         self.length=len(fields_lookuplist[0])
         for field in fields_lookuplist[1:]:
+            if self.length != len(field) :
+                print 'self.length = ',self.length
+                print 'len(field) = ', len(field)
+                print 'self._fields.keys() = ', self._fields.keys()
+                print 'field=',field
             assert self.length==len(field)
         self.values_vstack=values_vstack
         self.values_hstack=values_hstack
@@ -697,8 +702,13 @@
         return True
 
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+        #@TODO bug somewhere here, fieldnames doesnt seem to be well handled
         class Iterator(object):
-            def __init__(self,ds):
+            def __init__(self,ds,fieldnames):
+                # tbm: added two next lines to handle fieldnames
+                if fieldnames is None: fieldnames = ds._fields.keys()
+                self.fieldnames = fieldnames
+
                 self.ds=ds
                 self.next_example=offset
                 assert minibatch_size > 0
@@ -709,13 +719,21 @@
             def next(self):
                 upper = self.next_example+minibatch_size
                 assert upper<=self.ds.length
-                minibatch = Example(self.ds._fields.keys(),
-                                    [field[self.next_example:upper]
-                                     for field in self.ds._fields])
+                #minibatch = Example(self.ds._fields.keys(),
+                #                    [field[self.next_example:upper]
+                #                     for field in self.ds._fields])
+                # tbm: modif to use fieldnames
+                values = []
+                for f in self.fieldnames :
+                    #print 'we have field',f,'in fieldnames'
+                    values.append( self.ds._fields[f][self.next_example:upper] )
+                minibatch = Example(self.fieldnames,values)
+                #print minibatch
                 self.next_example+=minibatch_size
                 return minibatch
 
-        return Iterator(self)
+        # tbm: added fieldnames to handle subset of fieldnames
+        return Iterator(self,fieldnames)
 
     def valuesVStack(self,fieldname,fieldvalues):
         return self.values_vstack(fieldname,fieldvalues)