comparison dataset.py @ 435:eac0a7d44ff0

merge
author Olivier Breuleux <breuleuo@iro.umontreal.ca>
date Mon, 04 Aug 2008 16:29:30 -0400
parents 52b4908d8971
children 739612d316a4 ce6b4fd3ab29
comparison
equal deleted inserted replaced
434:0f366ecb11ee 435:eac0a7d44ff0
218 len(dataset) returns the number of examples in the dataset. 218 len(dataset) returns the number of examples in the dataset.
219 By default, a DataSet is a 'stream', i.e. it has an unbounded length (sys.maxint). 219 By default, a DataSet is a 'stream', i.e. it has an unbounded length (sys.maxint).
220 Sub-classes which implement finite-length datasets should redefine this method. 220 Sub-classes which implement finite-length datasets should redefine this method.
221 Some methods only make sense for finite-length datasets. 221 Some methods only make sense for finite-length datasets.
222 """ 222 """
223 return None 223 from sys import maxint
224 return maxint
224 225
225 226
226 class MinibatchToSingleExampleIterator(object): 227 class MinibatchToSingleExampleIterator(object):
227 """ 228 """
228 Converts the result of minibatch iterator with minibatch_size==1 into 229 Converts the result of minibatch iterator with minibatch_size==1 into
941 self.fieldname2dataset[fieldname]=i 942 self.fieldname2dataset[fieldname]=i
942 for fieldname,i in names_to_change: 943 for fieldname,i in names_to_change:
943 del self.fieldname2dataset[fieldname] 944 del self.fieldname2dataset[fieldname]
944 self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i 945 self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i
945 946
947 def __len__(self):
948 return len(self.datasets[0])
949
946 def hasFields(self,*fieldnames): 950 def hasFields(self,*fieldnames):
947 for fieldname in fieldnames: 951 for fieldname in fieldnames:
948 if not fieldname in self.fieldname2dataset: 952 if not fieldname in self.fieldname2dataset:
949 return False 953 return False
950 return True 954 return True
1221 # - James 22/05/2008 1225 # - James 22/05/2008
1222 self.fields_columns[fieldname]=[fieldcolumns] 1226 self.fields_columns[fieldname]=[fieldcolumns]
1223 else: 1227 else:
1224 self.fields_columns[fieldname]=fieldcolumns 1228 self.fields_columns[fieldname]=fieldcolumns
1225 elif type(fieldcolumns) is slice: 1229 elif type(fieldcolumns) is slice:
1226 start,step=None,None 1230 start,step=fieldcolumns.start,fieldcolumns.step
1227 if not fieldcolumns.start: 1231 if not start:
1228 start=0 1232 start=0
1229 if not fieldcolumns.step: 1233 if not step:
1230 step=1 1234 step=1
1231 if start or step: 1235 self.fields_columns[fieldname]=slice(start,fieldcolumns.stop,step)
1232 self.fields_columns[fieldname]=slice(start,fieldcolumns.stop,step)
1233 elif hasattr(fieldcolumns,"__iter__"): # something like a list 1236 elif hasattr(fieldcolumns,"__iter__"): # something like a list
1234 for i in fieldcolumns: 1237 for i in fieldcolumns:
1235 assert i>=0 and i<data_array.shape[1] 1238 assert i>=0 and i<data_array.shape[1]
1236 1239
1237 def fieldNames(self): 1240 def fieldNames(self):
1449 1452
1450 Note that the expected semantics of the function differs in minibatch mode 1453 Note that the expected semantics of the function differs in minibatch mode
1451 (it takes minibatches of inputs and produces minibatches of outputs, as 1454 (it takes minibatches of inputs and produces minibatches of outputs, as
1452 documented in the class comment). 1455 documented in the class comment).
1453 1456
1454 TBM: are filedtypes the old field types (from input_dataset) or the new ones 1457 TBM: are fieldtypes the old field types (from input_dataset) or the new ones
1455 (for the new dataset created)? 1458 (for the new dataset created)?
1456 """ 1459 """
1457 self.input_dataset=input_dataset 1460 self.input_dataset=input_dataset
1458 self.function=function 1461 self.function=function
1459 self.output_names=output_names 1462 self.output_names=output_names
1463 #print 'self.output_names in afds:', self.output_names
1464 #print 'length in afds:', len(self.output_names)
1460 self.minibatch_mode=minibatch_mode 1465 self.minibatch_mode=minibatch_mode
1461 DataSet.__init__(self,description,fieldtypes) 1466 DataSet.__init__(self,description,fieldtypes)
1462 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack 1467 self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack
1463 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack 1468 self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack
1464 1469
1479 input_examples = zip(*input_fields) #makes so that [i] means example i 1484 input_examples = zip(*input_fields) #makes so that [i] means example i
1480 output_examples = [self.function(*input_example) 1485 output_examples = [self.function(*input_example)
1481 for input_example in input_examples] 1486 for input_example in input_examples]
1482 all_output_fields = zip(*output_examples) 1487 all_output_fields = zip(*output_examples)
1483 1488
1489 #print 'output_names=', self.output_names
1490 #print 'all_output_fields', all_output_fields
1491 #print 'len(all_output_fields)=', len(all_output_fields)
1484 all_outputs = Example(self.output_names, all_output_fields) 1492 all_outputs = Example(self.output_names, all_output_fields)
1485 #print 'input_fields', input_fields
1486 #print 'all_outputs', all_outputs
1487 if fieldnames==self.output_names: 1493 if fieldnames==self.output_names:
1488 rval = all_outputs 1494 rval = all_outputs
1489 else: 1495 else:
1490 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames]) 1496 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames])
1491 #print 'rval', rval 1497 #print 'rval', rval