changeset 128:51cc127648e4

Splitup in interface and device
author Windel Bouwman
date Sun, 13 Jan 2013 17:31:35 +0100
parents ec1f2cc04d95
children 9e350a7dde98
files python/devices.py python/st-flash.py python/stlink.py python/stm32.py
diffstat 4 files changed, 294 insertions(+), 237 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/devices.py	Sun Jan 13 17:31:35 2013 +0100
@@ -0,0 +1,43 @@
+import sys
+
+
+# Global device list to which devices are registered.
+deviceList = {}
+
+def registerDevice(chipId):
+   """ Decorator to register a device """
+   def wrapper(dev):
+      deviceList[chipId] = dev
+      return dev
+   return wrapper
+
+# Global interface dictionary.
+interfaces = {}
+
+def registerInterface(vid_pid):
+   def wrapper(iface):
+      interfaces[vid_pid] = iface
+      return iface
+   return wrapper
+
+class Device:
+   """
+      Base class for a device possibly connected via an interface.
+   """
+   pass
+
+class Interface:
+   """
+      Generic interface class. Connected via Usb to a JTAG interface.
+      Possibly is connected with a certain chip.
+   """
+   def getDevice(self):
+      """ Try to get the device connected to this interface """
+      if self.ChipId in deviceList:
+         return deviceList[self.ChipId](self)
+      raise STLinkException('No device found!')
+
+class STLinkException(Exception):
+   """ Exception used for interfaces and devices """
+   pass
+
--- a/python/st-flash.py	Sun Jan 13 13:02:29 2013 +0100
+++ b/python/st-flash.py	Sun Jan 13 17:31:35 2013 +0100
@@ -1,13 +1,13 @@
 #!/usr/bin/python
 
 import argparse, sys
-import stlink 
+import stlink, stm32
 
 def hex2int(s):
    if s.startswith('0x'):
       s = s[2:]
       return int(s, 16)
-   return int(s)
+   raise ValueError('Hexadecimal value must begin with 0x')
 
 parser = argparse.ArgumentParser(
    description='ST-link flash utility by Windel Bouwman')
@@ -36,7 +36,7 @@
    sys.exit(1)
 
 # In any command case, open a device:
-stl = stlink.STLink()
+stl = stlink.STLink2()
 stl.open()
 
 # Enter the right mode:
@@ -50,17 +50,20 @@
    print('Only working on stm32f4discovery board for now.')
    sys.exit(2)
 
+# Retrieve the connected device, if any:
+dev = stl.getDevice()
+
 if args.command == 'read':
-   dev_content = stl.readFlash(args.address, args.size)
+   dev_content = dev.readFlash(args.address, args.size)
    args.filename.write(dev_content)
 elif args.command == 'write':
    content = args.filename.read()
-   stl.writeFlash(args.address, content)
+   dev.writeFlash(args.address, content)
 elif args.command == 'verify':
    content = args.filename.read()
-   stl.verifyFlash(args.address, content)
+   dev.verifyFlash(args.address, content)
 elif args.command == 'erase':
-   stl.eraseFlash()
+   dev.eraseFlash()
 else:
    print('unknown command', args.command)
    
--- a/python/stlink.py	Sun Jan 13 13:02:29 2013 +0100
+++ b/python/stlink.py	Sun Jan 13 17:31:35 2013 +0100
@@ -1,8 +1,6 @@
 import struct, time
 from usb import UsbContext
-
-class STLinkException(Exception):
-   pass
+from devices import Interface, STLinkException, registerInterface
 
 ST_VID, STLINK2_PID = 0x0483, 0x3748
 
@@ -43,65 +41,15 @@
 # cortex M3
 CM3_REG_CPUID = 0xE000ED00
 
-# F4 specifics:
-STM32_FLASH_BASE = 0x08000000
-STM32_SRAM_BASE  = 0x20000000
-
-FLASH_KEY1 = 0x45670123
-FLASH_KEY2 = 0xcdef89ab
-# flash registers:
-FLASH_F4_REGS_ADDR = 0x40023c00
-
-FLASH_F4_KEYR = FLASH_F4_REGS_ADDR + +0x04
-FLASH_F4_SR = FLASH_F4_REGS_ADDR + 0x0c
-FLASH_F4_CR = FLASH_F4_REGS_ADDR + 0x10
-
-FLASH_F4_CR_START = 16
-FLASH_F4_CR_LOCK = 31
-FLASH_CR_PG = 0
-FLASH_F4_CR_SER = 1
-FLASH_CR_MER = 2
-FLASH_F4_CR_SNB = 3
-FLASH_F4_CR_SNB_MASK = 0x38
-FLASH_F4_SR_BSY = 16
-
-def calculate_F4_sector(address):
-   """
-      from 0x8000000 to 0x80FFFFF
-      4 sectors of 0x4000 (16 kB)
-      1 sector of 0x10000 (64 kB)
-      7 of 0x20000 (128 kB)
-   """
-   sectorsizes = [0x4000] * 4 + [0x10000] + [0x20000] * 7
-   sectorstarts = []
-   a = STM32_FLASH_BASE
-   for sectorsize in sectorsizes:
-      sectorstarts.append(a)
-      a += sectorsize
-   # linear search:
-   sec = 0
-   while sec < len(sectorsizes) and address >= sectorstarts[sec]:
-      sec += 1
-   sec -= 1 # one back.
-   return sec, sectorsizes[sec]
-
-def calcSectors(address, size):
-   off = 0
-   sectors = []
-   while off < size:
-      sectornum, sectorsize = calculate_F4_sector(address + off)
-      sectors.append((sectornum, sectorsize))
-      off += sectorsize
-   return sectors
-
-class STLink:
-   """ ST link interface code """
+@registerInterface((ST_VID, STLINK2_PID))
+class STLink2(Interface):
+   """ STlink2 interface implementation. """
    def __init__(self, stlink2=None):
       if not stlink2:
          context = UsbContext()
          stlink2s = list(filter(checkDevice, context.DeviceList))
          if not stlink2s:
-            raise STLinkException('Could not find an ST link')
+            raise STLinkException('Could not find an ST link 2 interface')
          if len(stlink2s) > 1:
             print('More then one stlink2 found, picking first one')
          stlink2 = stlink2s[0]
@@ -202,178 +150,6 @@
       cmd[0:2] = DEBUG_COMMAND, DEBUG_RUNCORE
       self.send_recv(cmd, 2)
 
-   # flashing commands:
-   def writeFlash(self, address, content):
-      # TODO:
-      flashsize = 0x100000 # fixed 1 MB for now..
-      print('WARNING: using 1 MB as flash size')
-      pagesize = 0x4000 # fixed for now!
-
-      # Check address range:
-      if address < STM32_FLASH_BASE:
-         raise STLinkException('Flashing below flash start')
-      if address + len(content) > STM32_FLASH_BASE + flashsize:
-         raise STLinkException('Flashing above flash size')
-      if address & 1 == 1:
-         raise STLinkException('Unaligned flash')
-      if len(content) & 1 == 1:
-         print('unaligned length, padding with zero')
-         content += bytes([0])
-      if address & (pagesize - 1) != 0:
-         raise STLinkException('Address not aligned with pagesize')
-      
-      # erase required space
-      sectors = calcSectors(address, len(content))
-      print('erasing {0} sectors'.format(len(sectors)))
-      for sector, secsize in sectors:
-         print('erasing sector {0} of {1} bytes'.format(sector, secsize))
-         self.eraseFlashSector(sector)
-
-      # program pages:
-      self.unlockFlashIf()
-      self.writeFlashCrPsiz(2) # writes are 32 bits aligned
-      self.setFlashCrPg()
-
-      print('writing {0} bytes'.format(len(content)), end='')
-      offset = 0
-      t1 = time.time()
-      while offset < len(content):
-         size = len(content) - offset
-         if size > 0x8000:
-            size = 0x8000
-
-         chunk = content[offset:offset + size]
-         while len(chunk) % 4 != 0:
-            print('padding chunk')
-            chunk = chunk + bytes([0])
-
-         # Use simple mem32 writes:
-         self.write_mem32(address + offset, chunk)
-
-         offset += size
-         print('.', end='', flush=True)
-      t2 = time.time()
-      print('Done!')
-      print('Speed: {0} bytes/second'.format(len(content)/(t2-t1)))
-
-      self.lockFlash()
-
-      # verfify program:
-      self.verifyFlash(address, content)
-   def eraseFlashSector(self, sector):
-      self.waitFlashBusy()
-      self.unlockFlashIf()
-      self.writeFlashCrSnb(sector)
-      self.setFlashCrStart()
-      self.waitFlashBusy()
-      self.lockFlash()
-   def eraseFlash(self):
-      self.waitFlashBusy()
-      self.unlockFlashIf()
-      self.setFlashCrMer()
-      self.setFlashCrStart()
-      self.waitFlashBusy()
-      self.clearFlashCrMer()
-      self.lockFlash()
-   def verifyFlash(self, address, content):
-      device_content = self.readFlash(address, len(content))
-      ok = content == device_content
-      print('Verify:', ok)
-   def readFlash(self, address, size):
-      print('Reading {1} bytes from 0x{0:X}'.format(address, size), end='')
-      offset = 0
-      tmp_size = 0x1800
-      t1 = time.time()
-      image = bytes()
-      while offset < size:
-         # Correct for last page:
-         if offset + tmp_size > size:
-            tmp_size = size - offset
-
-         # align size to 4 bytes:
-         aligned_size = tmp_size
-         while aligned_size % 4 != 0:
-            aligned_size += 1
-
-         mem = self.read_mem32(address + offset, aligned_size)
-         image += mem[:tmp_size]
-
-         # indicate progress:
-         print('.', end='', flush=True)
-
-         # increase for next piece:
-         offset += tmp_size
-      t2 = time.time()
-      assert size == len(image)
-      print('Done!')
-      print('Speed: {0} bytes/second'.format(size/(t2-t1)))
-      return image
-   def readFlashSr(self):
-      return self.read_debug32(FLASH_F4_SR)
-   def readFlashCr(self):
-      return self.read_debug32(FLASH_F4_CR)
-   def writeFlashCr(self, x):
-      self.write_debug32(FLASH_F4_CR, x)
-   def writeFlashCrSnb(self, sector):
-      x = self.readFlashCr()
-      x &= ~FLASH_F4_CR_SNB_MASK
-      x |= sector << FLASH_F4_CR_SNB
-      x |= 1 << FLASH_F4_CR_SER
-      self.writeFlashCr(x)
-   def setFlashCrMer(self):
-      x = self.readFlashCr()
-      x |= 1 << FLASH_CR_MER
-      self.writeFlashCr(x)
-   def setFlashCrPg(self):
-      x = self.readFlashCr()
-      x |= 1 << FLASH_CR_PG
-      self.writeFlashCr(x)
-   def writeFlashCrPsiz(self, n):
-      x = self.readFlashCr()
-      x &= (0x3 << 8)
-      x |= n << 8
-      self.writeFlashCr(x)
-   def clearFlashCrMer(self):
-      x = self.readFlashCr()
-      x &= ~(1 << FLASH_CR_MER)
-      self.writeFlashCr(x)
-   def setFlashCrStart(self):
-      x = self.readFlashCr()
-      x |= 1 << FLASH_F4_CR_START
-      self.writeFlashCr(x)
-   def isFlashBusy(self):
-      mask = 1 << FLASH_F4_SR_BSY
-      sr = self.readFlashSr()
-      # Check for error bits:
-      errorbits = {}
-      errorbits[7] = 'Programming sequence error'
-      errorbits[6] = 'Programming parallelism error'
-      errorbits[5] = 'Programming alignment error'
-      errorbits[4] = 'Write protection error'
-      errorbits[1] = 'Operation error'
-      #errorbits[0] = 'End of operation'
-      for bit, msg in errorbits.items():
-         if sr & (1 << bit) == (1 << bit):
-            raise STLinkException(msg)
-
-      return sr & mask == mask
-   def waitFlashBusy(self):
-      """ block until flash operation completes. """
-      while self.isFlashBusy():
-         pass
-   def isFlashLocked(self):
-      cr = self.readFlashCr()
-      mask = 1 << FLASH_F4_CR_LOCK
-      return cr & mask == mask
-   def unlockFlashIf(self):
-      if self.isFlashLocked():
-         self.write_debug32(FLASH_F4_KEYR, FLASH_KEY1)
-         self.write_debug32(FLASH_F4_KEYR, FLASH_KEY2)
-         if self.isFlashLocked():
-            raise STLinkException('Failed to unlock')
-   def lockFlash(self):
-      x = self.readFlashCr() | (1 << FLASH_F4_CR_LOCK)
-      self.write_debug32(FLASH_F4_CR, x)
 
    # Helper 1 functions:
    def write_debug32(self, address, value):
@@ -432,7 +208,7 @@
 
 if __name__ == '__main__':
    # Test program
-   sl = STLink()
+   sl = STLink2()
    sl.open()
    print('version:', sl.Version)
    print('mode before doing anything:', sl.CurrentModeString)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/python/stm32.py	Sun Jan 13 17:31:35 2013 +0100
@@ -0,0 +1,235 @@
+import time
+from devices import Device, registerDevice, STLinkException, Interface
+
+# F4 specifics:
+STM32_FLASH_BASE = 0x08000000
+STM32_SRAM_BASE  = 0x20000000
+
+FLASH_KEY1 = 0x45670123
+FLASH_KEY2 = 0xcdef89ab
+
+# flash registers:
+FLASH_F4_REGS_ADDR = 0x40023c00
+FLASH_F4_KEYR = FLASH_F4_REGS_ADDR + 0x04
+FLASH_F4_SR = FLASH_F4_REGS_ADDR + 0x0c
+FLASH_F4_CR = FLASH_F4_REGS_ADDR + 0x10
+
+FLASH_F4_CR_START = 16
+FLASH_F4_CR_LOCK = 31
+FLASH_CR_PG = 0
+FLASH_F4_CR_SER = 1
+FLASH_CR_MER = 2
+FLASH_F4_CR_SNB = 3
+FLASH_F4_CR_SNB_MASK = 0x38
+FLASH_F4_SR_BSY = 16
+
+def calculate_F4_sector(address):
+   """
+      from 0x8000000 to 0x80FFFFF
+      4 sectors of 0x4000 (16 kB)
+      1 sector of 0x10000 (64 kB)
+      7 of 0x20000 (128 kB)
+   """
+   sectorsizes = [0x4000] * 4 + [0x10000] + [0x20000] * 7
+   sectorstarts = []
+   a = STM32_FLASH_BASE
+   for sectorsize in sectorsizes:
+      sectorstarts.append(a)
+      a += sectorsize
+   # linear search:
+   sec = 0
+   while sec < len(sectorsizes) and address >= sectorstarts[sec]:
+      sec += 1
+   sec -= 1 # one back.
+   return sec, sectorsizes[sec]
+
+def calcSectors(address, size):
+   off = 0
+   sectors = []
+   while off < size:
+      sectornum, sectorsize = calculate_F4_sector(address + off)
+      sectors.append((sectornum, sectorsize))
+      off += sectorsize
+   return sectors
+
+@registerDevice(0x10016413)
+class Stm32F4(Device):
+   def __init__(self, iface):
+      super().__init__()
+      assert isinstance(iface, Interface)
+      self.iface = iface
+   # flashing commands:
+   def writeFlash(self, address, content):
+      # TODO:
+      flashsize = 0x100000 # fixed 1 MB for now..
+      print('WARNING: using 1 MB as flash size')
+      pagesize = 0x4000 # fixed for now!
+      print('warning: page size hardcoded')
+
+      # Check address range:
+      if address < STM32_FLASH_BASE:
+         raise STLinkException('Flashing below flash start')
+      if address + len(content) > STM32_FLASH_BASE + flashsize:
+         raise STLinkException('Flashing above flash size')
+      if address & 1 == 1:
+         raise STLinkException('Unaligned flash')
+      if len(content) & 1 == 1:
+         print('unaligned length, padding with zero')
+         content += bytes([0])
+      if address & (pagesize - 1) != 0:
+         raise STLinkException('Address not aligned with pagesize')
+      
+      # erase required space
+      sectors = calcSectors(address, len(content))
+      print('erasing {0} sectors'.format(len(sectors)))
+      for sector, secsize in sectors:
+         print('erasing sector {0} of {1} bytes'.format(sector, secsize))
+         self.eraseFlashSector(sector)
+
+      # program pages:
+      self.unlockFlashIf()
+      self.writeFlashCrPsiz(2) # writes are 32 bits aligned
+      self.setFlashCrPg()
+
+      print('writing {0} bytes'.format(len(content)), end='')
+      offset = 0
+      t1 = time.time()
+      while offset < len(content):
+         size = len(content) - offset
+         if size > 0x8000:
+            size = 0x8000
+
+         chunk = content[offset:offset + size]
+         while len(chunk) % 4 != 0:
+            print('padding chunk')
+            chunk = chunk + bytes([0])
+
+         # Use simple mem32 writes:
+         self.iface.write_mem32(address + offset, chunk)
+
+         offset += size
+         print('.', end='', flush=True)
+      t2 = time.time()
+      print('Done!')
+      print('Speed: {0} bytes/second'.format(len(content)/(t2-t1)))
+
+      self.lockFlash()
+
+      # verfify program:
+      self.verifyFlash(address, content)
+   def eraseFlashSector(self, sector):
+      self.waitFlashBusy()
+      self.unlockFlashIf()
+      self.writeFlashCrSnb(sector)
+      self.setFlashCrStart()
+      self.waitFlashBusy()
+      self.lockFlash()
+   def eraseFlash(self):
+      self.waitFlashBusy()
+      self.unlockFlashIf()
+      self.setFlashCrMer()
+      self.setFlashCrStart()
+      self.waitFlashBusy()
+      self.clearFlashCrMer()
+      self.lockFlash()
+   def verifyFlash(self, address, content):
+      device_content = self.readFlash(address, len(content))
+      ok = content == device_content
+      print('Verify:', ok)
+   def readFlash(self, address, size):
+      print('Reading {1} bytes from 0x{0:X}'.format(address, size), end='')
+      offset = 0
+      tmp_size = 0x1800
+      t1 = time.time()
+      image = bytes()
+      while offset < size:
+         # Correct for last page:
+         if offset + tmp_size > size:
+            tmp_size = size - offset
+
+         # align size to 4 bytes:
+         aligned_size = tmp_size
+         while aligned_size % 4 != 0:
+            aligned_size += 1
+
+         mem = self.iface.read_mem32(address + offset, aligned_size)
+         image += mem[:tmp_size]
+
+         # indicate progress:
+         print('.', end='', flush=True)
+
+         # increase for next piece:
+         offset += tmp_size
+      t2 = time.time()
+      assert size == len(image)
+      print('Done!')
+      print('Speed: {0} bytes/second'.format(size/(t2-t1)))
+      return image
+
+   def waitFlashBusy(self):
+      """ block until flash operation completes. """
+      while self.isFlashBusy():
+         pass
+   def isFlashLocked(self):
+      cr = self.readFlashCr()
+      mask = 1 << FLASH_F4_CR_LOCK
+      return cr & mask == mask
+   def unlockFlashIf(self):
+      if self.isFlashLocked():
+         self.iface.write_debug32(FLASH_F4_KEYR, FLASH_KEY1)
+         self.iface.write_debug32(FLASH_F4_KEYR, FLASH_KEY2)
+         if self.isFlashLocked():
+            raise STLinkException('Failed to unlock')
+   def lockFlash(self):
+      x = self.readFlashCr() | (1 << FLASH_F4_CR_LOCK)
+      self.writeFlashCr(x)
+
+   def readFlashSr(self):
+      return self.iface.read_debug32(FLASH_F4_SR)
+   def readFlashCr(self):
+      return self.iface.read_debug32(FLASH_F4_CR)
+   def writeFlashCr(self, x):
+      self.iface.write_debug32(FLASH_F4_CR, x)
+   def writeFlashCrSnb(self, sector):
+      x = self.readFlashCr()
+      x &= ~FLASH_F4_CR_SNB_MASK
+      x |= sector << FLASH_F4_CR_SNB
+      x |= 1 << FLASH_F4_CR_SER
+      self.writeFlashCr(x)
+   def setFlashCrMer(self):
+      x = self.readFlashCr()
+      x |= 1 << FLASH_CR_MER
+      self.writeFlashCr(x)
+   def setFlashCrPg(self):
+      x = self.readFlashCr()
+      x |= 1 << FLASH_CR_PG
+      self.writeFlashCr(x)
+   def writeFlashCrPsiz(self, n):
+      x = self.readFlashCr()
+      x &= (0x3 << 8)
+      x |= n << 8
+      self.writeFlashCr(x)
+   def clearFlashCrMer(self):
+      x = self.readFlashCr()
+      x &= ~(1 << FLASH_CR_MER)
+      self.writeFlashCr(x)
+   def setFlashCrStart(self):
+      x = self.readFlashCr()
+      x |= 1 << FLASH_F4_CR_START
+      self.writeFlashCr(x)
+   def isFlashBusy(self):
+      mask = 1 << FLASH_F4_SR_BSY
+      sr = self.readFlashSr()
+      # Check for error bits:
+      errorbits = {}
+      errorbits[7] = 'Programming sequence error'
+      errorbits[6] = 'Programming parallelism error'
+      errorbits[5] = 'Programming alignment error'
+      errorbits[4] = 'Write protection error'
+      errorbits[1] = 'Operation error'
+      #errorbits[0] = 'End of operation'
+      for bit, msg in errorbits.items():
+         if sr & (1 << bit) == (1 << bit):
+            raise STLinkException(msg)
+      return sr & mask == mask
+