Browse Source

bolt-gen: add field optional handling

we'll need this for internal wire message formats. also disambiguates
from 'bolt message optional fields', which we rename to extensions here.

example of an optional field declaration (note the ? prefixing the
type):

    msgdata,msg_name,field_name,?type,count

these are handled with either a boolean if they're not present,
or a true value and then the object if they are.
pull/2938/head
lisa neigut 6 years ago
committed by Rusty Russell
parent
commit
181e1916b2
  1. 59
      tools/generate-bolts.py

59
tools/generate-bolts.py

@ -38,13 +38,16 @@ def next_line(args, lines):
# Class definitions, to keep things classy # Class definitions, to keep things classy
class Field(object): class Field(object):
def __init__(self, name, type_obj, optional=False, field_comments=[]): def __init__(self, name, type_obj, extension=False,
field_comments=[], optional=False):
self.name = name self.name = name
self.type_obj = type_obj self.type_obj = type_obj
self.count = 1 self.count = 1
self.is_optional = optional
self.len_field_of = None self.len_field_of = None
self.len_field = None self.len_field = None
self.is_extension = extension
self.is_optional = optional
self.field_comments = field_comments self.field_comments = field_comments
def add_count(self, count): def add_count(self, count):
@ -66,6 +69,9 @@ class Field(object):
def is_optional(self): def is_optional(self):
return self.is_optional return self.is_optional
def is_extension(self):
return self.is_extension
def size(self): def size(self):
if self.count: if self.count:
return self.count return self.count
@ -101,15 +107,16 @@ class Field(object):
class FieldSet(object): class FieldSet(object):
def __init__(self): def __init__(self):
self.fields = OrderedDict() self.fields = OrderedDict()
self.optional_fields = False self.extension_fields = False
self.len_fields = {} self.len_fields = {}
def add_data_field(self, field_name, type_obj, count=1, is_optional=[], comments=[]): def add_data_field(self, field_name, type_obj, count=1,
# FIXME: use this somewhere? is_extension=[], comments=[], optional=False):
if is_optional: if is_extension:
self.optional_fields = True self.extension_fields = True
field = Field(field_name, type_obj, bool(is_optional), comments) field = Field(field_name, type_obj, extension=bool(is_extension),
field_comments=comments, optional=optional)
if bool(count): if bool(count):
try: try:
field.add_count(int(count)) field.add_count(int(count))
@ -201,8 +208,10 @@ class Type(FieldSet):
self.is_enum = False self.is_enum = False
self.type_comments = [] self.type_comments = []
def add_data_field(self, field_name, type_obj, count=1, is_optional=[], comments=[]): def add_data_field(self, field_name, type_obj, count=1,
FieldSet.add_data_field(self, field_name, type_obj, count, is_optional, comments) is_extension=[], comments=[], optional=False):
FieldSet.add_data_field(self, field_name, type_obj, count,
is_extension, comments=comments, optional=optional)
if type_obj.name not in self.depends_on: if type_obj.name not in self.depends_on:
self.depends_on[type_obj.name] = type_obj self.depends_on[type_obj.name] = type_obj
@ -305,12 +314,16 @@ class Master(object):
self.extension_msgs[name] = msg self.extension_msgs[name] = msg
def add_type(self, type_name, field_name=None): def add_type(self, type_name, field_name=None):
optional = False
if type_name.startswith('?'):
type_name = type_name[1:]
optional = True
# Check for special type name re-mapping # Check for special type name re-mapping
type_name, collapse_original = Type.true_type(type_name, field_name) type_name, collapse_original = Type.true_type(type_name, field_name)
if type_name not in self.types: if type_name not in self.types:
self.types[type_name] = Type(type_name) self.types[type_name] = Type(type_name)
return self.types[type_name], collapse_original return self.types[type_name], collapse_original, optional
def find_type(self, type_name): def find_type(self, type_name):
return self.types[type_name] return self.types[type_name]
@ -398,7 +411,7 @@ def main(options, args=None, output=sys.stdout, lines=None):
continue continue
if token_type == 'subtype': if token_type == 'subtype':
subtype, _ = master.add_type(tokens[1]) subtype, _, _ = master.add_type(tokens[1])
subtype.add_comments(list(comment_set)) subtype.add_comments(list(comment_set))
comment_set = [] comment_set = []
@ -407,13 +420,17 @@ def main(options, args=None, output=sys.stdout, lines=None):
if not subtype: if not subtype:
raise ValueError('Unknown subtype {} for data.\nat {}:{}' raise ValueError('Unknown subtype {} for data.\nat {}:{}'
.format(tokens[1], ln, line)) .format(tokens[1], ln, line))
type_obj, collapse = master.add_type(tokens[3], tokens[2]) type_obj, collapse, optional = master.add_type(tokens[3], tokens[2])
if optional:
raise ValueError('Subtypes cannot have optional fields {}.{}\n at {}:{}'
.format(subtype.name, tokens[2], ln, line))
if collapse: if collapse:
count = 1 count = 1
else: else:
count = tokens[4] count = tokens[4]
subtype.add_data_field(tokens[2], type_obj, count, list(comment_set)) subtype.add_data_field(tokens[2], type_obj, count, comments=list(comment_set),
optional=optional)
comment_set = [] comment_set = []
elif token_type == 'tlvtype': elif token_type == 'tlvtype':
tlv = master.add_tlv(tokens[1]) tlv = master.add_tlv(tokens[1])
@ -421,7 +438,11 @@ def main(options, args=None, output=sys.stdout, lines=None):
comment_set = [] comment_set = []
elif token_type == 'tlvdata': elif token_type == 'tlvdata':
type_obj, collapse = master.add_type(tokens[4], tokens[3]) type_obj, collapse, optional = master.add_type(tokens[4], tokens[3])
if optional:
raise ValueError('TLV messages cannot have optional fields {}.{}\n at {}:{}'
.format(tokens[2], tokens[3], ln, line))
tlv = master.find_tlv(tokens[1]) tlv = master.find_tlv(tokens[1])
if not tlv: if not tlv:
raise ValueError('tlvdata for unknown tlv {}.\nat {}:{}' raise ValueError('tlvdata for unknown tlv {}.\nat {}:{}'
@ -435,7 +456,8 @@ def main(options, args=None, output=sys.stdout, lines=None):
else: else:
count = tokens[5] count = tokens[5]
msg.add_data_field(tokens[3], type_obj, count, list(comment_set)) msg.add_data_field(tokens[3], type_obj, count, comments=list(comment_set),
optional=optional)
comment_set = [] comment_set = []
elif token_type == 'msgtype': elif token_type == 'msgtype':
master.add_message(tokens[1:], comments=list(comment_set)) master.add_message(tokens[1:], comments=list(comment_set))
@ -444,7 +466,7 @@ def main(options, args=None, output=sys.stdout, lines=None):
msg = master.find_message(tokens[1]) msg = master.find_message(tokens[1])
if not msg: if not msg:
raise ValueError('Unknown message type {}. {}:{}'.format(tokens[1], ln, line)) raise ValueError('Unknown message type {}. {}:{}'.format(tokens[1], ln, line))
type_obj, collapse = master.add_type(tokens[3], tokens[2]) type_obj, collapse, optional = master.add_type(tokens[3], tokens[2])
# if this is an 'extension' field*, we want to add a new 'message' type # if this is an 'extension' field*, we want to add a new 'message' type
# in the future, extensions will be handled as TLV's # in the future, extensions will be handled as TLV's
@ -469,7 +491,8 @@ def main(options, args=None, output=sys.stdout, lines=None):
else: else:
count = tokens[4] count = tokens[4]
msg.add_data_field(tokens[2], type_obj, count, list(comment_set)) msg.add_data_field(tokens[2], type_obj, count, comments=list(comment_set),
optional=optional)
comment_set = [] comment_set = []
elif token_type.startswith('#include'): elif token_type.startswith('#include'):
master.add_include(token_type) master.add_include(token_type)

Loading…
Cancel
Save