changeset 4:0d5ef85b8698

Improved link between ast viewer and code edit
author windel-eee
date Wed, 21 Sep 2011 19:05:18 +0200
parents 77202b0e0f40
children 818f80afa78b
files ide/compiler/codegenerator.py ide/compiler/nodes.py ide/compiler/parser.py ide/ide/astviewer.py ide/ide/ide.py ide/runtests.py
diffstat 6 files changed, 346 insertions(+), 39 deletions(-) [+]
line wrap: on
line diff
--- a/ide/compiler/codegenerator.py	Sun Sep 18 21:21:54 2011 +0200
+++ b/ide/compiler/codegenerator.py	Wed Sep 21 19:05:18 2011 +0200
@@ -405,12 +405,21 @@
       elif type(node) is ForStatement:
          # Initial load of iterator variable:
          self.genexprcode(node.begin)
-         self.storeRegInDesignator(node.begin.reg, node.variable)
-         self.freereg(node.begin)
+         self.genexprcode(node.end)
+         # TODO: link reg with variable so that a register is used instead of a variable
+         iterreg = node.begin.reg # Get the register used for the loop
+         #self.addCode(cmpreg64(iterreg, node.endvalue))
          rip1 = self.rip
          self.gencode(node.statements)
          #self.loadDesignatorInReg(node.
          #self.addCode( addreg64(node.variable, node.increment) )
+         self.addCode(nearjump(0x0))
+         fixloc1 = self.rip - 4
+         rip2 = self.rip
+         self.fixCode(fixloc1, imm32(rip1 - rip2))
+
+         self.freereg(node.begin) # Release register used in loop
+         self.freereg(node.end)
          Error('No implementation of FOR statement')
 
       elif type(node) is AsmCode:
--- a/ide/compiler/nodes.py	Sun Sep 18 21:21:54 2011 +0200
+++ b/ide/compiler/nodes.py	Wed Sep 21 19:05:18 2011 +0200
@@ -2,6 +2,7 @@
 Parse tree elements
 """
 class Node:
+   location = None
    def getChildren(self):
       children = []
       members = dir(self)
@@ -16,7 +17,13 @@
       return children
 
 class Symbol(Node):
-  pass
+   pass
+
+class Id(Node):
+   def __init__(self, name):
+      self.name = name
+   def __repr__(self):
+      return 'ID {0}'.format(self.name)
 
 # Selectors:
 class Field(Node):
--- a/ide/compiler/parser.py	Sun Sep 18 21:21:54 2011 +0200
+++ b/ide/compiler/parser.py	Wed Sep 21 19:05:18 2011 +0200
@@ -37,20 +37,23 @@
      
    def NextToken(self):
      self.token = self.tokens.__next__()
+     self.location = (self.token.row, self.token.col)
 
-   def setLocation(self):
-      pass
-   def attachLocation(self, node):
-      node.row, node.col = self.token.row, self.token.col
-      return node
+   # Helpers to find location of the error in the code:
+   def setLocation(self, obj, location):
+      obj.location = location
+      return obj
+   def getLocation(self):
+      return self.location
+
    """
      Recursive descent parser functions:
         A set of mutual recursive functions.
         Starting symbol is the Module.
    """
-
    def parseModule(self):
        self.imports = []
+       loc = self.getLocation()
        self.Consume('module')
        modname = self.Consume('ID')
        self.Consume(';')
@@ -81,7 +84,7 @@
        self.Consume('.')
 
        mod.imports = self.imports
-       return mod
+       return self.setLocation(mod, loc)
 
    # Import part
    def parseImportList(self):
@@ -92,15 +95,21 @@
          self.Consume(';')
 
    def parseImport(self):
+      loc = self.getLocation()
       modname = self.Consume('ID')
       mod = loadModule(modname)
+      self.setLocation(mod, loc)
       self.cst.addSymbol(mod)
 
    # Helper to parse an identifier defenitions
    def parseIdentDef(self):
+      loc = self.getLocation()
       name = self.Consume('ID')
       ispublic = self.hasConsumed('*')
-      return (name, ispublic)
+      # Make a node of this thing:
+      i = Id(name)
+      i.ispublic = ispublic
+      return self.setLocation(i, loc)
 
    def parseIdentList(self):
       ids = [ self.parseIdentDef() ]
@@ -138,6 +147,7 @@
            The base location in memory is denoted by the qualified identifier
            The actual address depends on the selector.
       """
+      loc = self.getLocation()
       obj = self.parseQualIdent()
       typ = obj.typ
       selectors = []
@@ -164,7 +174,7 @@
          elif self.hasConsumed('^'):
             selectors.append(Deref())
             typ = typ.pointedType
-      return Designator(obj, selectors, typ)
+      return self.setLocation(Designator(obj, selectors, typ), loc)
 
    # Declaration sequence
    def parseDeclarationSequence(self):
@@ -222,11 +232,12 @@
       """ Parse const part of a module """
       if self.hasConsumed('const'):
          while self.token.typ == 'ID':
-            name, ispublic = self.parseIdentDef()
+            i = self.parseIdentDef()
             self.Consume('=')
             constvalue, typ = self.parseConstExpression()
             self.Consume(';')
-            c = Constant(constvalue, typ, name=name, public=ispublic)
+            c = Constant(constvalue, typ, name=i.name, public=i.ispublic)
+            self.setLocation(c, i.location)
             self.cst.addSymbol(c)
      
    # Type system
@@ -276,10 +287,10 @@
             self.Consume(':')
             typ = self.parseType()
             self.Consume(';')
-            for id, public in identifiers:
-               if id in fields.keys():
-                  self.Error('record field "{0}" multiple defined.'.format(id))
-               fields[id] = typ
+            for i in identifiers:
+               if i.name in fields.keys():
+                  self.Error('record field "{0}" multiple defined.'.format(i.name))
+               fields[i.name] = typ
             # TODO store this in another way, symbol table?
          self.Consume('end')
          return RecordType(fields)
@@ -302,8 +313,9 @@
                self.Consume(':')
                typename = self.parseType()
                self.Consume(';')
-               for name, ispublic in ids:
-                  v = Variable(name, typename, public=ispublic)
+               for i in ids:
+                  v = Variable(i.name, typename, public=i.ispublic)
+                  self.setLocation(v, i.location)
                   self.cst.addSymbol(v)
          else:
             self.Error('Expected ID, got'+str(self.token))
@@ -349,7 +361,8 @@
 
    def parseProcedureDeclaration(self):
      self.Consume('procedure')
-     name, ispublic = self.parseIdentDef()
+     i = self.parseIdentDef()
+     procname = i.name
      proctyp = self.parseFormalParameters()
      procsymtable = SymbolTable(parent = self.cst)
      self.cst = procsymtable    # Switch symbol table:
@@ -388,16 +401,17 @@
 
      self.Consume('end')
      endname = self.Consume('ID')
-     if endname != name:
+     if endname != procname:
         self.Error('endname should match {0}'.format(name))
      self.cst = procsymtable.parent    # Switch back to parent symbol table
-     proc = Procedure(name, proctyp, block, procsymtable, returnexpression)
+     proc = Procedure(procname, proctyp, block, procsymtable, returnexpression)
      self.cst.addSymbol(proc)
-     proc.public = ispublic
+     proc.public = i.ispublic
      return proc
 
    # Statements:
    def parseAssignment(self, lval):
+      loc = self.getLocation()
       self.Consume(':=')
       rval = self.parseExpression()
       if isType(lval.typ, real) and isType(rval.typ, integer):
@@ -407,7 +421,7 @@
             self.Error('Can assign nil only to pointers or procedure types, not {0}'.format(lval))
       elif not isType(lval.typ, rval.typ):
          self.Error('Type mismatch {0} != {1}'.format(lval.typ, rval.typ))
-      return Assignment(lval, rval)
+      return self.setLocation(Assignment(lval, rval), loc)
 
    def parseExpressionList(self):
       expressions = [ self.parseExpression() ]
@@ -432,6 +446,7 @@
       return ProcedureCall(procedure, args)
 
    def parseIfStatement(self):
+     loc = self.getLocation()
      self.Consume('if')
      ifs = []
      condition = self.parseExpression()
@@ -454,7 +469,7 @@
      self.Consume('end')
      for condition, truestatement in reversed(ifs):
          statement = IfStatement(condition, truestatement, statement)
-     return statement
+     return self.setLocation(statement, loc)
 
    def parseCase(self):
       # TODO
@@ -470,6 +485,7 @@
       self.Consume('end')
 
    def parseWhileStatement(self):
+      loc = self.getLocation()
       self.Consume('while')
       condition = self.parseExpression()
       self.Consume('do')
@@ -477,7 +493,7 @@
       if self.hasConsumed('elsif'):
          self.Error('elsif in while not yet implemented')
       self.Consume('end')
-      return WhileStatement(condition, statements)
+      return self.setLocation(WhileStatement(condition, statements), loc)
 
    def parseRepeatStatement(self):
       self.Consume('repeat')
@@ -486,6 +502,7 @@
       cond = self.parseBoolExpression()
 
    def parseForStatement(self):
+      loc = self.getLocation()
       self.Consume('for')
       variable = self.parseDesignator()
       if not variable.typ.isType(integer):
@@ -509,9 +526,10 @@
       self.Consume('do')
       statements = self.parseStatementSequence()
       self.Consume('end')
-      return ForStatement(variable, begin, end, increment, statements)
+      return self.setLocation(ForStatement(variable, begin, end, increment, statements), loc)
 
    def parseAsmcode(self):
+      # TODO: move this to seperate file
       def parseOpcode():
          return self.Consume('ID')
       def parseOperand():
@@ -622,6 +640,7 @@
    def parseTerm(self):
        a = self.parseFactor()
        while self.token.typ in ['*', '/', 'mod', 'div', 'and']:
+           loc = self.getLocation()
            op = self.Consume()
            b = self.parseTerm()
            # Type determination and checking:
@@ -668,7 +687,7 @@
            else:
               self.Error('Unknown operand {0}'.format(op))
 
-           a = Binop(a, op, b, typ)
+           a = self.setLocation(Binop(a, op, b, typ), loc)
        return a
 
    def parseFactor(self):
@@ -677,11 +696,13 @@
          self.Consume(')')
          return e
       elif self.token.typ == 'NUMBER':
-          val = self.Consume('NUMBER')
-          return Constant(val, integer)
+         loc = self.getLocation() 
+         val = self.Consume('NUMBER')
+         return self.setLocation(Constant(val, integer), loc)
       elif self.token.typ == 'REAL':
-          val = self.Consume('REAL')
-          return Constant(val, real)
+         loc = self.getLocation()
+         val = self.Consume('REAL')
+         return self.setLocation(Constant(val, real), loc)
       elif self.token.typ == 'CHAR':
           val = self.Consume('CHAR')
           return Constant(val, char)
@@ -723,6 +744,7 @@
       else:
          a = self.parseTerm()
       while self.token.typ in ['+', '-', 'or']:
+           loc = self.getLocation()
            op = self.Consume()
            b = self.parseTerm()
            if op in ['+', '-']:
@@ -749,6 +771,6 @@
               typ = boolean
            else:
               self.Error('Unknown operand {0}'.format(op))
-           a = Binop(a, op, b, typ)
+           a = self.setLocation(Binop(a, op, b, typ), loc)
       return a
 
--- a/ide/ide/astviewer.py	Sun Sep 18 21:21:54 2011 +0200
+++ b/ide/ide/astviewer.py	Wed Sep 21 19:05:18 2011 +0200
@@ -4,16 +4,18 @@
 def astToNamedElement(astNode, parentNode):
    """ Helper to convert and AST tree to NamedElement tree: """
    item = QStandardItem(str(astNode))
+   item.setData(astNode)
    parentNode.appendRow(item)
    for c in astNode.getChildren():
       astToNamedElement(c, item)
 
 # The actual widget:
 class AstViewer(QTreeView):
+   sigNodeSelected = pyqtSignal(object)
    def __init__(self, parent=None):
       super(AstViewer, self).__init__(parent)
       self.setHeaderHidden(True)
-      self.clicked.connect(self.woei)
+      self.clicked.connect(self.selectHandler)
 
    def setAst(self, ast):
       """ Create a new model and add all ast elements to it """
@@ -23,11 +25,12 @@
       self.setModel( model )
       self.expandAll()
 
-   def woei(self, index):
+   def selectHandler(self, index):
       if not index.isValid():
-         print('Invalid index')
          return
-      print(index.data)
+      model = self.model()
+      item = model.itemFromIndex(index)
+      node = item.data()
+      self.sigNodeSelected.emit(node)
 
 
-
--- a/ide/ide/ide.py	Sun Sep 18 21:21:54 2011 +0200
+++ b/ide/ide/ide.py	Wed Sep 21 19:05:18 2011 +0200
@@ -47,6 +47,7 @@
 
     self.astViewer = AstViewer()
     self.addComponent('AST viewer', self.astViewer)
+    self.astViewer.sigNodeSelected.connect(self.nodeSelected)
 
     # Create actions:
     self.buildAction = QAction('Build!', self)
@@ -79,6 +80,14 @@
      self.settings.setValue('mainwindowgeometry', self.saveGeometry())
      self.codeedit.saveFile()
      ev.accept()
+
+  def nodeSelected(self, node):
+      if node.location:
+         row, col = node.location
+         self.codeedit.highlightErrorLocation( row, col )
+      else:
+         self.codeedit.clearErrors()
+
   def buildFile(self):
      self.buildOutput.clear()
      self.codeedit.clearErrors()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/ide/runtests.py	Wed Sep 21 19:05:18 2011 +0200
@@ -0,0 +1,257 @@
+import unittest
+
+from compiler.compiler import Compiler
+from compiler.errors import CompilerException, printError
+from compiler import lexer
+from compiler.parser import Parser
+from compiler import assembler
+from compiler.codegenerator import CodeGenerator
+
+class CompilerTestCase(unittest.TestCase):
+   """ test methods start with 'test*' """
+   def testSource1(self):
+      source = """
+      module lcfos;
+      var  
+        a : integer;
+
+      procedure putchar(num : integer);
+      begin
+      end putchar;
+
+      procedure WriteNum( num: integer);
+        var 
+          d, base :  integer;
+          dgt : integer;
+        begin
+          d := 1;
+          base := 10;
+          while num div d >= base do
+            d := d * base
+          end;
+          while d <> 0 do
+             dgt := num div d;
+             num := num mod d;
+             d   := d div base;
+             putchar(48 + dgt)
+           end
+        end WriteNum;
+
+      begin
+        a := 1;
+        while a < 26
+         do
+           putchar(65+a);
+           a := a * 2
+         end;
+      end lcfos.
+      """
+      pc = Compiler()
+      pc.compilesource(source)
+   def testSource2(self):
+      source = """
+      module lcfos;
+      var  
+        a, b : integer;
+        arr: array 30 of integer;
+        arr2: array 10, 12 of integer;
+        procedure t2*() : integer;
+        begin
+        a := 2;
+        while a < 5 do
+           b := arr[a-1] + arr[a-2];
+           arr2[a,2] := b;
+           arr2[a,3] := arr2[a,2] + arr2[a,2]*3 + b;
+           arr[a] := b;
+           a := a  + 1;
+         end;
+         return b
+        end t2;
+        begin
+         b := 12;
+         arr[0] := 1;
+         arr[1] := 1;
+      end lcfos.
+      """
+      pc = Compiler()
+      mod = pc.compilesource(source)
+   def testSource5(self):
+      source = """
+      module lcfos;
+      procedure WriteLn() : integer;
+        const zzz = 13;
+        var
+          a, b, c: integer;
+        begin
+         a := 2;
+         b := 7;
+         c := 10 * a + b*10*a;
+         return c
+        end WriteLn;
+      begin  end lcfos.
+      """
+      pc = Compiler()
+      pc.compilesource(source)
+   def testForStatement(self):
+      source = """
+      module fortest;
+      var  
+        a,b,c : integer;
+      begin
+         c := 0;
+         for a := 1 to 10 by 1 do
+            b := a + 15;
+            c := c + b * a;
+         end;
+      end fortest.
+      """
+      pc = Compiler()
+      pc.compilesource(source)
+   def testSourceIfAndWhilePattern(self):
+      source = """
+      module lcfos;
+      procedure WriteLn() : integer;
+        const zzz = 13;
+        var
+          a, b, c: integer;
+        begin
+         a := 1;
+         b := 2;
+         if a * 3 > b then
+            c := 10*a + b*10*a*a*a*b;
+         else
+            c := 13;
+         end;
+         while a < 101 do
+            a := a + 1;
+            c := c + 2;
+         end;
+         return c
+        end WriteLn;
+      begin end lcfos.
+      """
+      pc = Compiler()
+      pc.compilesource(source)
+
+   def testPattern1(self):
+      """ Test if expression can be compiled into byte code """
+      src = "12*13+33-12*2*3"
+      tokens = lexer.tokenize(src)
+      ast = Parser(tokens).parseExpression()
+      code = CodeGenerator().genexprcode(ast)
+
+   def testAssembler(self):
+      """ Check all kind of assembler cases """
+      assert(assembler.shortjump(5) == [0xeb, 0x5])
+      assert(assembler.shortjump(-2) == [0xeb, 0xfc])
+      assert(assembler.shortjump(10,'GE') == [0x7d, 0xa])
+      assert(assembler.nearjump(5) == [0xe9, 0x5,0x0,0x0,0x0])
+      assert(assembler.nearjump(-2) == [0xe9, 0xf9, 0xff,0xff,0xff])
+      assert(assembler.nearjump(10,'LE') == [0x0f, 0x8e, 0xa,0x0,0x0,0x0])
+
+   def testCall(self):
+      assert(assembler.call('r10') == [0x41, 0xff, 0xd2])
+      assert(assembler.call('rcx') == [0xff, 0xd1])
+   def testXOR(self):
+      assert(assembler.xorreg64('rax', 'rax') == [0x48, 0x31, 0xc0])
+      assert(assembler.xorreg64('r9', 'r8') == [0x4d, 0x31, 0xc1])
+      assert(assembler.xorreg64('rbx', 'r11') == [0x4c, 0x31, 0xdb])
+
+   def testINC(self):
+      assert(assembler.increg64('r11') == [0x49, 0xff, 0xc3])
+      assert(assembler.increg64('rcx') == [0x48, 0xff, 0xc1])
+
+   def testPush(self):
+      assert(assembler.push('rbp') == [0x55])
+      assert(assembler.push('rbx') == [0x53])
+      assert(assembler.push('r12') == [0x41, 0x54])
+   def testPop(self):
+      assert(assembler.pop('rbx') == [0x5b])
+      assert(assembler.pop('rbp') == [0x5d])
+      assert(assembler.pop('r12') == [0x41, 0x5c])
+
+   def testAsmLoads(self):
+      # TODO constant add testcases
+      assert(assembler.mov('rbx', 'r14') == [0x4c, 0x89, 0xf3])
+      assert(assembler.mov('r12', 'r8')  == [0x4d, 0x89, 0xc4])
+      assert(assembler.mov('rdi', 'rsp') == [0x48, 0x89, 0xe7])
+
+   def testAsmMemLoads(self):
+      assert(assembler.mov('rax', ['r8','r15',0x11]) == [0x4b,0x8b,0x44,0x38,0x11])
+      assert(assembler.mov('r13', ['rbp','rcx',0x23]) == [0x4c,0x8b,0x6c,0xd,0x23])
+
+      assert(assembler.mov('r9', ['rbp',-0x33]) == [0x4c,0x8b,0x4d,0xcd])
+      #assert(assembler.movreg64('rbx', ['rax']) == [0x48, 0x8b,0x18])
+
+      assert(assembler.mov('rax', [0xb000]) == [0x48,0x8b,0x4,0x25,0x0,0xb0,0x0,0x0])
+      assert(assembler.mov('r11', [0xa0]) == [0x4c,0x8b,0x1c,0x25,0xa0,0x0,0x0,0x0])
+
+      assert(assembler.mov('r11', ['RIP', 0xf]) == [0x4c,0x8b,0x1d,0x0f,0x0,0x0,0x0])
+
+   def testAsmMemStores(self):
+      assert(assembler.mov(['rbp', 0x13],'rbx') == [0x48,0x89,0x5d,0x13])
+      assert(assembler.mov(['r12', 0x12],'r9') == [0x4d,0x89,0x4c,0x24,0x12])
+      assert(assembler.mov(['rcx', 0x11],'r14') == [0x4c,0x89,0x71,0x11])
+
+
+      assert(assembler.mov([0xab], 'rbx') == [0x48,0x89,0x1c,0x25,0xab,0x0,0x0,0x0])
+      assert(assembler.mov([0xcd], 'r13') == [0x4c,0x89,0x2c,0x25,0xcd,0x0,0x0,0x0])
+
+      assert(assembler.mov(['RIP', 0xf], 'r9') == [0x4c,0x89,0x0d,0x0f,0x0,0x0,0x0])
+
+   def testAsmMOV8(self):
+      assert(assembler.mov(['rbp', -8], 'al') == [0x88, 0x45, 0xf8])
+      assert(assembler.mov(['r11', 9], 'cl') == [0x41, 0x88, 0x4b, 0x09])
+
+      assert(assembler.mov(['rbx'], 'al') == [0x88, 0x03])
+      assert(assembler.mov(['r11'], 'dl') == [0x41, 0x88, 0x13])
+
+   def testAsmLea(self):
+      assert(assembler.leareg64('r11', ['RIP', 0xf]) == [0x4c,0x8d,0x1d,0x0f,0x0,0x0,0x0])
+      assert(assembler.leareg64('rsi', ['RIP', 0x7]) == [0x48,0x8d,0x35,0x07,0x0,0x0,0x0])
+
+      assert(assembler.leareg64('rcx', ['rbp', -8]) == [0x48,0x8d,0x4d,0xf8])
+
+   def testAssemblerCMP(self):
+      assert(assembler.cmpreg64('rdi', 'r13') == [0x4c, 0x39, 0xef])
+      assert(assembler.cmpreg64('rbx', 'r14') == [0x4c, 0x39, 0xf3])
+      assert(assembler.cmpreg64('r12', 'r9')  == [0x4d, 0x39, 0xcc])
+
+      assert(assembler.cmpreg64('rdi', 1)  == [0x48, 0x83, 0xff, 0x01])
+      assert(assembler.cmpreg64('r11', 2)  == [0x49, 0x83, 0xfb, 0x02])
+   def testAssemblerADD(self):
+      assert(assembler.addreg64('rbx', 'r13') == [0x4c, 0x01, 0xeb])
+      assert(assembler.addreg64('rax', 'rbx') == [0x48, 0x01, 0xd8])
+      assert(assembler.addreg64('r12', 'r13') == [0x4d, 0x01, 0xec])
+
+      assert(assembler.addreg64('rbx', 0x13) == [0x48, 0x83, 0xc3, 0x13])
+      assert(assembler.addreg64('r11', 0x1234567) == [0x49, 0x81, 0xc3, 0x67, 0x45,0x23,0x1])
+      assert(assembler.addreg64('rsp', 0x33) == [0x48, 0x83, 0xc4, 0x33])
+
+   def testAssemblerSUB(self):
+      assert(assembler.subreg64('rdx', 'r14') == [0x4c, 0x29, 0xf2])
+      assert(assembler.subreg64('r15', 'rbx') == [0x49, 0x29, 0xdf])
+      assert(assembler.subreg64('r8', 'r9') == [0x4d, 0x29, 0xc8])
+
+      assert(assembler.subreg64('rsp', 0x123456) == [0x48, 0x81, 0xec, 0x56,0x34,0x12,0x0])
+      assert(assembler.subreg64('rsp', 0x12) == [0x48, 0x83, 0xec, 0x12])
+
+   def testAssemblerIDIV(self):
+      assert(assembler.idivreg64('r11') == [0x49, 0xf7, 0xfb])
+      assert(assembler.idivreg64('rcx') == [0x48, 0xf7, 0xf9])
+      assert(assembler.idivreg64('rsp') == [0x48, 0xf7, 0xfc])
+
+   def testAssemblerIMUL(self):
+      assert(assembler.imulreg64_rax('rdi') == [0x48, 0xf7, 0xef])
+      assert(assembler.imulreg64_rax('r10') == [0x49, 0xf7, 0xea])
+      assert(assembler.imulreg64_rax('rdx') == [0x48, 0xf7, 0xea])
+
+      assert(assembler.imulreg64('r11', 'rdi') == [0x4c, 0xf, 0xaf, 0xdf])
+      assert(assembler.imulreg64('r12', 'rbx') == [0x4c, 0xf, 0xaf, 0xe3])
+      # nasm generates this machine code: 0x4d, 0x6b, 0xff, 0xee
+      # This also works: 4D0FAFFE (another variant?? )
+      assert(assembler.imulreg64('r15', 'r14') == [0x4d, 0x0f, 0xaf, 0xfe])
+
+if __name__ == '__main__':
+   unittest.main()
+