comparison dataset.py @ 223:517364d48ae0

should have solved the problem with minibatches not handling subsets of fieldnames, although maybe not super efficient
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Fri, 23 May 2008 16:01:01 -0400
parents df3fae88ab46
children 17c5d080964b
comparison
equal deleted inserted replaced
222:f6a7eb1b7970 223:517364d48ae0
667 667
668 self._fields=fields_lookuplist 668 self._fields=fields_lookuplist
669 assert len(fields_lookuplist)>0 669 assert len(fields_lookuplist)>0
670 self.length=len(fields_lookuplist[0]) 670 self.length=len(fields_lookuplist[0])
671 for field in fields_lookuplist[1:]: 671 for field in fields_lookuplist[1:]:
672 if self.length != len(field) :
673 print 'self.length = ',self.length
674 print 'len(field) = ', len(field)
675 print 'self._fields.keys() = ', self._fields.keys()
676 print 'field=',field
672 assert self.length==len(field) 677 assert self.length==len(field)
673 self.values_vstack=values_vstack 678 self.values_vstack=values_vstack
674 self.values_hstack=values_hstack 679 self.values_hstack=values_hstack
675 680
676 def __len__(self): 681 def __len__(self):
695 if fieldname not in self._fields.keys(): 700 if fieldname not in self._fields.keys():
696 return False 701 return False
697 return True 702 return True
698 703
699 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 704 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
705 #@TODO bug somewhere here, fieldnames doesnt seem to be well handled
700 class Iterator(object): 706 class Iterator(object):
701 def __init__(self,ds): 707 def __init__(self,ds,fieldnames):
708 # tbm: added two next lines to handle fieldnames
709 if fieldnames is None: fieldnames = ds._fields.keys()
710 self.fieldnames = fieldnames
711
702 self.ds=ds 712 self.ds=ds
703 self.next_example=offset 713 self.next_example=offset
704 assert minibatch_size > 0 714 assert minibatch_size > 0
705 if offset+minibatch_size > ds.length: 715 if offset+minibatch_size > ds.length:
706 raise NotImplementedError() 716 raise NotImplementedError()
707 def __iter__(self): 717 def __iter__(self):
708 return self 718 return self
709 def next(self): 719 def next(self):
710 upper = self.next_example+minibatch_size 720 upper = self.next_example+minibatch_size
711 assert upper<=self.ds.length 721 assert upper<=self.ds.length
712 minibatch = Example(self.ds._fields.keys(), 722 #minibatch = Example(self.ds._fields.keys(),
713 [field[self.next_example:upper] 723 # [field[self.next_example:upper]
714 for field in self.ds._fields]) 724 # for field in self.ds._fields])
725 # tbm: modif to use fieldnames
726 values = []
727 for f in self.fieldnames :
728 #print 'we have field',f,'in fieldnames'
729 values.append( self.ds._fields[f][self.next_example:upper] )
730 minibatch = Example(self.fieldnames,values)
731 #print minibatch
715 self.next_example+=minibatch_size 732 self.next_example+=minibatch_size
716 return minibatch 733 return minibatch
717 734
718 return Iterator(self) 735 # tbm: added fieldnames to handle subset of fieldnames
736 return Iterator(self,fieldnames)
719 737
720 def valuesVStack(self,fieldname,fieldvalues): 738 def valuesVStack(self,fieldname,fieldvalues):
721 return self.values_vstack(fieldname,fieldvalues) 739 return self.values_vstack(fieldname,fieldvalues)
722 740
723 def valuesHStack(self,fieldnames,fieldvalues): 741 def valuesHStack(self,fieldnames,fieldvalues):