# HG changeset patch # User Thierry Bertin-Mahieux # Date 1211572861 14400 # Node ID 517364d48ae0fe2ad48a6b11bf3a8abf804cda28 # Parent f6a7eb1b797098cde134fd64c46c9579d21d7299 should have solved the problem with minibatches not handling subsets of fieldnames, although maybe not super efficient diff -r f6a7eb1b7970 -r 517364d48ae0 dataset.py --- a/dataset.py Fri May 23 14:16:54 2008 -0400 +++ b/dataset.py Fri May 23 16:01:01 2008 -0400 @@ -669,6 +669,11 @@ assert len(fields_lookuplist)>0 self.length=len(fields_lookuplist[0]) for field in fields_lookuplist[1:]: + if self.length != len(field) : + print 'self.length = ',self.length + print 'len(field) = ', len(field) + print 'self._fields.keys() = ', self._fields.keys() + print 'field=',field assert self.length==len(field) self.values_vstack=values_vstack self.values_hstack=values_hstack @@ -697,8 +702,13 @@ return True def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): + #@TODO bug somewhere here, fieldnames doesnt seem to be well handled class Iterator(object): - def __init__(self,ds): + def __init__(self,ds,fieldnames): + # tbm: added two next lines to handle fieldnames + if fieldnames is None: fieldnames = ds._fields.keys() + self.fieldnames = fieldnames + self.ds=ds self.next_example=offset assert minibatch_size > 0 @@ -709,13 +719,21 @@ def next(self): upper = self.next_example+minibatch_size assert upper<=self.ds.length - minibatch = Example(self.ds._fields.keys(), - [field[self.next_example:upper] - for field in self.ds._fields]) + #minibatch = Example(self.ds._fields.keys(), + # [field[self.next_example:upper] + # for field in self.ds._fields]) + # tbm: modif to use fieldnames + values = [] + for f in self.fieldnames : + #print 'we have field',f,'in fieldnames' + values.append( self.ds._fields[f][self.next_example:upper] ) + minibatch = Example(self.fieldnames,values) + #print minibatch self.next_example+=minibatch_size return minibatch - return Iterator(self) + # tbm: added fieldnames to handle subset of fieldnames + return Iterator(self,fieldnames) def valuesVStack(self,fieldname,fieldvalues): return self.values_vstack(fieldname,fieldvalues)