Mercurial > pylearn
comparison dataset.py @ 223:517364d48ae0
should have solved the problem with minibatches not handling subsets of fieldnames, although maybe not super efficient
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Fri, 23 May 2008 16:01:01 -0400 |
parents | df3fae88ab46 |
children | 17c5d080964b |
comparison
equal
deleted
inserted
replaced
222:f6a7eb1b7970 | 223:517364d48ae0 |
---|---|
667 | 667 |
668 self._fields=fields_lookuplist | 668 self._fields=fields_lookuplist |
669 assert len(fields_lookuplist)>0 | 669 assert len(fields_lookuplist)>0 |
670 self.length=len(fields_lookuplist[0]) | 670 self.length=len(fields_lookuplist[0]) |
671 for field in fields_lookuplist[1:]: | 671 for field in fields_lookuplist[1:]: |
672 if self.length != len(field) : | |
673 print 'self.length = ',self.length | |
674 print 'len(field) = ', len(field) | |
675 print 'self._fields.keys() = ', self._fields.keys() | |
676 print 'field=',field | |
672 assert self.length==len(field) | 677 assert self.length==len(field) |
673 self.values_vstack=values_vstack | 678 self.values_vstack=values_vstack |
674 self.values_hstack=values_hstack | 679 self.values_hstack=values_hstack |
675 | 680 |
676 def __len__(self): | 681 def __len__(self): |
695 if fieldname not in self._fields.keys(): | 700 if fieldname not in self._fields.keys(): |
696 return False | 701 return False |
697 return True | 702 return True |
698 | 703 |
699 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 704 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
705 #@TODO bug somewhere here, fieldnames doesnt seem to be well handled | |
700 class Iterator(object): | 706 class Iterator(object): |
701 def __init__(self,ds): | 707 def __init__(self,ds,fieldnames): |
708 # tbm: added two next lines to handle fieldnames | |
709 if fieldnames is None: fieldnames = ds._fields.keys() | |
710 self.fieldnames = fieldnames | |
711 | |
702 self.ds=ds | 712 self.ds=ds |
703 self.next_example=offset | 713 self.next_example=offset |
704 assert minibatch_size > 0 | 714 assert minibatch_size > 0 |
705 if offset+minibatch_size > ds.length: | 715 if offset+minibatch_size > ds.length: |
706 raise NotImplementedError() | 716 raise NotImplementedError() |
707 def __iter__(self): | 717 def __iter__(self): |
708 return self | 718 return self |
709 def next(self): | 719 def next(self): |
710 upper = self.next_example+minibatch_size | 720 upper = self.next_example+minibatch_size |
711 assert upper<=self.ds.length | 721 assert upper<=self.ds.length |
712 minibatch = Example(self.ds._fields.keys(), | 722 #minibatch = Example(self.ds._fields.keys(), |
713 [field[self.next_example:upper] | 723 # [field[self.next_example:upper] |
714 for field in self.ds._fields]) | 724 # for field in self.ds._fields]) |
725 # tbm: modif to use fieldnames | |
726 values = [] | |
727 for f in self.fieldnames : | |
728 #print 'we have field',f,'in fieldnames' | |
729 values.append( self.ds._fields[f][self.next_example:upper] ) | |
730 minibatch = Example(self.fieldnames,values) | |
731 #print minibatch | |
715 self.next_example+=minibatch_size | 732 self.next_example+=minibatch_size |
716 return minibatch | 733 return minibatch |
717 | 734 |
718 return Iterator(self) | 735 # tbm: added fieldnames to handle subset of fieldnames |
736 return Iterator(self,fieldnames) | |
719 | 737 |
720 def valuesVStack(self,fieldname,fieldvalues): | 738 def valuesVStack(self,fieldname,fieldvalues): |
721 return self.values_vstack(fieldname,fieldvalues) | 739 return self.values_vstack(fieldname,fieldvalues) |
722 | 740 |
723 def valuesHStack(self,fieldnames,fieldvalues): | 741 def valuesHStack(self,fieldnames,fieldvalues): |