You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
344 lines
10 KiB
344 lines
10 KiB
# SPDX-FileCopyrightText: 2020 Foundation Devices, Inc. <hello@foundationdevices.com>
|
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
#
|
|
# SPDX-FileCopyrightText: 2018 Coinkite, Inc. <coldcardwallet.com>
|
|
# SPDX-License-Identifier: GPL-3.0-only
|
|
#
|
|
# (c) Copyright 2018 by Coinkite Inc. This file is part of Coldcard <coldcardwallet.com>
|
|
# and is covered by GPLv3 license found in COPYING.
|
|
#
|
|
# utils.py
|
|
#
|
|
|
|
import gc, sys, ustruct
|
|
from ubinascii import unhexlify as a2b_hex
|
|
from ubinascii import hexlify as b2a_hex
|
|
from ubinascii import a2b_base64, b2a_base64
|
|
|
|
class imported:
|
|
# Context manager that temporarily imports
|
|
# a list of modules.
|
|
# LATER: doubtful this saves any memory when all the code is frozen.
|
|
|
|
def __init__(self, *modules):
|
|
self.modules = modules
|
|
|
|
def __enter__(self):
|
|
# import everything required
|
|
rv = tuple(__import__(n) for n in self.modules)
|
|
|
|
return rv[0] if len(self.modules) == 1 else rv
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
for n in self.modules:
|
|
if n in sys.modules:
|
|
del sys.modules[n]
|
|
|
|
# recovery that tasty memory.
|
|
gc.collect()
|
|
|
|
def pretty_delay(n):
|
|
# decode # of seconds into various ranges, need not be precise.
|
|
if n < 120:
|
|
return '%d seconds' % n
|
|
n /= 60
|
|
if n < 60:
|
|
return '%d minutes' % n
|
|
n /= 60
|
|
if n < 48:
|
|
return '%.1f hours' % n
|
|
n /= 24
|
|
return 'about %d days' % n
|
|
|
|
def pretty_short_delay(sec):
|
|
# precise, shorter on screen display
|
|
if sec >= 3600:
|
|
return '%2dh %2dm %2ds' % (sec //3600, (sec//60) % 60, sec % 60)
|
|
else:
|
|
return '%2dm %2ds' % ((sec//60) % 60, sec % 60)
|
|
|
|
def pop_count(i):
|
|
# 32-bit population count for integers
|
|
# from <https://stackoverflow.com/questions/9829578>
|
|
i = i - ((i >> 1) & 0x55555555)
|
|
i = (i & 0x33333333) + ((i >> 2) & 0x33333333)
|
|
|
|
return (((i + (i >> 4) & 0xF0F0F0F) * 0x1010101) & 0xffffffff) >> 24
|
|
|
|
def get_filesize(fn):
|
|
# like os.path.getsize()
|
|
import uos
|
|
return uos.stat(fn)[6]
|
|
|
|
class HexWriter:
|
|
# Emulate a file/stream but convert binary to hex as they write
|
|
def __init__(self, fd):
|
|
self.fd = fd
|
|
|
|
def __enter__(self):
|
|
self.fd.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, *a, **k):
|
|
self.fd.write(b'\r\n')
|
|
return self.fd.__exit__(*a, **k)
|
|
|
|
def write(self, b):
|
|
self.fd.write(b2a_hex(b))
|
|
|
|
class Base64Writer:
|
|
# Emulate a file/stream but convert binary to Base64 as they write
|
|
def __init__(self, fd):
|
|
self.fd = fd
|
|
self.runt = b''
|
|
|
|
def __enter__(self):
|
|
self.fd.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, *a, **k):
|
|
if self.runt:
|
|
self.fd.write(b2a_base64(self.runt))
|
|
self.fd.write(b'\r\n')
|
|
return self.fd.__exit__(*a, **k)
|
|
|
|
def write(self, buf):
|
|
if self.runt:
|
|
buf = self.runt + buf
|
|
rl = len(buf) % 3
|
|
self.runt = buf[-rl:] if rl else b''
|
|
if rl < len(buf):
|
|
tmp = b2a_base64(buf[:(-rl if rl else None)])
|
|
# library puts in newlines!?
|
|
assert tmp[-1:] == b'\n', tmp
|
|
assert tmp[-2:-1] != b'=', tmp
|
|
self.fd.write(tmp[:-1])
|
|
|
|
def swab32(n):
|
|
# endian swap: 32 bits
|
|
return ustruct.unpack('>I', ustruct.pack('<I', n))[0]
|
|
|
|
def xfp2str(xfp):
|
|
# Standardized way to show an xpub's fingerprint... it's a 4-byte string
|
|
# and not really an integer. Used to show as '0x%08x' but that's wrong endian.
|
|
return b2a_hex(ustruct.pack('<I', xfp)).decode().upper()
|
|
|
|
def str2xfp(txt):
|
|
# Inverse of xfp2str
|
|
return ustruct.unpack('<I', a2b_hex(txt))[0]
|
|
|
|
def problem_file_line(exc):
|
|
# return a string of just the filename.py and line number where
|
|
# an exception occured. Best used on AssertionError.
|
|
import uio, sys, ure
|
|
|
|
tmp = uio.StringIO()
|
|
sys.print_exception(exc, tmp)
|
|
lines = tmp.getvalue().split('\n')[-3:]
|
|
del tmp
|
|
|
|
# convert:
|
|
# File "main.py", line 63, in interact
|
|
# into just:
|
|
# main.py:63
|
|
#
|
|
# on simulator, huge path is included, remove that too
|
|
|
|
rv = None
|
|
for ln in lines:
|
|
mat = ure.match(r'.*"(/.*/|)(.*)", line (.*), ', ln)
|
|
if mat:
|
|
try:
|
|
rv = mat.group(2) + ':' + mat.group(3)
|
|
except: pass
|
|
|
|
return rv or str(exc) or 'Exception'
|
|
|
|
def cleanup_deriv_path(bin_path, allow_star=False):
|
|
# Clean-up path notation as string.
|
|
# - raise exceptions on junk
|
|
# - standardize on 'prime' notation (34' not 34p, or 34h)
|
|
# - assume 'm' prefix, so '34' becomes 'm/34', etc
|
|
# - do not assume /// is m/0/0/0
|
|
# - if allow_star, then final position can be * or *' (wildcard)
|
|
import ure
|
|
from public_constants import MAX_PATH_DEPTH
|
|
try:
|
|
s = str(bin_path, 'ascii').lower()
|
|
except UnicodeError:
|
|
raise AssertionError('must be ascii')
|
|
|
|
# empty string is valid
|
|
if s == '': return 'm'
|
|
|
|
s = s.replace('p', "'").replace('h', "'")
|
|
mat = ure.match(r"(m|m/|)[0-9/']*" + ('' if not allow_star else r"(\*'|\*|)"), s)
|
|
assert mat.group(0) == s, "invalid characters"
|
|
|
|
parts = s.split('/')
|
|
|
|
# the m/ prefix is optional
|
|
if parts and parts[0] == 'm':
|
|
parts = parts[1:]
|
|
|
|
if not parts:
|
|
# rather than: m/
|
|
return 'm'
|
|
|
|
assert len(parts) <= MAX_PATH_DEPTH, "too deep"
|
|
|
|
for p in parts:
|
|
assert p != '' and p != "'", "empty path component"
|
|
if allow_star and '*' in p:
|
|
# - star or star' can be last only (checked by regex above)
|
|
assert p == '*' or p == "*'", "bad wildcard"
|
|
continue
|
|
if p[-1] == "'":
|
|
p = p[0:-1]
|
|
try:
|
|
ip = int(p, 10)
|
|
except:
|
|
ip = -1
|
|
assert 0 <= ip < 0x80000000 and p == str(ip), "bad component: "+p
|
|
|
|
return 'm/' + '/'.join(parts)
|
|
|
|
def match_deriv_path(patterns, path):
|
|
# check for exact string match, or wildcard match (star in last position)
|
|
# - both args must be cleaned by cleanup_deriv_path() already
|
|
# - will accept any path, if 'any' in patterns
|
|
if 'any' in patterns:
|
|
return True
|
|
|
|
for pat in patterns:
|
|
if pat == path:
|
|
return True
|
|
|
|
if pat.endswith("/*") or pat.endswith("/*'"):
|
|
if pat[-1] == "'" and path[-1] != "'": continue
|
|
if pat[-1] == "*" and path[-1] == "'": continue
|
|
|
|
# same hardness so check up to last component of path
|
|
if pat.split('/')[:-1] == path.split('/')[:-1]:
|
|
return True
|
|
|
|
return False
|
|
|
|
class DecodeStreamer:
|
|
def __init__(self):
|
|
self.runt = bytearray()
|
|
|
|
def more(self, buf):
|
|
# Generator:
|
|
# - accumulate into mod-N groups
|
|
# - strip whitespace
|
|
for ch in buf:
|
|
if chr(ch).isspace(): continue
|
|
self.runt.append(ch)
|
|
if len(self.runt) == 128*self.mod:
|
|
yield self.a2b(self.runt)
|
|
self.runt = bytearray()
|
|
|
|
here = len(self.runt) - (len(self.runt) % self.mod)
|
|
if here:
|
|
yield self.a2b(self.runt[0:here])
|
|
self.runt = self.runt[here:]
|
|
|
|
class HexStreamer(DecodeStreamer):
|
|
# be a generator that converts hex digits into binary
|
|
# NOTE: mpy a2b_hex doesn't care about unicode vs bytes
|
|
mod = 2
|
|
def a2b(self, x):
|
|
return a2b_hex(x)
|
|
|
|
class Base64Streamer(DecodeStreamer):
|
|
# be a generator that converts Base64 into binary
|
|
mod = 4
|
|
def a2b(self, x):
|
|
return a2b_base64(x)
|
|
|
|
|
|
def check_firmware_hdr(hdr, binary_size=None, bad_magic_ok=False):
|
|
# Check basics of new firmware being loaded. Return text of error msg if any.
|
|
# - basic checks only: for confused customers, not attackers.
|
|
# - hdr must be a bytearray(FW_HEADER_SIZE+more)
|
|
|
|
from sigheader import FW_HEADER_SIZE, FW_HEADER_MAGIC, FWH_PY_FORMAT
|
|
from sigheader import MK_1_OK, MK_2_OK, MK_3_OK
|
|
from ustruct import unpack_from
|
|
from version import hw_label
|
|
|
|
try:
|
|
assert len(hdr) >= FW_HEADER_SIZE
|
|
|
|
magic_value, timestamp, version_string, pk, fw_size, install_flags, hw_compat = \
|
|
unpack_from(FWH_PY_FORMAT, hdr)[0:7]
|
|
|
|
if bad_magic_ok and magic_value != FW_HEADER_MAGIC:
|
|
# it's just not a firmware file, and that's ok
|
|
return None
|
|
|
|
assert magic_value == FW_HEADER_MAGIC, 'bad magic'
|
|
if binary_size is not None:
|
|
assert fw_size == binary_size, 'truncated'
|
|
|
|
# TODO: maybe show the version string? Warn them that downgrade doesn't work?
|
|
|
|
except Exception as exc:
|
|
return "That does not look like a firmware " \
|
|
"file we would want to use: %s" % exc
|
|
|
|
if hw_compat != 0:
|
|
# check this hardware is compatible
|
|
ok = False
|
|
if hw_label == 'mk1':
|
|
ok = (hw_compat & MK_1_OK)
|
|
elif hw_label == 'mk2':
|
|
ok = (hw_compat & MK_2_OK)
|
|
elif hw_label == 'mk3':
|
|
ok = (hw_compat & MK_3_OK)
|
|
|
|
if not ok:
|
|
return "New firmware doesn't support this version of Coldcard hardware (%s)."%hw_label
|
|
|
|
return None
|
|
|
|
|
|
def clean_shutdown(style=0):
|
|
# wipe SPI flash and shutdown (wiping main memory)
|
|
import callgate
|
|
|
|
try:
|
|
from common import sf
|
|
sf.wipe_most()
|
|
except: pass
|
|
|
|
callgate.show_logout(style)
|
|
|
|
|
|
class UXStateMachine:
|
|
def __init__(self, initial_state, machine_name=None):
|
|
print('UXStateMachine init: initial_state={}'.format(initial_state))
|
|
self.state = initial_state
|
|
self.prev_states = []
|
|
|
|
def goto(self, new_state):
|
|
print('Go from {} to {}'.format(self.state, new_state))
|
|
self.prev_states.append(self.state)
|
|
self.state = new_state
|
|
|
|
# Transition back to previous state
|
|
def goto_prev(self):
|
|
if len(self.prev_state) > 0:
|
|
prev_state = self.prev_state.pop()
|
|
if self.machine_name != None:
|
|
print('{}: Go from {} to PREVIOUS state {}'.format(self.machine_name, self.state, prev_state))
|
|
else:
|
|
print('Go from {} to PREVIOUS state {}'.format(self.state, prev_state))
|
|
self.state = prev_state
|
|
|
|
async def show(self):
|
|
pass
|
|
|
|
# EOF
|
|
|