view python/stlink.py @ 116:90b03bc018cf

Added loader code from openocd
author Windel Bouwman
date Mon, 07 Jan 2013 19:30:01 +0100
parents 92b2bf0da1ec
children f2b37d78082d
line wrap: on
line source

import struct, time
from usb import UsbContext

class STLinkException(Exception):
   pass

ST_VID, STLINK2_PID = 0x0483, 0x3748

def checkDevice(device):
   return device.VendorId == ST_VID and device.ProductId == STLINK2_PID

DFU_MODE, MASS_MODE, DEBUG_MODE = 0, 1, 2
# Commands:
GET_VERSION = 0xf1
DEBUG_COMMAND = 0xf2
DFU_COMMAND = 0xf3
GET_CURRENT_MODE = 0xf5

# dfu commands:
DFU_EXIT = 0x7

# debug commands:
DEBUG_ENTER = 0x20
DEBUG_EXIT = 0x21
DEBUG_ENTER_SWD = 0xa3
DEBUG_GETSTATUS = 0x01
DEBUG_RESETSYS = 0x03
DEBUG_READREG = 0x5
DEBUG_WRITEREG = 0x6
DEBUG_READMEM_32BIT = 0x7
DEBUG_WRITEMEM_32BIT = 0x8

JTAG_WRITEDEBUG_32BIT = 0x35
JTAG_READDEBUG_32BIT = 0x36

# cortex M3
CM3_REG_CPUID = 0xE000ED00

# F4 specifics:
STM32_FLASH_BASE = 0x08000000
STM32_SRAM_BASE  = 0x20000000

FLASH_KEY1 = 0x45670123
FLASH_KEY2 = 0xcdef89ab
# flash registers:
FLASH_F4_REGS_ADDR = 0x40023c00

FLASH_F4_KEYR = FLASH_F4_REGS_ADDR + +0x04
FLASH_F4_SR = FLASH_F4_REGS_ADDR + 0x0c
FLASH_F4_CR = FLASH_F4_REGS_ADDR + 0x10

FLASH_F4_CR_START = 16
FLASH_F4_CR_LOCK = 31
FLASH_F4_CR_SER = 1
FLASH_CR_MER = 2
FLASH_F4_CR_SNB = 3
FLASH_F4_CR_SNB_MASK = 0x38
FLASH_F4_SR_BSY = 16

# flashloaders/stm32f4.s
loader_code_stm32f4 = bytes([
   0x07, 0x4b,

   0x62, 0xb1,
   0x04, 0x68,
   0x0c, 0x60,

   0xdc, 0x89,
   0x14, 0xf0, 0x01, 0x0f,
   0xfb, 0xd1,
   0x00, 0xf1, 0x04, 0x00,
   0x01, 0xf1, 0x04, 0x01,
   0xa2, 0xf1, 0x01, 0x02,
   0xf1, 0xe7,

   0x00, 0xbe,

   0x00, 0x3c, 0x02, 0x40
])

def calculate_F4_sector(address):
   """
      from 0x8000000 to 0x80FFFFF
      4 sectors of 0x4000 (16 kB)
      1 sector of 0x10000 (64 kB)
      7 of 0x20000 (128 kB)
   """
   sectorsizes = [0x4000] * 4 + [0x10000] + [0x20000] * 7
   sectorstarts = []
   a = STM32_FLASH_BASE
   for sectorsize in sectorsizes:
      sectorstarts.append(a)
      a += sectorsize
   # linear search:
   sec = 0
   while sec < len(sectorsizes) and address >= sectorstarts[sec]:
      sec += 1
   sec -= 1 # one back.
   return sec, sectorsizes[sec]

def calcSectors(address, size):
   off = 0
   sectors = []
   while off < size:
      sectornum, sectorsize = calculate_F4_sector(address + off)
      #print('num: {0} size: {1:X} offset: {2}'.format(sectornum, sectorsize, off))
      sectors.append((sectornum, sectorsize))
      off += sectorsize
   return sectors

class STLink:
   def __init__(self):
      self.context = UsbContext()
   def open(self):
      context = UsbContext()
      stlink2s = list(filter(checkDevice, context.DeviceList))
      if not stlink2s:
         raise STLinkException('Could not find an ST link')
      if len(stlink2s) > 1:
         print('More then one stlink2 found, picking first one')
      stlink2 = stlink2s[0]
      self.devHandle = stlink2.open()
      if self.devHandle.Configuration != 1:
         self.devHandle.Configuration = 1
      self.devHandle.claimInterface(0)

      # First initialization:
      if self.CurrentMode == DFU_MODE:
         self.exitDfuMode()
      if self.CurrentMode != DEBUG_MODE:
         self.enterSwdMode()
      self.reset()
   def close(self):
      pass

   # modes:
   def getCurrentMode(self):
      cmd = bytearray(16)
      cmd[0] = GET_CURRENT_MODE
      reply = self.send_recv(cmd, 2) # Expect 2 bytes back
      return reply[0]
   CurrentMode = property(getCurrentMode)
   @property
   def CurrentModeString(self):
      modes = {DFU_MODE: 'dfu', MASS_MODE: 'massmode', DEBUG_MODE:'debug'}
      return modes[self.CurrentMode]
   def exitDfuMode(self):
      cmd = bytearray(16)
      cmd[0:2] = DFU_COMMAND, DFU_EXIT
      self.send_recv(cmd)
   def enterSwdMode(self):
      cmd = bytearray(16)
      cmd[0:3] = DEBUG_COMMAND, DEBUG_ENTER, DEBUG_ENTER_SWD
      self.send_recv(cmd)
   def exitDebugMode(self):
      cmd = bytearray(16)
      cmd[0:2] = DEBUG_COMMAND, DEBUG_EXIT
      self.send_recv(cmd)
      
   def getVersion(self):
      cmd = bytearray(16)
      cmd[0] = GET_VERSION
      data = self.send_recv(cmd, 6) # Expect 6 bytes back
      # Parse 6 bytes into various versions:
      b0, b1, b2, b3, b4, b5 = data
      stlink_v = b0 >> 4
      jtag_v = ((b0 & 0xf) << 2) | (b1 >> 6)
      swim_v = b1 & 0x3f
      vid = (b3 << 8) | b2
      pid = (b5 << 8) | b4

      return 'stlink={0} jtag={1} swim={2} vid:pid={3:04X}:{4:04X}'.format(\
         stlink_v, jtag_v, swim_v, vid, pid)
   Version = property(getVersion)
   
   @property
   def ChipId(self):
      return self.read_debug32(0xE0042000)
   @property
   def CpuId(self):
      u32 = self.read_debug32(CM3_REG_CPUID)
      implementer_id = (u32 >> 24) & 0x7f
      variant = (u32 >> 20) & 0xf
      part = (u32 >> 4) & 0xfff
      revision = u32 & 0xf
      return implementer_id, variant, part, revision

   def status(self):
      cmd = bytearray(16)
      cmd[0:2] = DEBUG_COMMAND, DEBUG_GETSTATUS
      reply = self.send_recv(cmd, 2)
      return reply[0]
   def reset(self):
      cmd = bytearray(16)
      cmd[0:2] = DEBUG_COMMAND, DEBUG_RESETSYS
      self.send_recv(cmd, 2)

   # debug commands:
   def step(self):
      cmd = bytearray(16)
      cmd[0:2] = DEBUG_COMMAND, DEBUG_STEPCORE
      self.send_recv(cmd, 2)
   def run(self):
      cmd = bytearray(16)
      cmd[0:2] = DEBUG_COMMAND, DEBUG_RUNCORE
      self.send_recv(cmd, 2)

   # flashing commands:
   def writeFlash(self, address, content):
      # TODO:
      flashsize = 0x100000 # fixed 1 MB for now..
      print('WARNING: using 1 MB as flash size')
      pagesize = 0x4000 # fixed for now!

      # Check address range:
      if address < STM32_FLASH_BASE:
         raise STLinkException('Flashing below flash start')
      if address + len(content) > STM32_FLASH_BASE + flashsize:
         raise STLinkException('Flashing above flash size')
      if address & 1 == 1:
         raise STLinkException('Unaligned flash')
      if len(content) & 1 == 1:
         print('unaligned length, padding with zero')
         content += bytes([0])
      if address & (pagesize - 1) != 0:
         raise STLinkException('Address not aligned with pagesize')
      
      # erase required space
      sectors = calcSectors(address, len(content))
      print('erasing {0} sectors'.format(len(sectors)))
      for sector, secsize in sectors:
         print('erasing sector {0} of size {1}'.format(sector, secsize))
         self.eraseFlashSector(sector)

      # program pages:
      self.initFlashLoader()
      self.unlockFlashIf()
      # TODO

      self.lockFlash()

      # verfify program:
      self.verifyFlash(address, content)
   def eraseFlashSector(self, sector):
      self.waitFlashBusy()
      self.unlockFlashIf()
      self.writeFlashCrSnb(sector)
      self.setFlashCrStart()
      self.waitFlashBusy()
      self.lockFlash()
   def eraseFlash(self):
      self.waitFlashBusy()
      self.unlockFlashIf()
      self.setFlashCrMer()
      self.setFlashCrStart()
      self.waitFlashBusy()
      self.clearFlashCrMer()
      self.lockFlash()
   def verifyFlash(self, address, content):
      device_content = self.readFlash(address, len(content))
      ok = content == device_content
      print('Verify:', ok)
   def readFlash(self, address, size):
      print('reading', address, size)
      offset = 0
      tmp_size = 0x1800
      t1 = time.time()
      image = bytes()
      while offset < size:
         # Correct for last page:
         if offset + tmp_size > size:
            tmp_size = size - offset

         # align size to 4 bytes:
         aligned_size = tmp_size
         while aligned_size % 4 != 0:
            aligned_size += 1

         mem = self.read_mem32(address + offset, aligned_size)
         image += mem[:tmp_size]

         # indicate progress:
         print('.', end='', flush=True)

         # increase for next piece:
         offset += tmp_size
      t2 = time.time()
      assert size == len(image)
      print('done! {0} bytes/second'.format(size/(t2-t1)))
      return image
   def initFlashLoader(self):
      # TODO: support other loader code.
      self.write_mem32(STM32_SRAM_BASE, loader_code_stm32f4)
      self.bufAddress = STM32_SRAM_BASE + len(loader_code_stm32f4)
   def readFlashSr(self):
      return self.read_debug32(FLASH_F4_SR)
   def readFlashCr(self):
      return self.read_debug32(FLASH_F4_CR)
   def writeFlashCrSnb(self, sector):
      x = self.readFlashCr()
      x &= ~FLASH_F4_CR_SNB_MASK
      x |= sector << FLASH_F4_CR_SNB
      x |= 1 << FLASH_F4_CR_SER
      self.write_debug32(FLASH_F4_CR, x)
   def setFlashCrMer(self):
      x = self.readFlashCr()
      x |= 1 << FLASH_CR_MER
      self.write_debug32(FLASH_F4_CR, x)
   def clearFlashCrMer(self):
      x = self.readFlashCr()
      x &= ~(1 << FLASH_CR_MER)
      self.write_debug32(FLASH_F4_CR, x)
   def setFlashCrStart(self):
      x = self.readFlashCr()
      x |= 1 << FLASH_F4_CR_START
      self.write_debug32(FLASH_F4_CR, x)
   def isFlashBusy(self):
      mask = 1 << FLASH_F4_SR_BSY
      sr = self.readFlashSr()
      return sr & mask == mask
   def waitFlashBusy(self):
      """ block until flash operation completes. """
      while self.isFlashBusy():
         pass
   def isFlashLocked(self):
      cr = self.readFlashCr()
      mask = 1 << FLASH_F4_CR_LOCK
      return cr & mask == mask
   def unlockFlashIf(self):
      if self.isFlashLocked():
         print('unlocking')
         self.write_debug32(FLASH_F4_KEYR, FLASH_KEY1)
         self.write_debug32(FLASH_F4_KEYR, FLASH_KEY2)
         if self.isFlashLocked():
            raise STLinkException('Failed to unlock')
   def lockFlash(self):
      print('locking')
      x = self.readFlashCr() | (1 << FLASH_F4_CR_LOCK)
      self.write_debug32(FLASH_F4_CR, x)

   # Helper 1 functions:
   def write_debug32(self, address, value):
      cmd = bytearray(16)
      cmd[0:2] = DEBUG_COMMAND, JTAG_WRITEDEBUG_32BIT
      cmd[2:6] = struct.pack('<I', address)
      cmd[6:10] = struct.pack('<I', value)
      self.send_recv(cmd, 2)
   def read_debug32(self, address):
      cmd = bytearray(16)
      cmd[0:2] = DEBUG_COMMAND, JTAG_READDEBUG_32BIT
      cmd[2:6] = struct.pack('<I', address) # pack into u32 little endian
      reply = self.send_recv(cmd, 8)
      return struct.unpack('<I', reply[4:8])[0]
   def write_reg(self, reg, value):
      cmd = bytearray(16)
      cmd[0:3] = DEBUG_COMMAND, DEBUG_WRITEREG, reg
      cmd[3:7] = struct.pack('<I', value)
      r = self.send_recv(cmd, 2)
   def read_reg(self, reg):
      cmd = bytearray(16)
      cmd[0:3] = DEBUG_COMMAND, DEBUG_READREG, reg
      reply = self.send_recv(cmd, 4)
      return struct.unpack('<I', reply)[0]
   def write_mem32(self, address, content):
      assert len(content) % 4 == 0
      cmd = bytearray(16)
      cmd[0:2] = DEBUG_COMMAND, DEBUG_WRITEMEM_32BIT
      cmd[2:6] = struct.pack('<I', address)
      cmd[6:8] = struct.pack('<H', len(content))
      self.send_recv(cmd)
      self.send_recv(content)
   def read_mem32(self, address, length):
      assert length % 4 == 0
      cmd = bytearray(16)
      cmd[0:2] = DEBUG_COMMAND, DEBUG_READMEM_32BIT
      cmd[2:6] = struct.pack('<I', address)
      cmd[6:8] = struct.pack('<H', length) # uint16
      reply = self.send_recv(cmd, length) # expect memory back!
      return reply

   # Helper 2 functions:
   def send_recv(self, tx, rxsize=0):
      """ Helper function that transmits and receives data in bulk mode. """
      # TODO: we could use here the non-blocking libusb api.
      tx = bytes(tx)
      #assert len(tx) == 16
      self.devHandle.bulkWrite(2, tx) # write to endpoint 2
      if rxsize > 0:
         return self.devHandle.bulkRead(1, rxsize) # read from endpoint 1

knownChipIds = {0x1: 'x'}

if __name__ == '__main__':
   # Test program
   sl = STLink()
   sl.open()
   print('version:', sl.Version)
   print('mode before doing anything:', sl.CurrentModeString)
   if sl.CurrentMode == DFU_MODE:
      sl.exitDfuMode()
   sl.enterSwdMode()
   print('mode after entering swd mode:', sl.CurrentModeString)

   i = sl.ChipId
   if i in knownChipIds:
      print('chip id: 0x{0:X} -> {1}'.format(i, knownChipIds[i]))
   else:
      print('chip id: 0x{0:X}'.format(i))
   print('cpu: {0}'.format(sl.CpuId))

   print('status: {0}'.format(sl.status()))

   #time.sleep(2.2)

   # test registers:
   sl.write_reg(3, 0x1337)
   sl.write_reg(2, 0x1332)
   assert sl.read_reg(3) == 0x1337
   assert sl.read_reg(2) == 0x1332

   sl.exitDebugMode()
   print('mode at end:', sl.CurrentModeString)

   sl.close()