Mercurial > pylearn
diff dataset.py @ 223:517364d48ae0
should have solved the problem with minibatches not handling subsets of fieldnames, although maybe not super efficient
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 23 May 2008 16:01:01 -0400 |
parents | df3fae88ab46 |
children | 17c5d080964b |
line wrap: on
line diff
--- a/dataset.py Fri May 23 14:16:54 2008 -0400 +++ b/dataset.py Fri May 23 16:01:01 2008 -0400 @@ -669,6 +669,11 @@ assert len(fields_lookuplist)>0 self.length=len(fields_lookuplist[0]) for field in fields_lookuplist[1:]: + if self.length != len(field) : + print 'self.length = ',self.length + print 'len(field) = ', len(field) + print 'self._fields.keys() = ', self._fields.keys() + print 'field=',field assert self.length==len(field) self.values_vstack=values_vstack self.values_hstack=values_hstack @@ -697,8 +702,13 @@ return True def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): + #@TODO bug somewhere here, fieldnames doesnt seem to be well handled class Iterator(object): - def __init__(self,ds): + def __init__(self,ds,fieldnames): + # tbm: added two next lines to handle fieldnames + if fieldnames is None: fieldnames = ds._fields.keys() + self.fieldnames = fieldnames + self.ds=ds self.next_example=offset assert minibatch_size > 0 @@ -709,13 +719,21 @@ def next(self): upper = self.next_example+minibatch_size assert upper<=self.ds.length - minibatch = Example(self.ds._fields.keys(), - [field[self.next_example:upper] - for field in self.ds._fields]) + #minibatch = Example(self.ds._fields.keys(), + # [field[self.next_example:upper] + # for field in self.ds._fields]) + # tbm: modif to use fieldnames + values = [] + for f in self.fieldnames : + #print 'we have field',f,'in fieldnames' + values.append( self.ds._fields[f][self.next_example:upper] ) + minibatch = Example(self.fieldnames,values) + #print minibatch self.next_example+=minibatch_size return minibatch - return Iterator(self) + # tbm: added fieldnames to handle subset of fieldnames + return Iterator(self,fieldnames) def valuesVStack(self,fieldname,fieldvalues): return self.values_vstack(fieldname,fieldvalues)