diff options
-rwxr-xr-x | util/acroterm.py | 363 | ||||
-rwxr-xr-x | util/test_acroterm.py | 113 |
2 files changed, 383 insertions, 93 deletions
diff --git a/util/acroterm.py b/util/acroterm.py index 5b79e059e0..471e861dfd 100755 --- a/util/acroterm.py +++ b/util/acroterm.py @@ -17,11 +17,17 @@ import argparse import atexit import fcntl import glob +import math +import pickle +import os +import re import struct import subprocess import sys +import tempfile import termios import threading +import zlib import serial @@ -106,9 +112,7 @@ class Packet(object): MAGIC = 0xc0 END_MAGIC = 0xc1 - DATA_LEN_OFFSET = 10 - CRC_OFFSET = 12 - HEADER_LEN = 13 + FORMAT = struct.Struct('<x3BI2HB') # Dict of struct handlers indexed by type. Add with set_struct_handler(). # Handlers take (struct type index, bytearray) and return string. @@ -130,11 +134,14 @@ class Packet(object): self.errors = [] self.next_seq = None self.last_timestamp = 0 + self.channel = None + self.data_len = None def reset(self): """Reset the packet state.""" self.data = bytearray() self.expect_len = 0 + self.timestr = '' def get_decoded(self): """Return the last decoded packet, or an empty string if none.""" @@ -150,24 +157,22 @@ class Packet(object): def add_byte(self, b): """Add a byte to the packet. Returns True if the byte was consumed.""" - if not self.expect_len: # Not in a packet - if b != Packet.MAGIC: + if b != self.MAGIC: return False # Now starting a packet - self.expect_len = Packet.HEADER_LEN + self.expect_len = self.FORMAT.size self.data.append(b) if len(self.data) == self.expect_len: - if self.expect_len == Packet.HEADER_LEN: + if self.expect_len == self.FORMAT.size: if not self.validate_header(): self.reset() else: - self.decode_packet() - self.reset() + self.check_trailer_and_decode_packet() return True @@ -181,24 +186,58 @@ class Packet(object): self.data = self.data[count:] return d + def unpack_ph(self, header): + """Class specific header unpack structure + + Saves class unique fields in the instance, returns the common ones to + the caller. + """ + (b1, chan, self.const_str_len, time_lo, + time_hi, data_len, crc) = self.FORMAT.unpack(header) + + self.param_count = b1 >> 5 + return b1, chan, time_lo, time_hi, data_len, crc + def validate_header(self): """Validate the packet header. Returns: True if there is more data needed. """ - if self.expect_len != Packet.HEADER_LEN: + if self.expect_len != self.FORMAT.size: return False - if self.data[Packet.CRC_OFFSET] != crc8(self.data[:Packet.CRC_OFFSET]): - print('Bad packet') + header = self.data[:self.FORMAT.size] + b1, chan, time_lo, time_hi, data_len, crc = self.unpack_ph(header) + + if crc != crc8(header[:-1]): + print('Bad packet size') return False - data_len = struct.unpack('=H', self.data[Packet.DATA_LEN_OFFSET: - Packet.DATA_LEN_OFFSET + 2])[0] + self.channel = chan + + timestamp = time_hi << 32 | time_lo + if timestamp < self.last_timestamp: + # Reboot will restart the sequence at 0 + self.next_seq = 0 + self.last_timestamp = timestamp + self.timestr = '%d.%06d' % (timestamp // 1000000, timestamp % 1000000) + + # Flag dropped packets + if b1 & 0x10: + self.errors.append('(sender dropped packet(s))') + + sequence = b1 & 0x0f + if self.next_seq is not None and sequence != self.next_seq: + self.errors.append('(missing packet(s)); got %d expect %d' % + (sequence, self.next_seq)) + self.next_seq = (sequence + 1) % 16 + + self.data_len = data_len + if not data_len: # No data; just header - self.decode_packet() + self.check_trailer_and_decode_packet() return False self.expect_len += data_len + 1 # +1 for packet end @@ -282,34 +321,24 @@ class Packet(object): return dout - def decode_packet(self): - """Decode a packet, now that it's all shown up.""" + def check_trailer_and_decode_packet(self): + 'Verify trailer presence and decode packet' + # Consume packet header. + self.next_data(self.FORMAT.size) - # Decode header - header = self.next_data(Packet.HEADER_LEN) - (b1, channel, const_str_len, time_lo, time_hi, - data_len) = struct.unpack('=xBBBLHHx', header) - sequence = b1 & 0x0f - sender_dropped = b1 & 0x10 - param_count = b1 >> 5 - timestamp = time_hi << 32 | time_lo - - # If timestamp decreased, board rebooted - if timestamp < self.last_timestamp: - # Reboot will restart the sequence at 0 - self.next_seq = 0 - self.last_timestamp = timestamp + # Flag (but keep processing) bad data: header is still fine. + if self.data_len and ((not self.data) or self.data[-1] != + self.END_MAGIC): + self.errors.append( + '(packet data missing end magic; may be corrupt)') + self.decode_packet() + self.reset() - # Flag dropped packets - if sender_dropped: - self.errors.append('(sender dropped packet(s))') - if self.next_seq is not None and sequence != self.next_seq: - self.errors.append('(missing packet(s)); got %d expect %d' % - (sequence, self.next_seq)) - self.next_seq = (sequence + 1) % 16 + def decode_packet(self): + """Decode a packet, now that it's all shown up.""" + channel = self.channel - self.decoded += '[%d.%06d/' % ( - timestamp // 1000000, timestamp % 1000000) + self.decoded += '[%s/' % self.timestr if channel == CMSG_CHAN_DEFAULT: self.decoded += '??' elif channel == CMSG_CHAN_INTERRUPT: @@ -328,23 +357,19 @@ class Packet(object): self.decoded += '%02x' % channel self.decoded += ']' - # Flag (but print) bad data; header is still fine - if data_len and ((not self.data) or self.data[-1] != Packet.END_MAGIC): - self.errors.append( - '(packet data missing end magic; may be corrupt)') - # Decode data const_str = '' param_decoded = [] try: - if const_str_len: - const_str = self.next_data(const_str_len).decode( - encoding='utf-8', errors='replace') + if self.const_str_len: + const_str = self.next_data(self.const_str_len + ).decode(encoding='utf-8', errors='replace') + param_count = self.param_count if param_count: # Unpack format nibbles formats = [] - fbuf = self.next_data((param_count + 1) // 2) + fbuf = self.next_data((self.param_count + 1) // 2) for f in fbuf: formats += [f & 0xf, f >> 4] @@ -388,9 +413,6 @@ class Packet(object): if param_decoded: self.decoded += ' ' + ' '.join(param_decoded) - # Clear packet for next time - self.reset() - # ------------------------------------------------------------------------------ # Packet struct handlers @@ -635,7 +657,10 @@ class Console(object): @staticmethod def write(text): """Write string""" - sys.stdout.write(text) + try: + sys.stdout.write(text) + except UnicodeEncodeError: + sys.stdout.write(str(bytes(text, 'utf-8'))) sys.stdout.flush() def cancel(self): @@ -729,8 +754,14 @@ class Acroterm(object): self.console.setup() # Start logging - if self.log_filename: - self.log_file = open(self.log_filename, 'wb') + if self.log_filename is not None: + if self.log_filename: + self.log_file = open(self.log_filename, 'wb') + else: + fd, name = tempfile.mkstemp(prefix='acroterm.', suffix='.log', + dir='/tmp') + self.log_file = os.fdopen(fd, 'wb') + notice('Saving log in %s' % name) else: self.log_file = None @@ -778,6 +809,8 @@ class Acroterm(object): else: self.timer = None + self.cr50_mode = type(self).__name__ == 'Cr50Term' + def stop(self): """Set flag to stop worker threads""" self.alive = False @@ -846,6 +879,7 @@ class Acroterm(object): Args: data: a byte array, the received chunk. """ + self.log_file.write(data) # Scan for magic sequences for c in data: @@ -857,17 +891,17 @@ class Acroterm(object): if self.packet.add_byte(c): e = self.packet.get_errors() if e: - self.log_file.write(('\n'.join(e) + '\n').encode()) self.console.write(self.color('error')) self.console.write('\n'.join(e)) self.console.write(self.color('default')+'\n') d = self.packet.get_decoded() if d: - self.log_file.write((d + '\n').encode()) if not self.coverage_filter_active: self.console.write(self.color('normal')) self.console.write(d) - self.console.write(self.color('default')+'\n') + self.console.write(self.color('default')) + if not self.cr50_mode: + self.console.write('\n') self.process_line(d) continue @@ -876,8 +910,6 @@ class Acroterm(object): # start/end byte values. c = chr(c) - self.log_file.write(c.encode()) - if not self.coverage_filter_active: self.console.write(self.color('not_packet')) self.console.write(c) @@ -902,7 +934,7 @@ class Acroterm(object): if data: self.process_output(data) - except serial.SerialException: + except (serial.SerialException, OSError): self.stop() raise @@ -946,8 +978,8 @@ class Acroterm(object): line = raw_line.decode() if self.cmd_filter and self.cmd_filter in line: line = '#' - sys.stdout.write(self.color('program') + line + - self.color('default')) + sys.stdout.write(self.color('program') + + line + self.color('default')) sys.stdout.flush() elif proc.poll() is not None: break @@ -972,12 +1004,181 @@ class Acroterm(object): notice('Hit timeout') self.stop() +int_param = re.compile(r'^[0-9.\-]*([l]{0,2}|z)[Xcdux]') +split_int = re.compile(r'^([0-9]+)\.([0-9]+)') +str_param = re.compile(r'^[0-9.\-]*s') +ptr_param = re.compile(r'^p[hPT]') +ll_struct = struct.Struct('<q') +ull_struct = struct.Struct('<Q') +int_struct = struct.Struct('<i') +uint_struct = struct.Struct('<I') +short_struct = struct.Struct('<H') + +class Cr50Packet(Packet): + 'Cr50 specific packet class, handles messages of Cr50 format' + + FORMAT = struct.Struct('<x2BIHBHB') + MAGIC = 0xc2 + + def __init__(self, strings): + self.strings = strings + super().__init__() + + def unpack_ph(self, header): + """Class specific header unpack structure + + Saves class unique fields in the instance, returns the common ones to + the caller. + """ + (b1, chan, time_lo, time_hi, data_len, + self.str_index, crc) = self.FORMAT.unpack(header) + + return b1, chan, time_lo, time_hi, data_len, crc + + def process_format(self, str_index, data): + """Process C format string converting it into Python format string + + Args: + str_index: int, index of the source code format string in the + strings list. + data: binary blob containing parameters matching the format + string. + + Returns: + A Python format string suitable for printing + """ + fstring = self.strings[str_index] + if fstring.startswith('[^T'): + fstring = '[%s %s]\n' % (self.timestr, fstring[3:]) + tokens = fstring.split('%') + text = tokens[0] + for token in tokens[1:]: + if token[0] == '%': + text += token + continue + m = int_param.search(token) + if m: + fend = m.span()[1] + fmt = token[:fend] + # Python doesn't know what z means in format. + fmt = fmt.replace('z', '') + rest = token[fend:] + signed = fmt[-1] == 'd' + if m.group(1) == 'll': + fmt = fmt.replace('ll', '', 1) + if signed: + s = ll_struct + else: + s = ull_struct + else: + if signed: + s = int_struct + else: + s = uint_struct + value = s.unpack_from(data)[0] + data = data[s.size:] + m = split_int.search(token) + if m: + # This is a complex format spec used in EC codebase. + exp = m.groups()[1] + fvalue = value / math.pow(10, int(exp)) + text += ('%f' % fvalue) + rest + else: + text += ('%' + fmt + rest) % value + continue + if str_param.search(token): + if data[0] == 0xff: + # This is a function name. + index = uint_struct.unpack_from(data[1:])[0] + st = self.strings[index] + data = data[5:] + else: + eos = data.find(0) + param = data[:eos].decode('ascii') + st = param + data = data[eos + 1:] + text += ('%' + token) % st + continue + m = ptr_param.search(token) + if m: + rest = token[m.span()[1]:] + if token[1] == 'P': + s = uint_struct + fmt = '%08x' + elif token[1] == 'T': + s = ull_struct + fmt = '%d' + elif token[1] == 'h': + size = short_struct.unpack_from(data)[0] + data = data[short_struct.size:] + text += ' '.join('%02x' % x for x in data[:size]) + rest + data = data[size:] + continue + else: + notice('unprocessed format %%%s' % token) + continue + value = s.unpack_from(data)[0] + if token[1] == 'T' and value == 0: + # current time, take it from the packet header. + value = self.last_timestamp + text += (fmt + rest) % value + data = data[s.size:] + continue + return text + + def decode_packet(self): + """Decode a packet, now that it's all shown up. + + Sets self.decoded to the text of the decoded packet. + """ + + text = self.process_format(self.str_index, + self.next_data(self.data_len)) + if self.data_len: + self.next_data(1) # Consume the trailing byte. + self.decoded = text + + +class Cr50Term(Acroterm): + 'Cr50 specific Acroterm. Uses Cr50Packet instead of Packet' + + @staticmethod + def parse_blob(cr50_str_blob): + """Read and decode the format string blob + + Args: + cr50_str_blob: name of the file containing the blob prepared by + util_precompile.py. + + Returns: + A list of strings, placed at their appropriate locations such that + string index in the packet sent by Cr50 matches the format string + it was generated with. + + Raises: + FileNotFoundError if the file is not found. + """ + try: + zipped = open(cr50_str_blob, 'rb').read() + except FileNotFoundError: + fatal('Blob file %s not found' % cr50_str_blob) + + pickled = zlib.decompress(zipped) + dump = pickle.loads(pickled) + return dump.split('\0') + + def __init__(self, args): + super().__init__(args) + strings = self.parse_blob(args.cr50_str_blob) + self.packet = Cr50Packet(strings) + # ------------------------------------------------------------------------------ -def main(): - """Main function. +def get_args(): + """Prepare argument parser and retrieve command line arguments. - Parse command line arguments and start operation accordingly. + Returns the parser object with a namespace with all present optional + arguments set. """ parser = argparse.ArgumentParser(description='Acropora terminal') @@ -1015,7 +1216,7 @@ def main(): group.add_argument( '--remote-fail', - help='Exit with error if remote emits this string' + help='Exit with error if remote emits this string ' '(default=%(default)s)', metavar='MATCH', default='***HANGUP-FAIL***') @@ -1024,9 +1225,9 @@ def main(): group.add_argument( '--log', - help='Logfile (default=%(default)s)', + help='Logfile (default=/tmp/acroterm.*.log)', metavar='LOGFILE', - default='acroterm.log') + default='') group.add_argument( '--color', @@ -1062,7 +1263,24 @@ def main(): dest='cmd_filter', action='store_const', const=None) - args = parser.parse_args() + parser.add_argument( + '--cr50_mode', + help='Use Cr50 packet format', + action='store_true') + + parser.add_argument( + '--cr50_str_blob', + help='Binary blob containing Cr50 strings (default=%(default)s)', + default=os.path.normpath(os.path.join(os.path.dirname(__file__), + '../build/cr50/RW/str_blob'))) + return parser.parse_args() + +def main(): + """Main function. + + Parse command line arguments and start operation accordingly. + """ + args = get_args() # Look for exactly one matching TTY, unless running target on host if args.tty == 'host': @@ -1084,7 +1302,10 @@ def main(): args.log = None # Start the terminal - term = Acroterm(args) + if args.cr50_mode: + term = Cr50Term(args) + else: + term = Acroterm(args) # Keep running until the other side exits try: diff --git a/util/test_acroterm.py b/util/test_acroterm.py index 15249daa68..ed1092524c 100755 --- a/util/test_acroterm.py +++ b/util/test_acroterm.py @@ -10,24 +10,35 @@ import unittest import acroterm -class TestPacket(unittest.TestCase): - 'Test various packet failures' +def report_packet_errors(errors): + """Write packet error strings into stderr + + Called if number of error strings does not match test expectations. + """ + if not errors: + return + sys.stderr.write('unexpected error set:\n') + for error in errors: + sys.stderr.write('%s\n' % error) + - def report_packet_errors(self): - """Write packet error strings into stderr +def process_samples(packet, samples): + 'Submit various data samples to packet and validate results' - Called if number of error strings does not match test expectations. - """ - if not self.p.errors: - return - sys.stderr.write('unexpected error set:\n') - for error in self.p.errors: - sys.stderr.write('%s\n' % error) + for data, handler in samples: + for b in data: + packet.add_byte(b) + handler(data) + packet.errors = [] + packet.get_decoded() + +class TestPacket(unittest.TestCase): + 'Test various base Acropora packet failures' def good_packet(self, packet): 'Verify good packet handling' if self.p.errors: - self.report_packet_errors() + report_packet_errors(self.p.errors) self.fail() d = self.p.get_decoded() self.assertEqual(d, '[13581998.891532/t1]') @@ -36,7 +47,7 @@ class TestPacket(unittest.TestCase): def bad_seq(self, _): 'Verify bad sequence number handling' if len(self.p.errors) != 1: - self.report_packet_errors() + report_packet_errors(self.p.errors) self.fail() self.assertEqual(self.p.errors[0], '(missing packet(s)); got 0 expect 1') @@ -46,13 +57,13 @@ class TestPacket(unittest.TestCase): 'Verify both cases of packet with data' if packet[-1] != 0xc1: if len(self.p.errors) != 1: - self.report_packet_errors() + report_packet_errors(self.p.errors) self.fail() self.assertEqual(self.p.errors[0], '(packet data missing end magic; may be corrupt)') else: if self.p.errors: - self.report_packet_errors() + report_packet_errors(self.p.errors) self.fail() d = self.p.get_decoded() self.assertEqual(d, '[13590588.826124/t1] 67305985->134678021') @@ -63,7 +74,7 @@ class TestPacket(unittest.TestCase): # Tuple of two-tuples, the first element in the tuple pair is the # packet to send to the Packet class, the second element is the # function to invoke once the packet has been sent. - packets = ( + samples = ( ((0xc0, 0, 1, 0, 12, 34, 56, 78, 90, 12, 0, 0, 33), self.good_packet), ((0xc0, 0, 1, 0, 12, 34, 56, 78, 91, 12, 0, 0, 55), @@ -78,12 +89,70 @@ class TestPacket(unittest.TestCase): self.with_data), ) self.p = acroterm.Packet() - for packet, handler in packets: - for b in packet: - self.p.add_byte(b) - handler(packet) - self.p.errors = [] - self.p.get_decoded() + process_samples(self.p, samples) + +class TestCr50Packet(unittest.TestCase): + 'Test various base Acropora packet failures' + + def good_packet(self, packet): + 'Verify good packet handling' + if self.p.errors: + report_packet_errors(self.p.errors) + self.fail() + d = self.p.get_decoded() + self.assertEqual(d, 'string 0') + self.assertEqual(self.p.next_seq, (packet[1] & 0xf) + 1) + + def bad_seq(self, _): + 'Verify bad sequence number handling' + if len(self.p.errors) != 1: + report_packet_errors(self.p.errors) + self.fail() + self.assertEqual(self.p.errors[0], + '(missing packet(s)); got 0 expect 1') + self.assertEqual(self.p.next_seq, 1) + + def with_data(self, packet): + 'Verify both cases of packet with data' + d = self.p.get_decoded() + if packet[-1] != 0xc1: + if len(self.p.errors) != 1: + report_packet_errors(self.p.errors) + self.fail() + self.assertEqual(self.p.errors[0], + '(packet data missing end magic; may be corrupt)') + else: + if self.p.errors: + report_packet_errors(self.p.errors) + self.fail() + self.assertEqual(d, 'string ext param 230') + + def test_acorpora_packet(self): + 'Test various good and bad packets' + + strings = ['string 0', 'string 1', 'string %s %d'] + # Tuple of two-tuples, the first element in the tuple pair is the + # packet to send to the Packet class, the second element is the + # function to invoke once the packet has been sent. + samples = ( + ((0xc2, 0, 1, 12, 34, 56, 78, 90, 12, 0, 0, 0, 79), + self.good_packet), + ((0xc2, 0, 1, 12, 34, 56, 78, 90, 13, 0, 0, 0, 89), + self.bad_seq), + ((0xc2, 1, 1, 12, 34, 56, 78, 90, 14, 0, 0, 0, 124), + self.good_packet), + # Packet with valid data, but with an incorrect trailing character. + ((0xc2, 2, 1, 12, 34, 56, 78, 90, 15, 15, 2, 0, 38, + 101, 120, 116, 32, 112, 97, 114, 97, 109, 0, 230, 0, 0, 0, 0, 0), + self.with_data), + # A valid packet with data. + ((0xc2, 3, 1, 12, 34, 56, 78, 90, 15, 15, 2, 0, 57, + 101, 120, 116, 32, 112, 97, 114, 97, 109, 0, 230, 0, 0, 0, 0, + 0xc1), + self.with_data), + ) + self.p = acroterm.Cr50Packet(strings) + process_samples(self.p, samples) if __name__ == '__main__': unittest.main() |