From b0c9059602ea2f84ba6507372919b5140e56e302 Mon Sep 17 00:00:00 2001 From: Rusty Russell Date: Wed, 6 May 2020 20:11:54 +0930 Subject: [PATCH] tools/generate-wire: no more lonely messages! When we have only a single member in a TLV (e.g. an optional u64), wrapping it in a struct is awkward. This changes it to directly access those fields. This is not only more elegant (60 fewer lines), it would also be more cache friendly. That's right: cache hot singles! Signed-off-by: Rusty Russell --- channeld/channeld.c | 52 +++++++++------------------ common/onion.c | 61 +++++++++----------------------- connectd/peer_exchange_initmsg.c | 7 ++-- devtools/blindedpath.c | 24 ++++++------- devtools/gossipwith.c | 5 ++- devtools/mkquery.c | 5 ++- gossipd/queries.c | 20 +++++------ gossipd/test/run-extended-info.c | 13 +++---- lightningd/onion_message.c | 19 +++++----- tools/gen/header_template | 4 +++ tools/gen/impl_template | 53 ++++++++++++++++++--------- tools/generate-wire.py | 18 ++++++++-- wire/test/run-peer-wire.c | 15 ++++---- wire/test/run-tlvstream.c | 28 +++++++-------- 14 files changed, 149 insertions(+), 175 deletions(-) diff --git a/channeld/channeld.c b/channeld/channeld.c index ce1aacbfc..fd37596c7 100644 --- a/channeld/channeld.c +++ b/channeld/channeld.c @@ -638,8 +638,7 @@ static void handle_peer_add_htlc(struct peer *peer, const u8 *msg) "Bad peer_add_htlc %s", tal_hex(msg, msg)); #if EXPERIMENTAL_FEATURES - if (tlvs->blinding) - blinding = &tlvs->blinding->blinding; + blinding = tlvs->blinding; #endif add_err = channel_add_htlc(peer->channel, REMOTE, id, amount, cltv_expiry, &payment_hash, @@ -1661,8 +1660,6 @@ static void handle_onion_message(struct peer *peer, const u8 *msg) const u8 *cursor; size_t max, maxlen; struct tlv_onionmsg_payload *om; - const struct short_channel_id *next_scid; - struct node_id *next_node; struct tlv_onion_message_tlvs *tlvs = tlv_onion_message_tlvs_new(msg); if (!fromwire_onion_message(msg, onion, tlvs)) @@ -1682,8 +1679,7 @@ static void handle_onion_message(struct peer *peer, const u8 *msg) struct secret hmac; /* E(i) */ - blinding_in = tal(msg, struct pubkey); - *blinding_in = tlvs->blinding->blinding; + blinding_in = tal_dup(msg, struct pubkey, tlvs->blinding); status_debug("blinding in = %s", type_to_string(tmpctx, struct pubkey, blinding_in)); blinding_ss = tal(msg, struct secret); @@ -1742,8 +1738,7 @@ static void handle_onion_message(struct peer *peer, const u8 *msg) /* If we weren't given a blinding factor, tlv can provide one. */ if (om->blinding && !blinding_ss) { /* E(i) */ - blinding_in = tal(msg, struct pubkey); - *blinding_in = om->blinding->blinding; + blinding_in = tal_dup(msg, struct pubkey, om->blinding); blinding_ss = tal(msg, struct secret); ecdh(blinding_in, blinding_ss); @@ -1764,19 +1759,19 @@ static void handle_onion_message(struct peer *peer, const u8 *msg) subkey_from_hmac("rho", blinding_ss, &rho); /* Overrides next_scid / next_node */ - if (tal_bytelen(om->enctlv->enctlv) + if (tal_bytelen(om->enctlv) < crypto_aead_chacha20poly1305_ietf_ABYTES) { status_debug("enctlv too short for mac"); return; } dec = tal_arr(msg, u8, - tal_bytelen(om->enctlv->enctlv) + tal_bytelen(om->enctlv) - crypto_aead_chacha20poly1305_ietf_ABYTES); ret = crypto_aead_chacha20poly1305_ietf_decrypt(dec, NULL, NULL, - om->enctlv->enctlv, - tal_bytelen(om->enctlv->enctlv), + om->enctlv, + tal_bytelen(om->enctlv), NULL, 0, npub, rho.data); @@ -1801,17 +1796,6 @@ static void handle_onion_message(struct peer *peer, const u8 *msg) return; } - if (om->next_short_channel_id) - next_scid = &om->next_short_channel_id->short_channel_id; - else - next_scid = NULL; - - if (om->next_node_id) { - next_node = tal(msg, struct node_id); - node_id_from_pubkey(next_node, &om->next_node_id->node_id); - } else - next_node = NULL; - if (om->enctlv) { status_broken("FIXME: Handle enctlv!"); return; @@ -1835,14 +1819,16 @@ static void handle_onion_message(struct peer *peer, const u8 *msg) path))); } else { struct pubkey *next_blinding; + struct node_id next_node; /* This *MUST* have instructions on where to go next. */ - if (!next_scid && !next_node) { + if (!om->next_short_channel_id && !om->next_node_id) { status_debug("onion msg: no next field in %s", tal_hex(tmpctx, rs->raw_payload)); return; } + node_id_from_pubkey(&next_node, om->next_node_id); if (blinding_ss) { /* E(i-1) = H(E(i) || ss(i)) * E(i) */ struct sha256 h; @@ -1854,8 +1840,8 @@ static void handle_onion_message(struct peer *peer, const u8 *msg) wire_sync_write(MASTER_FD, take(towire_got_onionmsg_forward(NULL, - next_scid, - next_node, + om->next_short_channel_id, + &next_node, next_blinding, serialize_onionpacket(tmpctx, rs->next)))); } @@ -1871,11 +1857,8 @@ static void send_onionmsg(struct peer *peer, const u8 *msg) if (!fromwire_send_onionmsg(msg, msg, onion_routing_packet, &blinding)) master_badmsg(WIRE_SEND_ONIONMSG, msg); - if (blinding) { - tlvs->blinding = tal(tlvs, - struct tlv_onion_message_tlvs_blinding); - tlvs->blinding->blinding = *blinding; - } + if (blinding) + tlvs->blinding = tal_dup(tlvs, struct pubkey, blinding); sync_crypto_write(peer->pps, take(towire_onion_message(NULL, onion_routing_packet, @@ -2087,8 +2070,8 @@ static void resend_commitment(struct peer *peer, const struct changed_htlc *last struct tlv_update_add_tlvs *tlvs; if (h->blinding) { tlvs = tlv_update_add_tlvs_new(tmpctx); - tlvs->blinding = tal(tlvs, struct tlv_update_add_tlvs_blinding); - tlvs->blinding->blinding = *h->blinding; + tlvs->blinding = tal_dup(tlvs, struct pubkey, + h->blinding); } else tlvs = NULL; #endif @@ -2724,8 +2707,7 @@ static void handle_offer_htlc(struct peer *peer, const u8 *inmsg) struct tlv_update_add_tlvs *tlvs; if (blinding) { tlvs = tlv_update_add_tlvs_new(tmpctx); - tlvs->blinding = tal(tlvs, struct tlv_update_add_tlvs_blinding); - tlvs->blinding->blinding = *blinding; + tlvs->blinding = tal_dup(tlvs, struct pubkey, blinding); } else tlvs = NULL; #endif diff --git a/common/onion.c b/common/onion.c index 2f3ce4c9f..f44bf75e3 100644 --- a/common/onion.c +++ b/common/onion.c @@ -67,9 +67,6 @@ u8 *onion_nonfinal_hop(const tal_t *ctx, { if (use_tlv) { struct tlv_tlv_payload *tlv = tlv_tlv_payload_new(tmpctx); - struct tlv_tlv_payload_amt_to_forward tlv_amt; - struct tlv_tlv_payload_outgoing_cltv_value tlv_cltv; - struct tlv_tlv_payload_short_channel_id tlv_scid; /* BOLT #4: * @@ -81,23 +78,13 @@ u8 *onion_nonfinal_hop(const tal_t *ctx, * - MUST include `short_channel_id` * - MUST NOT include `payment_data` */ - tlv_amt.amt_to_forward = forward.millisatoshis; /* Raw: TLV convert */ - tlv_cltv.outgoing_cltv_value = outgoing_cltv; - tlv_scid.short_channel_id = *scid; - tlv->amt_to_forward = &tlv_amt; - tlv->outgoing_cltv_value = &tlv_cltv; - tlv->short_channel_id = &tlv_scid; + tlv->amt_to_forward = &forward.millisatoshis; /* Raw: TLV convert */ + tlv->outgoing_cltv_value = &outgoing_cltv; + tlv->short_channel_id = cast_const(struct short_channel_id *, + scid); #if EXPERIMENTAL_FEATURES - struct tlv_tlv_payload_blinding_seed tlv_blinding; - struct tlv_tlv_payload_enctlv tlv_enctlv; - if (blinding) { - tlv_blinding.blinding_seed = *blinding; - tlv->blinding_seed = &tlv_blinding; - } - if (enctlv) { - tlv_enctlv.enctlv = cast_const(u8 *, enctlv); - tlv->enctlv = &tlv_enctlv; - } + tlv->blinding_seed = cast_const(struct pubkey *, blinding); + tlv->enctlv = cast_const(u8 *, enctlv); #endif return make_tlv_hop(ctx, tlv); } else { @@ -124,8 +111,6 @@ u8 *onion_final_hop(const tal_t *ctx, if (use_tlv) { struct tlv_tlv_payload *tlv = tlv_tlv_payload_new(tmpctx); - struct tlv_tlv_payload_amt_to_forward tlv_amt; - struct tlv_tlv_payload_outgoing_cltv_value tlv_cltv; struct tlv_tlv_payload_payment_data tlv_pdata; /* BOLT #4: @@ -142,10 +127,8 @@ u8 *onion_final_hop(const tal_t *ctx, * - MUST set `payment_secret` to the one provided * - MUST set `total_msat` to the total amount it will send */ - tlv_amt.amt_to_forward = forward.millisatoshis; /* Raw: TLV convert */ - tlv_cltv.outgoing_cltv_value = outgoing_cltv; - tlv->amt_to_forward = &tlv_amt; - tlv->outgoing_cltv_value = &tlv_cltv; + tlv->amt_to_forward = &forward.millisatoshis; /* Raw: TLV convert */ + tlv->outgoing_cltv_value = &outgoing_cltv; if (payment_secret) { tlv_pdata.payment_secret = *payment_secret; @@ -153,16 +136,8 @@ u8 *onion_final_hop(const tal_t *ctx, tlv->payment_data = &tlv_pdata; } #if EXPERIMENTAL_FEATURES - struct tlv_tlv_payload_blinding_seed tlv_blinding; - struct tlv_tlv_payload_enctlv tlv_enctlv; - if (blinding) { - tlv_blinding.blinding_seed = *blinding; - tlv->blinding_seed = &tlv_blinding; - } - if (enctlv) { - tlv_enctlv.enctlv = cast_const(u8 *, enctlv); - tlv->enctlv = &tlv_enctlv; - } + tlv->blinding_seed = cast_const(struct pubkey *, blinding); + tlv->enctlv = cast_const(u8 *, enctlv); #endif return make_tlv_hop(ctx, tlv); } else { @@ -351,9 +326,8 @@ struct onion_payload *onion_decode(const tal_t *ctx, if (!tlv->amt_to_forward || !tlv->outgoing_cltv_value) goto fail; - amount_msat_from_u64(&p->amt_to_forward, - tlv->amt_to_forward->amt_to_forward); - p->outgoing_cltv = tlv->outgoing_cltv_value->outgoing_cltv_value; + amount_msat_from_u64(&p->amt_to_forward, *tlv->amt_to_forward); + p->outgoing_cltv = *tlv->outgoing_cltv_value; /* BOLT #4: * @@ -365,9 +339,8 @@ struct onion_payload *onion_decode(const tal_t *ctx, if (rs->nextcase == ONION_FORWARD) { if (!tlv->short_channel_id) goto fail; - p->forward_channel = tal(p, struct short_channel_id); - *p->forward_channel - = tlv->short_channel_id->short_channel_id; + p->forward_channel = tal_dup(p, struct short_channel_id, + tlv->short_channel_id); p->total_msat = NULL; } else { p->forward_channel = NULL; @@ -391,7 +364,7 @@ struct onion_payload *onion_decode(const tal_t *ctx, if (tlv->blinding_seed) { p->blinding = tal_dup(p, struct pubkey, - &tlv->blinding_seed->blinding_seed); + tlv->blinding_seed); ecdh(p->blinding, &p->blinding_ss); } } else @@ -408,7 +381,7 @@ struct onion_payload *onion_decode(const tal_t *ctx, ntlv = decrypt_tlv(tmpctx, &p->blinding_ss, - tlv->enctlv->enctlv); + tlv->enctlv); if (!ntlv) goto fail; @@ -417,7 +390,7 @@ struct onion_payload *onion_decode(const tal_t *ctx, goto fail; *p->forward_channel - = ntlv->short_channel_id->short_channel_id; + = *ntlv->short_channel_id; } } #endif /* EXPERIMENTAL_FEATURES */ diff --git a/connectd/peer_exchange_initmsg.c b/connectd/peer_exchange_initmsg.c index 7c7427e5d..4dca9848c 100644 --- a/connectd/peer_exchange_initmsg.c +++ b/connectd/peer_exchange_initmsg.c @@ -74,7 +74,7 @@ static struct io_plan *peer_init_received(struct io_conn *conn, * - MAY fail the connection. */ if (tlvs->networks) { - if (!contains_common_chain(tlvs->networks->chains)) { + if (!contains_common_chain(tlvs->networks)) { status_peer_debug(&peer->id, "No common chain with this peer '%s', closing", tal_hex(tmpctx, msg)); @@ -160,9 +160,8 @@ struct io_plan *peer_exchange_initmsg(struct io_conn *conn, * channels for. */ tlvs = tlv_init_tlvs_new(tmpctx); - tlvs->networks = tal(tlvs, struct tlv_init_tlvs_networks); - tlvs->networks->chains = tal_arr(tlvs->networks, struct bitcoin_blkid, 1); - tlvs->networks->chains[0] = chainparams->genesis_blockhash; + tlvs->networks = tal_dup_arr(tlvs, struct bitcoin_blkid, + &chainparams->genesis_blockhash, 1, 0); /* Initially, there were two sets of feature bits: global and local. * Local affected peer nodes only, global affected everyone. Both were diff --git a/devtools/blindedpath.c b/devtools/blindedpath.c index 1da2c53ee..feb32ce91 100644 --- a/devtools/blindedpath.c +++ b/devtools/blindedpath.c @@ -158,21 +158,19 @@ int main(int argc, char **argv) /* Use scid if they provided one */ if (scids[i]) { inner->next_short_channel_id - = tal(inner, struct tlv_onionmsg_payload_next_short_channel_id); - inner->next_short_channel_id->short_channel_id - = *scids[i]; + = tal_dup(inner, struct short_channel_id, + scids[i]); } else { - inner->next_node_id = tal(inner, struct tlv_onionmsg_payload_next_node_id); - inner->next_node_id->node_id = nodes[i+1]; + inner->next_node_id + = tal_dup(inner, struct pubkey, &nodes[i+1]); } p = tal_arr(tmpctx, u8, 0); towire_encmsg_tlvs(&p, inner); outer = tlv_onionmsg_payload_new(tmpctx); - outer->enctlv = tal(outer, struct tlv_onionmsg_payload_enctlv); - outer->enctlv->enctlv = tal_arr(tmpctx, u8, tal_count(p) + outer->enctlv = tal_arr(outer, u8, tal_count(p) + crypto_aead_chacha20poly1305_ietf_ABYTES); - ret = crypto_aead_chacha20poly1305_ietf_encrypt(outer->enctlv->enctlv, NULL, + ret = crypto_aead_chacha20poly1305_ietf_encrypt(outer->enctlv, NULL, p, tal_bytelen(p), NULL, 0, @@ -188,7 +186,7 @@ int main(int argc, char **argv) printf("%s\n%s\n", type_to_string(tmpctx, struct pubkey, &b[i]), - tal_hex(tmpctx, outer->enctlv->enctlv)); + tal_hex(tmpctx, outer->enctlv)); } else { /* devtools/onion wants length explicitly prepended */ printf("%s/%.*s%s ", @@ -290,17 +288,17 @@ int main(int argc, char **argv) if (!outer->enctlv) errx(1, "No enctlv field"); - if (tal_bytelen(outer->enctlv->enctlv) + if (tal_bytelen(outer->enctlv) < crypto_aead_chacha20poly1305_ietf_ABYTES) errx(1, "enctlv field too short"); dec = tal_arr(tmpctx, u8, - tal_bytelen(outer->enctlv->enctlv) + tal_bytelen(outer->enctlv) - crypto_aead_chacha20poly1305_ietf_ABYTES); ret = crypto_aead_chacha20poly1305_ietf_decrypt(dec, NULL, NULL, - outer->enctlv->enctlv, - tal_bytelen(outer->enctlv->enctlv), + outer->enctlv, + tal_bytelen(outer->enctlv), NULL, 0, npub, rho.data); diff --git a/devtools/gossipwith.c b/devtools/gossipwith.c index 1d1a2f3a2..e5a785606 100644 --- a/devtools/gossipwith.c +++ b/devtools/gossipwith.c @@ -156,9 +156,8 @@ static struct io_plan *handshake_success(struct io_conn *conn, struct tlv_init_tlvs *tlvs = NULL; if (chainparams) { tlvs = tlv_init_tlvs_new(NULL); - tlvs->networks = tal(tlvs, struct tlv_init_tlvs_networks); - tlvs->networks->chains = tal_arr(tlvs->networks, struct bitcoin_blkid, 1); - tlvs->networks->chains[0] = chainparams->genesis_blockhash; + tlvs->networks = tal_arr(tlvs, struct bitcoin_blkid, 1); + tlvs->networks[0] = chainparams->genesis_blockhash; } msg = towire_init(NULL, NULL, features, tlvs); diff --git a/devtools/mkquery.c b/devtools/mkquery.c index 6e86fe8c1..c3b07ef6a 100644 --- a/devtools/mkquery.c +++ b/devtools/mkquery.c @@ -49,9 +49,8 @@ int main(int argc, char *argv[]) tlvs = NULL; else if (argc == 6) { tlvs = tlv_query_channel_range_tlvs_new(ctx); - tlvs->query_option = tal(tlvs, struct tlv_query_channel_range_tlvs_query_option); - tlvs->query_option->query_option_flags - = strtol(argv[5], NULL, 0); + tlvs->query_option = tal(tlvs, varint); + *tlvs->query_option = strtol(argv[5], NULL, 0); } else usage(); msg = towire_query_channel_range(ctx, &chainhash, diff --git a/gossipd/queries.c b/gossipd/queries.c index f40f617f5..e1ceaa9db 100644 --- a/gossipd/queries.c +++ b/gossipd/queries.c @@ -348,7 +348,7 @@ static void reply_channel_range(struct peer *peer, u32 first_blocknum, u32 number_of_blocks, const u8 *encoded_scids, struct tlv_reply_channel_range_tlvs_timestamps_tlv *timestamps, - struct tlv_reply_channel_range_tlvs_checksums_tlv *checksums) + struct channel_update_checksums *checksums) { /* BOLT #7: * @@ -437,7 +437,7 @@ static bool queue_channel_ranges(struct peer *peer, struct routing_state *rstate = peer->daemon->rstate; u8 *encoded_scids = encoding_start(tmpctx); struct tlv_reply_channel_range_tlvs_timestamps_tlv *tstamps; - struct tlv_reply_channel_range_tlvs_checksums_tlv *csums; + struct channel_update_checksums *csums; struct short_channel_id scid; bool scid_ok; @@ -464,10 +464,7 @@ static bool queue_channel_ranges(struct peer *peer, tstamps = NULL; if (query_option_flags & QUERY_ADD_CHECKSUMS) { - csums = tal(tmpctx, - struct tlv_reply_channel_range_tlvs_checksums_tlv); - csums->checksums - = tal_arr(csums, struct channel_update_checksums, 0); + csums = tal_arr(tmpctx, struct channel_update_checksums, 0); } else csums = NULL; @@ -509,7 +506,7 @@ static bool queue_channel_ranges(struct peer *peer, &cs.checksum_node_id_2); if (csums) - tal_arr_expand(&csums->checksums, cs); + tal_arr_expand(&csums, cs); if (tstamps) encoding_add_timestamps(&tstamps->encoded_timestamps, &ts); @@ -520,7 +517,7 @@ static bool queue_channel_ranges(struct peer *peer, /* If either of these can't fit in max_encoded_bytes by itself, * it's over. */ if (csums) { - extension_bytes += tlv_len(csums->checksums); + extension_bytes += tlv_len(csums); } if (tstamps) { @@ -585,7 +582,7 @@ const u8 *handle_query_channel_range(struct peer *peer, const u8 *msg) tal_hex(tmpctx, msg)); } if (tlvs->query_option) - query_option_flags = tlvs->query_option->query_option_flags; + query_option_flags = *tlvs->query_option; else query_option_flags = 0; @@ -1036,9 +1033,8 @@ bool query_channel_range(struct daemon *daemon, if (qflags) { tlvs = tlv_query_channel_range_tlvs_new(tmpctx); - tlvs->query_option - = tal(tlvs, struct tlv_query_channel_range_tlvs_query_option); - tlvs->query_option->query_option_flags = qflags; + tlvs->query_option = tal(tlvs, varint); + *tlvs->query_option = qflags; } else tlvs = NULL; status_peer_debug(&peer->id, diff --git a/gossipd/test/run-extended-info.c b/gossipd/test/run-extended-info.c index 9d9e0cd41..76d7b1ed7 100644 --- a/gossipd/test/run-extended-info.c +++ b/gossipd/test/run-extended-info.c @@ -341,9 +341,8 @@ static u8 *test_query_channel_range(const char *test_vector, const jsmntok_t *ob json_for_each_arr(i, t, opt) { assert(json_tok_streq(test_vector, t, "WANT_TIMESTAMPS | WANT_CHECKSUMS")); - tlvs->query_option = tal(tlvs, - struct tlv_query_channel_range_tlvs_query_option); - tlvs->query_option->query_option_flags = + tlvs->query_option = tal(tlvs, varint); + *tlvs->query_option = QUERY_ADD_TIMESTAMPS | QUERY_ADD_CHECKSUMS; } msg = towire_query_channel_range(NULL, &chain_hash, firstBlockNum, numberOfBlocks, tlvs); @@ -411,11 +410,7 @@ static u8 *test_reply_channel_range(const char *test_vector, const jsmntok_t *ob if (opt) { const jsmntok_t *cstok; tlvs->checksums_tlv - = tal(tlvs, struct tlv_reply_channel_range_tlvs_checksums_tlv); - - tlvs->checksums_tlv->checksums - = tal_arr(tlvs->checksums_tlv, - struct channel_update_checksums, 0); + = tal_arr(tlvs, struct channel_update_checksums, 0); cstok = json_get_member(test_vector, opt, "checksums"); json_for_each_arr(i, t, cstok) { @@ -428,7 +423,7 @@ static u8 *test_reply_channel_range(const char *test_vector, const jsmntok_t *ob json_get_member(test_vector, t, "checksum2"), &cs.checksum_node_id_2)); - tal_arr_expand(&tlvs->checksums_tlv->checksums, cs); + tal_arr_expand(&tlvs->checksums_tlv, cs); } } diff --git a/lightningd/onion_message.c b/lightningd/onion_message.c index 3de82b5eb..9e219c675 100644 --- a/lightningd/onion_message.c +++ b/lightningd/onion_message.c @@ -300,20 +300,19 @@ static void populate_tlvs(struct hop *hops, tlv = tlv_onionmsg_payload_new(tmpctx); /* If they don't give scid, use next node id */ if (hops[i].scid) { - tlv->next_short_channel_id = tal(tlv, struct tlv_onionmsg_payload_next_short_channel_id); - tlv->next_short_channel_id->short_channel_id = *hops[i].scid; + tlv->next_short_channel_id + = tal_dup(tlv, struct short_channel_id, + hops[i].scid); } else if (i != tal_count(hops)-1) { - tlv->next_node_id = tal(tlv, struct tlv_onionmsg_payload_next_node_id); - tlv->next_node_id->node_id = hops[i+1].id; + tlv->next_node_id = tal_dup(tlv, struct pubkey, + &hops[i+1].id); } if (hops[i].blinding) { - tlv->blinding = tal(tlv, struct tlv_onionmsg_payload_blinding); - tlv->blinding->blinding = *hops[i].blinding; - } - if (hops[i].enctlv) { - tlv->enctlv = tal(tlv, struct tlv_onionmsg_payload_enctlv); - tlv->enctlv->enctlv = hops[i].enctlv; + tlv->blinding = tal_dup(tlv, struct pubkey, + hops[i].blinding); } + /* Note: tal_dup_talarr returns NULL for NULL */ + tlv->enctlv = tal_dup_talarr(tlv, u8, hops[i].enctlv); if (i == tal_count(hops)-1 && reply_path) tlv->reply_path = reply_path; diff --git a/tools/gen/header_template b/tools/gen/header_template index 6175dcd5a..45d959f1d 100644 --- a/tools/gen/header_template +++ b/tools/gen/header_template @@ -68,7 +68,11 @@ struct ${tlv.struct_name()} { /* TODO The following explicit fields could just point into the * tlv_field entries above to save on memory. */ % for msg in tlv.messages.values(): + % if msg.singleton(): + ${msg.singleton().type_obj.type_name()} *${msg.name}; + % else: struct ${msg.struct_name()} *${msg.name}; + % endif % endfor }; % endfor diff --git a/tools/gen/impl_template b/tools/gen/impl_template index 1c4f2e231..bdd041ea5 100644 --- a/tools/gen/impl_template +++ b/tools/gen/impl_template @@ -46,26 +46,28 @@ bool ${enum_set['name']}_is_defined(u16 type) % endfor ## START PARTIALS ## Subtype and TLV-msg towire_ -<%def name="towire_subtype_field(fieldname, f, ptr)">\ +<%def name="towire_subtype_field(fieldname, f, type_obj, is_single_ptr, ptr)">\ % if f.is_array() or f.is_varlen(): - % if f.type_obj.has_array_helper(): -towire_${f.type_obj.name}_array(${ptr}, ${fieldname}, ${f.size('tal_count(' + fieldname + ')')}); + % if type_obj.has_array_helper(): +towire_${type_obj.name}_array(${ptr}, ${fieldname}, ${f.size('tal_count(' + fieldname + ')')}); % else: for (size_t i = 0; i < ${f.size('tal_count(' + fieldname + ')')}; i++) - % if f.type_obj.is_assignable() or f.type_obj.has_len_fields(): - towire_${f.type_obj.name}(${ptr}, ${fieldname}[i]); + % if type_obj.is_assignable() or type_obj.has_len_fields(): + towire_${type_obj.name}(${ptr}, ${fieldname}[i]); % else: - towire_${f.type_obj.name}(${ptr}, ${fieldname} + i); + towire_${type_obj.name}(${ptr}, ${fieldname} + i); % endif % endif % elif f.len_field_of: -towire_${f.type_obj.name}(${ptr}, ${f.name}); +towire_${type_obj.name}(${ptr}, ${f.name}); +% elif is_single_ptr: +towire_${type_obj.name}(${ptr}, ${'*' if type_obj.is_assignable() else ''}${fieldname}); % else: -towire_${f.type_obj.name}(${ptr}, ${'' if f.type_obj.is_assignable() else '&'}${fieldname}); +towire_${type_obj.name}(${ptr}, ${'' if type_obj.is_assignable() else '&'}${fieldname}); % endif ## Subtype and TLV-msg fromwire -<%def name="fromwire_subtype_field(fieldname, f, ctx)">\ +<%def name="fromwire_subtype_field(fieldname, f, ctx, is_ptr)">\ <% type_ = f.type_obj.name typename = f.type_obj.type_name() @@ -100,6 +102,10 @@ ${fieldname} = ${f.size('*plen')} ? tal_arr(${ctx}, ${typename}, 0) : NULL; } % endif % else: + % if is_ptr: + ${fieldname} = tal(${ctx}, ${typename}); + <% fieldname = '*' + fieldname %> + % endif % if f.type_obj.is_assignable(): ${ f.name if f.len_field_of else fieldname} = fromwire_${type_}(cursor, plen); % elif f.type_obj.is_varsize(): @@ -132,9 +138,9 @@ ${static}void towire_${subtype.name}(u8 **p, const ${subtype.type_name()} *${sub /*${c} */ % endfor <% - fieldname = '{}->{}'.format(subtype.name,f.name) + fieldname = '{}->{}'.format(subtype.name,f.name) %>\ - ${towire_subtype_field(fieldname, f, 'p')}\ + ${towire_subtype_field(fieldname, f, f.type_obj, False, 'p')}\ % endfor } % if subtype.is_varsize(): @@ -160,7 +166,7 @@ ${static}void fromwire_${subtype.name}(${'const tal_t *ctx, ' if subtype.needs_c fieldname = '{}->{}'.format(subtype.name,f.name) ctx = subtype.name %> \ - ${fromwire_subtype_field(fieldname, f, ctx)}\ + ${fromwire_subtype_field(fieldname, f, ctx, False)}\ % endfor % if subtype.is_varsize(): @@ -205,8 +211,15 @@ static u8 *towire_${msg.struct_name()}(const tal_t *ctx, const void *vrecord) ptr = tal_arr(ctx, u8, 0); % for f in msg.fields.values(): -<% fieldname = 'r->{}->{}'.format(msg.name, f.name) %>\ - ${towire_subtype_field(fieldname, f, '&ptr')}\ +<% + if msg.singleton(): + fieldname = 'r->{}'.format(msg.name) + type_obj = msg.singleton().type_obj + else: + fieldname = 'r->{}->{}'.format(msg.name, f.name) + type_obj = f.type_obj +%> + ${towire_subtype_field(fieldname, f, type_obj, msg.singleton(), '&ptr')}\ % endfor return ptr; } @@ -218,13 +231,19 @@ static void fromwire_${msg.struct_name()}(const u8 **cursor, size_t *plen, void ${f.type_obj.type_name()} ${f.name}; % endfor + % if not msg.singleton(): r->${msg.name} = tal(r, struct ${msg.struct_name()}); + % endif % for f in msg.fields.values(): <% - fieldname = 'r->{}->{}'.format(msg.name, f.name) - ctx = 'r->{}'.format(msg.name) + if msg.singleton(): + fieldname = 'r->{}'.format(msg.name) + ctx = 'r' + else: + fieldname = 'r->{}->{}'.format(msg.name, f.name) + ctx = 'r->{}'.format(msg.name) %>\ - ${fromwire_subtype_field(fieldname, f, ctx)}\ + ${fromwire_subtype_field(fieldname, f, ctx, msg.singleton())}\ % endfor } % endfor diff --git a/tools/generate-wire.py b/tools/generate-wire.py index 1e2552dba..e71ad1b5a 100755 --- a/tools/generate-wire.py +++ b/tools/generate-wire.py @@ -172,6 +172,12 @@ class FieldSet(object): def needs_context(self): return any([field.needs_context() or field.is_optional for field in self.fields.values()]) + def singleton(self): + """Return the single message, if there's only one, otherwise None""" + if len(self.fields) == 1: + return next(iter(self.fields.values())) + return None + class Type(FieldSet): assignables = [ @@ -480,8 +486,14 @@ class Master(object): unsorted.remove(s) return sorted_types - def tlv_messages(self): - return [m for tlv in self.tlvs.values() for m in tlv.messages.values()] + def tlv_structs(self): + ret = [] + for tlv in self.tlvs.values(): + for v in tlv.messages.values(): + if not v.singleton(): + ret.append(v) + + return ret def find_template(self, options): dirpath = os.path.dirname(os.path.abspath(__file__)) @@ -512,7 +524,7 @@ class Master(object): stuff['includes'] = self.inclusions stuff['enum_sets'] = enum_sets subtypes = self.get_ordered_subtypes() - stuff['structs'] = subtypes + self.tlv_messages() + stuff['structs'] = subtypes + self.tlv_structs() stuff['tlvs'] = self.tlvs # We leave out extension messages in the printing pages. Any extension diff --git a/wire/test/run-peer-wire.c b/wire/test/run-peer-wire.c index 07c08f7c4..a848bb00d 100644 --- a/wire/test/run-peer-wire.c +++ b/wire/test/run-peer-wire.c @@ -858,12 +858,12 @@ static bool init_eq(const struct msg_init *a, if (!a->tlvs->networks) return true; - if (tal_count(a->tlvs->networks->chains) - != tal_count(b->tlvs->networks->chains)) + if (tal_count(a->tlvs->networks) + != tal_count(b->tlvs->networks)) return false; - for (size_t i = 0; i < tal_count(a->tlvs->networks->chains); i++) - if (!bitcoin_blkid_eq(&a->tlvs->networks->chains[i], - &b->tlvs->networks->chains[i])) + for (size_t i = 0; i < tal_count(a->tlvs->networks); i++) + if (!bitcoin_blkid_eq(&a->tlvs->networks[i], + &b->tlvs->networks[i])) return false; return true; } @@ -1078,9 +1078,8 @@ int main(void) init.localfeatures = tal_arr(ctx, u8, 2); memset(init.localfeatures, 2, 2); init.tlvs = tlv_init_tlvs_new(ctx); - init.tlvs->networks = tal(init.tlvs, struct tlv_init_tlvs_networks); - init.tlvs->networks->chains = tal_arr(ctx, struct bitcoin_blkid, 1); - init.tlvs->networks->chains[0] = chains[i]->genesis_blockhash; + init.tlvs->networks = tal_arr(init.tlvs, struct bitcoin_blkid, 1); + init.tlvs->networks[0] = chains[i]->genesis_blockhash; msg = towire_struct_init(ctx, &init); init2 = fromwire_struct_init(ctx, msg); assert(init_eq(&init, init2)); diff --git a/wire/test/run-tlvstream.c b/wire/test/run-tlvstream.c index 6c4a3aa7d..14ecf9c67 100644 --- a/wire/test/run-tlvstream.c +++ b/wire/test/run-tlvstream.c @@ -294,19 +294,19 @@ struct valid_stream { const struct tlv_n1 expect; }; -static struct tlv_n1_tlv1 tlv1_0 = { .amount_msat = 0 }; -static struct tlv_n1_tlv1 tlv1_1 = { .amount_msat = 1 }; -static struct tlv_n1_tlv1 tlv1_256 = { .amount_msat = 256 }; -static struct tlv_n1_tlv1 tlv1_65536 = { .amount_msat = 65536 }; -static struct tlv_n1_tlv1 tlv1_16777216 = { .amount_msat = 16777216 }; -static struct tlv_n1_tlv1 tlv1_4294967296 = { .amount_msat = 4294967296ULL }; -static struct tlv_n1_tlv1 tlv1_1099511627776 = { .amount_msat = 1099511627776ULL}; -static struct tlv_n1_tlv1 tlv1_281474976710656 = { .amount_msat = 281474976710656ULL }; -static struct tlv_n1_tlv1 tlv1_72057594037927936 = { .amount_msat = 72057594037927936ULL }; -static struct tlv_n1_tlv2 tlv2_0x0x550 = { .scid.u64 = 0x000000000226 }; +static u64 tlv1_0 = 0; +static u64 tlv1_1 = 1; +static u64 tlv1_256 = 256; +static u64 tlv1_65536 = 65536; +static u64 tlv1_16777216 = 16777216; +static u64 tlv1_4294967296 = 4294967296ULL; +static u64 tlv1_1099511627776 = 1099511627776UL; +static u64 tlv1_281474976710656 = 281474976710656ULL; +static u64 tlv1_72057594037927936 = 72057594037927936ULL; +static struct short_channel_id tlv2_0x0x550 = { .u64 = 0x000000000226 }; /* filled in at runtime. */ static struct tlv_n1_tlv3 tlv3_node_id; -static struct tlv_n1_tlv4 tlv4_550 = { .cltv_delta = 550 }; +static u16 tlv4_550 = 550; static struct valid_stream valid_streams[] = { /* Valid but no (known) content. */ @@ -338,7 +338,7 @@ static bool tlv_n1_eq(const struct tlv_n1 *a, const struct tlv_n1 *b) if (a->tlv1) { if (!b->tlv1) return false; - if (a->tlv1->amount_msat != b->tlv1->amount_msat) + if (*a->tlv1 != *b->tlv1) return false; } else if (b->tlv1) return false; @@ -346,7 +346,7 @@ static bool tlv_n1_eq(const struct tlv_n1 *a, const struct tlv_n1 *b) if (a->tlv2) { if (!b->tlv2) return false; - if (!short_channel_id_eq(&a->tlv2->scid, &b->tlv2->scid)) + if (!short_channel_id_eq(a->tlv2, b->tlv2)) return false; } else if (b->tlv2) return false; @@ -368,7 +368,7 @@ static bool tlv_n1_eq(const struct tlv_n1 *a, const struct tlv_n1 *b) if (a->tlv4) { if (!b->tlv4) return false; - if (a->tlv4->cltv_delta != b->tlv4->cltv_delta) + if (*a->tlv4 != *b->tlv4) return false; } else if (b->tlv4) return false;