#! /usr/bin/env python3 # Script to parse spec output CSVs and produce C files. # Released by lisa neigut under CC0: # https://creativecommons.org/publicdomain/zero/1.0/ # # Reads from stdin, outputs C header or body file. # # Standard message types: # msgtype,<msgname>,<value>[,<option>] # msgdata,<msgname>,<fieldname>,<typename>,[<count>][,<option>] # # TLV types: # tlvtype,<tlvstreamname>,<tlvname>,<value>[,<option>] # tlvdata,<tlvstreamname>,<tlvname>,<fieldname>,<typename>,[<count>][,<option>] # # Subtypes: # subtype,<subtypename> # subtypedata,<subtypename>,<fieldname>,<typename>,[<count>] from argparse import ArgumentParser, REMAINDER from collections import OrderedDict import copy import fileinput from mako.template import Template import os import re import sys # Generator to give us one line at a time. def next_line(args, lines): if lines is None: lines = fileinput.input(args) for i, line in enumerate(lines): yield i + 1, line.strip() # Class definitions, to keep things classy class Field(object): def __init__(self, name, type_obj, extensions=[], field_comments=[], optional=False): self.name = name self.type_obj = type_obj self.count = 1 self.len_field_of = None self.len_field = None self.implicit_len = False self.extension_names = extensions self.is_optional = optional self.field_comments = field_comments def __deepcopy__(self, memo): deepcopy_method = self.__deepcopy__ self.__deepcopy__ = None field = copy.deepcopy(self, memo) self.__deepcopy__ = deepcopy_method field.type_obj = self.type_obj return field def add_count(self, count): self.count = int(count) def add_len_field(self, len_field): self.count = False # we cache our len-field's name self.len_field = len_field.name # the len-field caches our name len_field.len_field_of = self.name def add_implicit_len(self): self.count = False self.implicit_len = True def is_array(self): return self.count > 1 def is_varlen(self): return not self.count def is_implicit_len(self): return self.implicit_len def is_extension(self): return bool(self.extension_names) def size(self, implicit_expression=None): if self.count: return self.count if self.len_field: return self.len_field assert self.is_implicit_len() assert implicit_expression return implicit_expression def needs_context(self): """ A field needs a context if it's varsized """ return self.is_varlen() or self.type_obj.needs_context() def arg_desc_to(self): if self.len_field_of: return '' type_name = self.type_obj.type_name() if self.is_array(): return ', const {} {}[{}]'.format(type_name, self.name, self.count) if self.type_obj.is_assignable() and not self.is_varlen(): name = self.name if self.is_optional: name = '*' + name return ', {} {}'.format(type_name, name) if self.is_varlen() and self.type_obj.is_varsize(): return ', const {} **{}'.format(type_name, self.name) return ', const {} *{}'.format(type_name, self.name) def arg_desc_from(self): type_name = self.type_obj.type_name() if self.type_obj.is_const_ptr_ptr_type(): return ', const {} **{}'.format(type_name, self.name) if self.len_field_of: return '' if self.is_array(): return ', {} {}[{}]'.format(type_name, self.name, self.count) ptrs = '*' if self.is_varlen() or self.is_optional or self.type_obj.is_varsize(): ptrs += '*' if self.is_varlen() and self.type_obj.is_varsize(): ptrs += '*' return ', {} {}{}'.format(type_name, ptrs, self.name) class FieldSet(object): def __init__(self): self.fields = OrderedDict() self.len_fields = {} def add_data_field(self, field_name, type_obj, count=1, extensions=[], comments=[], optional=False, implicit_len_ok=False): field = Field(field_name, type_obj, extensions=extensions, field_comments=comments, optional=optional) if bool(count): try: field.add_count(int(count)) except ValueError: if count in self.fields: len_field = self.find_data_field(count) field.add_len_field(len_field) self.len_fields[len_field.name] = len_field else: # '...' means "rest of TLV" assert implicit_len_ok assert count == '...' field.add_implicit_len() # You can't have any fields after an implicit-length field. if len(self.fields) != 0: assert not self.fields[next(reversed(self.fields))].is_implicit_len() self.fields[field_name] = field def find_data_field(self, field_name): return self.fields[field_name] def get_len_fields(self): return list(self.len_fields.values()) def has_len_fields(self): return bool(self.len_fields) def needs_context(self): return any([field.needs_context() or field.is_optional for field in self.fields.values()]) def singleton(self): """Return the single message, if there's only one, otherwise None""" if len(self.fields) == 1: return next(iter(self.fields.values())) return None class Type(FieldSet): assignables = [ 'u8', 'u16', 'u32', 'u64', 'tu16', 'tu32', 'tu64', 'bool', 'amount_sat', 'amount_msat', 'errcode_t', 'bigsize', 'varint' ] typedefs = [ 'u8', 'u16', 'u32', 'u64', 'bool', 'secp256k1_ecdsa_signature', 'secp256k1_ecdsa_recoverable_signature', 'wirestring', 'errcode_t', 'bigsize', 'varint', ] truncated_typedefs = [ 'tu16', 'tu32', 'tu64', ] # Externally defined variable size types (require a context) varsize_types = [ 'peer_features', 'gossip_getnodes_entry', 'gossip_getchannels_entry', 'failed_htlc', 'existing_htlc', 'utxo', 'bitcoin_tx', 'wirestring', 'per_peer_state', 'bitcoin_tx_output', 'exclude_entry', 'fee_states', 'onionreply', 'feature_set', 'onionmsg_path', 'route_hop', 'tx_parts', ] # Some BOLT types are re-typed based on their field name # ('fieldname partial', 'original type', 'outer type'): ('true type', 'collapse array?') name_field_map = { ('txid', 'sha256'): ('bitcoin_txid', False), ('amt', 'u64'): ('amount_msat', False), ('msat', 'u64'): ('amount_msat', False), ('satoshis', 'u64'): ('amount_sat', False), ('node_id', 'pubkey', 'channel_announcement'): ('node_id', False), ('node_id', 'pubkey', 'node_announcement'): ('node_id', False), ('temporary_channel_id', 'u8'): ('channel_id', True), ('secret', 'u8'): ('secret', True), ('preimage', 'u8'): ('preimage', True), } # For BOLT specified types, a few type names need to be simply 'remapped' # 'original type': 'true type' name_remap = { 'byte': 'u8', 'signature': 'secp256k1_ecdsa_signature', 'chain_hash': 'bitcoin_blkid', 'point': 'pubkey', # FIXME: omits 'pad' } # Types that are const pointer-to-pointers, such as chainparams, i.e., # they set a reference to some const entry. const_ptr_ptr_types = [ 'chainparams' ] @staticmethod def true_type(type_name, field_name=None, outer_name=None): """ Returns 'true' type of a given type and a flag if we've remapped a variable size/array type to a single struct (an example of this is 'temporary_channel_id' which is specified as a 32*byte, but we re-map it to a channel_id """ if type_name in Type.name_remap: type_name = Type.name_remap[type_name] if field_name: for t, true_type in Type.name_field_map.items(): if t[0] in field_name and t[1] == type_name: if len(t) == 2 or outer_name == t[2]: return true_type return (type_name, False) def __init__(self, name): FieldSet.__init__(self) self.name, self.is_enum = self.parse_name(name) self.depends_on = {} self.type_comments = [] self.tlv = False def parse_name(self, name): if name.startswith('enum '): return name[5:], True return name, False def add_data_field(self, field_name, type_obj, count=1, extensions=[], comments=[], optional=False): FieldSet.add_data_field(self, field_name, type_obj, count, extensions=extensions, comments=comments, optional=optional) if type_obj.name not in self.depends_on: self.depends_on[type_obj.name] = type_obj def type_name(self): if self.name in self.typedefs: return self.name if self.name in self.truncated_typedefs: return self.name[1:] if self.is_enum: prefix = 'enum ' else: prefix = 'struct ' return prefix + self.struct_name() # We only accelerate the u8 case: it's common and trivial. def has_array_helper(self): return self.name in ['u8'] def struct_name(self): if self.is_tlv(): return self.tlv.struct_name() return self.name def subtype_deps(self): return [dep for dep in self.depends_on.values() if dep.is_subtype()] def is_subtype(self): return bool(self.fields) def is_const_ptr_ptr_type(self): return self.name in self.const_ptr_ptr_types def is_truncated(self): return self.name in self.truncated_typedefs def needs_context(self): return self.is_varsize() def is_assignable(self): """ Generally typedef's and enums """ return self.name in self.assignables or self.is_enum def is_varsize(self): """ A type is variably sized if it's marked as such (in varsize_types) or it contains a field of variable length """ return self.name in self.varsize_types or self.has_len_fields() def add_comments(self, comments): self.type_comments = comments def mark_tlv(self, tlv): self.tlv = tlv def is_tlv(self): return bool(self.tlv) class Message(FieldSet): def __init__(self, name, number, option=[], enum_prefix='wire', struct_prefix=None, comments=[]): FieldSet.__init__(self) self.name = name self.number = number self.enum_prefix = enum_prefix self.option = option[0] if len(option) else None self.struct_prefix = struct_prefix self.enumname = None self.msg_comments = comments self.if_token = None def has_option(self): return self.option is not None def enum_name(self): name = self.enumname if self.enumname else self.name return "{}_{}".format(self.enum_prefix, name).upper() def struct_name(self): if self.struct_prefix: return self.struct_prefix + "_" + self.name return self.name def add_if(self, if_token): self.if_token = if_token class Tlv(object): def __init__(self, name): self.name = name self.messages = {} def add_message(self, tokens, comments=[]): """ tokens -> (name, value[, option]) """ self.messages[tokens[0]] = Message(tokens[0], tokens[1], option=tokens[2:], enum_prefix=self.name, struct_prefix=self.struct_name(), comments=comments) def type_name(self): return 'struct ' + self.struct_name() def struct_name(self): return "tlv_{}".format(self.name) def find_message(self, name): return self.messages[name] def ordered_msgs(self): return sorted(self.messages.values(), key=lambda item: int(item.number)) class Master(object): types = {} tlvs = {} messages = {} extension_msgs = {} inclusions = [] top_comments = [] def add_comments(self, comments): self.top_comments += comments def add_include(self, inclusion): self.inclusions.append(inclusion) def add_tlv(self, tlv_name): if tlv_name not in self.tlvs: self.tlvs[tlv_name] = Tlv(tlv_name) if tlv_name not in self.types: self.types[tlv_name] = Type(tlv_name) return self.tlvs[tlv_name] def add_message(self, tokens, comments=[]): """ tokens -> (name, value[, option])""" self.messages[tokens[0]] = Message(tokens[0], tokens[1], option=tokens[2:], comments=comments) def add_extension_msg(self, name, msg): self.extension_msgs[name] = msg def add_type(self, type_name, field_name=None, outer_name=None): optional = False if type_name.startswith('?'): type_name = type_name[1:] optional = True # Check for special type name re-mapping type_name, collapse_original = Type.true_type(type_name, field_name, outer_name) if type_name not in self.types: self.types[type_name] = Type(type_name) return self.types[type_name], collapse_original, optional def find_type(self, type_name): return self.types[type_name] def find_message(self, msg_name): if msg_name in self.messages: return self.messages[msg_name] if msg_name in self.extension_msgs: return self.extension_msgs[msg_name] return None def find_tlv(self, tlv_name): return self.tlvs[tlv_name] def get_ordered_subtypes(self): """ We want to order subtypes such that the 'no dependency' types are printed first """ subtypes = [s for s in self.types.values() if s.is_subtype()] # Start with subtypes without subtype dependencies sorted_types = [s for s in subtypes if not len(s.subtype_deps())] unsorted = [s for s in subtypes if len(s.subtype_deps())] while len(unsorted): names = [s.name for s in sorted_types] for s in list(unsorted): if all([dependency.name in names for dependency in s.subtype_deps()]): sorted_types.append(s) unsorted.remove(s) return sorted_types def tlv_structs(self): ret = [] for tlv in self.tlvs.values(): for v in tlv.messages.values(): if not v.singleton(): ret.append(v) return ret def find_template(self, options): dirpath = os.path.dirname(os.path.abspath(__file__)) filename = dirpath + '/gen/{}{}_template'.format( 'print_' if options.print_wire else '', options.page) return Template(filename=filename) def post_process(self): """ method to handle any 'post processing' that needs to be done. for now, we just need match up types to TLVs """ for tlv_name, tlv in self.tlvs.items(): if tlv_name in self.types: self.types[tlv_name].mark_tlv(tlv) def write(self, options, output): template = self.find_template(options) enum_sets = [] if len(self.messages.values()) != 0: enum_sets.append({ 'name': options.enum_name, 'set': self.messages.values(), }) stuff = {} stuff['top_comments'] = self.top_comments stuff['options'] = options stuff['idem'] = re.sub(r'[^A-Z]+', '_', options.header_filename.upper()) stuff['header_filename'] = options.header_filename stuff['includes'] = self.inclusions stuff['enum_sets'] = enum_sets subtypes = self.get_ordered_subtypes() stuff['structs'] = subtypes + self.tlv_structs() stuff['tlvs'] = self.tlvs # We leave out extension messages in the printing pages. Any extension # fields will get printed under the 'original' message, if present if options.print_wire: stuff['messages'] = list(self.messages.values()) else: stuff['messages'] = list(self.messages.values()) + list(self.extension_msgs.values()) stuff['subtypes'] = subtypes print(template.render(**stuff), file=output) def main(options, args=None, output=sys.stdout, lines=None): genline = next_line(args, lines) comment_set = [] token_name = None # Create a new 'master' that serves as the coordinator for the file generation master = Master() for i in options.include: master.add_include('#include <{}>'.format(i)) try: while True: ln, line = next(genline) tokens = line.split(',') token_type = tokens[0] if not bool(line): master.add_comments(comment_set) comment_set = [] token_name = None continue if len(tokens) > 2: token_name = tokens[1] if token_type == 'subtype': subtype, _, _ = master.add_type(tokens[1]) subtype.add_comments(list(comment_set)) comment_set = [] elif token_type == 'subtypedata': subtype = master.find_type(tokens[1]) if not subtype: raise ValueError('Unknown subtype {} for data.\nat {}:{}' .format(tokens[1], ln, line)) type_obj, collapse, optional = master.add_type(tokens[3], tokens[2], tokens[1]) if optional: raise ValueError('Subtypes cannot have optional fields {}.{}\n at {}:{}' .format(subtype.name, tokens[2], ln, line)) if collapse: count = 1 else: count = tokens[4] subtype.add_data_field(tokens[2], type_obj, count, comments=list(comment_set), optional=optional) comment_set = [] elif token_type == 'tlvtype': tlv = master.add_tlv(tokens[1]) tlv.add_message(tokens[2:], comments=list(comment_set)) comment_set = [] elif token_type == 'tlvdata': type_obj, collapse, optional = master.add_type(tokens[4], tokens[3], tokens[1]) if optional: raise ValueError('TLV messages cannot have optional fields {}.{}\n at {}:{}' .format(tokens[2], tokens[3], ln, line)) tlv = master.find_tlv(tokens[1]) if not tlv: raise ValueError('tlvdata for unknown tlv {}.\nat {}:{}' .format(tokens[1], ln, line)) msg = tlv.find_message(tokens[2]) if not msg: raise ValueError('tlvdata for unknown tlv-message {}.\nat {}:{}' .format(tokens[2], ln, line)) if collapse: count = 1 else: count = tokens[5] msg.add_data_field(tokens[3], type_obj, count, comments=list(comment_set), optional=optional, implicit_len_ok=True) comment_set = [] elif token_type == 'msgtype': master.add_message(tokens[1:], comments=list(comment_set)) comment_set = [] elif token_type == 'msgdata': msg = master.find_message(tokens[1]) if not msg: raise ValueError('Unknown message type {}. {}:{}'.format(tokens[1], ln, line)) type_obj, collapse, optional = master.add_type(tokens[3], tokens[2], tokens[1]) if collapse: count = 1 elif len(tokens) < 5: raise ValueError('problem with parsing {}:{}'.format(ln, line)) else: count = tokens[4] # if this is an 'extension' field*, we want to add a new 'message' type # in the future, extensions will be handled as TLV's # # *(in the spec they're called 'optional', but that term is overloaded # in that internal wire messages have 'optional' fields that are treated # differently. for the sake of clarity here, for bolt-wire messages, # we'll refer to 'optional' message fields as 'extensions') # if tokens[5:] == []: msg.add_data_field(tokens[2], type_obj, count, comments=list(comment_set), optional=optional) else: # is one or more extension fields if optional: raise ValueError("Extension fields cannot be optional. {}:{}" .format(ln, line)) orig_msg = msg for extension in tokens[5:]: extension_name = "{}_{}".format(tokens[1], extension) msg = master.find_message(extension_name) if not msg: msg = copy.deepcopy(orig_msg) msg.enumname = msg.name msg.name = extension_name master.add_extension_msg(msg.name, msg) msg.add_data_field(tokens[2], type_obj, count, comments=list(comment_set), optional=optional) # If this is a print_wire page, add the extension fields to the # original message, so we can print them if present. if options.print_wire: orig_msg.add_data_field(tokens[2], type_obj, count=count, extensions=tokens[5:], comments=list(comment_set), optional=optional) comment_set = [] elif token_type.startswith('#include'): master.add_include(token_type) elif token_type.startswith('#if'): msg = master.find_message(token_name) if (msg): if_token = token_type[token_type.index(' ') + 1:] msg.add_if(if_token) elif token_type.startswith('#'): comment_set.append(token_type[1:]) else: raise ValueError("Unknown token type {} on line {}:{}".format(token_type, ln, line)) except StopIteration: pass master.post_process() master.write(options, output) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-s", "--expose-subtypes", help="print subtypes in header", action="store_true", default=False) parser.add_argument("-P", "--print_wire", help="generate wire printing source files", action="store_true", default=False) parser.add_argument("--page", choices=['header', 'impl'], help="page to print") parser.add_argument('--expose-tlv-type', action='append', default=[]) parser.add_argument('--include', action='append', default=[]) parser.add_argument('header_filename', help='The filename of the header') parser.add_argument('enum_name', help='The name of the enum to produce') parser.add_argument("files", help='Files to read in (or stdin)', nargs=REMAINDER) parsed_args = parser.parse_args() main(parsed_args, parsed_args.files)