diff --git a/gossipd/queries.c b/gossipd/queries.c index e8ca796a1..c6254649e 100644 --- a/gossipd/queries.c +++ b/gossipd/queries.c @@ -20,7 +20,7 @@ #include #if DEVELOPER -static u32 max_encoding_bytes = -1U; +static u32 dev_max_encoding_bytes = -1U; #endif /* BOLT #7: @@ -121,7 +121,7 @@ static bool encoding_end_prepend_type(u8 **encoded, size_t max_bytes) } #if DEVELOPER - if (tal_count(*encoded) > max_encoding_bytes) + if (tal_count(*encoded) > dev_max_encoding_bytes) return false; #endif return tal_count(*encoded) <= max_bytes; @@ -345,11 +345,12 @@ const u8 *handle_query_short_channel_ids(struct peer *peer, const u8 *msg) /*~ We can send multiple replies when the peer queries for all channels in * a given range of blocks; each one indicates the range of blocks it covers. */ -static void reply_channel_range(struct peer *peer, - u32 first_blocknum, u32 number_of_blocks, - const u8 *encoded_scids, - struct tlv_reply_channel_range_tlvs_timestamps_tlv *timestamps, - struct channel_update_checksums *checksums) +static void send_reply_channel_range(struct peer *peer, + u32 first_blocknum, u32 number_of_blocks, + const struct short_channel_id *scids, + struct channel_update_timestamps *tstamps, + struct channel_update_checksums *csums, + size_t num_scids) { /* BOLT #7: * @@ -363,10 +364,29 @@ static void reply_channel_range(struct peer *peer, * - otherwise: * - SHOULD set `full_information` to 1. */ + u8 *encoded_scids = encoding_start(tmpctx); + u8 *encoded_timestamps = encoding_start(tmpctx); struct tlv_reply_channel_range_tlvs *tlvs = tlv_reply_channel_range_tlvs_new(tmpctx); - tlvs->timestamps_tlv = timestamps; - tlvs->checksums_tlv = checksums; + + /* Encode them all */ + for (size_t i = 0; i < num_scids; i++) + encoding_add_short_channel_id(&encoded_scids, &scids[i]); + encoding_end_prepend_type(&encoded_scids, tal_bytelen(encoded_scids)); + + if (tstamps) { + for (size_t i = 0; i < num_scids; i++) + encoding_add_timestamps(&encoded_timestamps, &tstamps[i]); + + tlvs->timestamps_tlv = tal(tlvs, struct tlv_reply_channel_range_tlvs_timestamps_tlv); + encoding_end_external_type(&encoded_timestamps, + &tlvs->timestamps_tlv->encoding_type, + tal_bytelen(encoded_timestamps)); + tlvs->timestamps_tlv->encoded_timestamps + = tal_steal(tlvs, encoded_timestamps); + } + + tlvs->checksums_tlv = csums; u8 *msg = towire_reply_channel_range(NULL, &chainparams->genesis_blockhash, @@ -414,31 +434,14 @@ static void get_checksum_and_timestamp(struct routing_state *rstate, } /* FIXME: This assumes that the tlv type encodes into 1 byte! */ -static size_t tlv_len(const tal_t *msg) +static size_t tlv_len(size_t num_entries, size_t size) { - return 1 + bigsize_len(tal_count(msg)) + tal_count(msg); + return 1 + bigsize_len(num_entries * size) + num_entries * size; } -/*~ When we need to send an array of channels, it might go over our 64k packet - * size. If it doesn't, we recurse, splitting in two, etc. Each message - * indicates what blocks it contains, so the recipient knows when we're - * finished. - * - * tail_blocks is the empty blocks at the end, in case they asked for all - * blocks to 4 billion. - */ -static bool queue_channel_ranges(struct peer *peer, - u32 first_blocknum, u32 number_of_blocks, - u32 tail_blocks, - enum query_option_flags query_option_flags) +/* How many entries can I fit in a reply? */ +static size_t max_entries(enum query_option_flags query_option_flags) { - struct routing_state *rstate = peer->daemon->rstate; - u8 *encoded_scids = encoding_start(tmpctx); - struct tlv_reply_channel_range_tlvs_timestamps_tlv *tstamps; - struct channel_update_checksums *csums; - struct short_channel_id scid; - bool scid_ok; - /* BOLT #7: * * 1. type: 264 (`reply_channel_range`) (`gossip_queries`) @@ -451,20 +454,59 @@ static bool queue_channel_ranges(struct peer *peer, * * [`len*byte`:`encoded_short_ids`] */ const size_t reply_overhead = 32 + 4 + 4 + 1 + 2; - const size_t max_encoded_bytes = 65535 - 2 - reply_overhead; - size_t extension_bytes; + size_t max_encoded_bytes = 65535 - 2 - reply_overhead; + size_t per_entry_size, max_num; + per_entry_size = sizeof(struct short_channel_id); + + /* Upper bound to start. */ + max_num = max_encoded_bytes / per_entry_size; + + /* If we add timestamps, we need to encode tlv */ if (query_option_flags & QUERY_ADD_TIMESTAMPS) { - tstamps = tal(tmpctx, - struct tlv_reply_channel_range_tlvs_timestamps_tlv); - tstamps->encoded_timestamps = encoding_start(tstamps); - } else - tstamps = NULL; + max_encoded_bytes -= tlv_len(max_num, + sizeof(struct channel_update_timestamps)); + per_entry_size += sizeof(struct channel_update_timestamps); + } if (query_option_flags & QUERY_ADD_CHECKSUMS) { - csums = tal_arr(tmpctx, struct channel_update_checksums, 0); - } else - csums = NULL; + max_encoded_bytes -= tlv_len(max_num, + sizeof(struct channel_update_checksums)); + per_entry_size += sizeof(struct channel_update_checksums); + } + +#if DEVELOPER + if (max_encoded_bytes > dev_max_encoding_bytes) + max_encoded_bytes = dev_max_encoding_bytes; + /* Always let one through! */ + if (max_encoded_bytes < per_entry_size) + max_encoded_bytes = per_entry_size; +#endif + + return max_encoded_bytes / per_entry_size; +} + +/* This gets all the scids they asked for, and optionally the timestamps and checksums */ +static struct short_channel_id *gather_range(const tal_t *ctx, + struct routing_state *rstate, + u32 first_blocknum, u32 number_of_blocks, + enum query_option_flags query_option_flags, + struct channel_update_timestamps **tstamps, + struct channel_update_checksums **csums) +{ + struct short_channel_id scid, *scids; + u32 end_block; + bool scid_ok; + + scids = tal_arr(ctx, struct short_channel_id, 0); + if (query_option_flags & QUERY_ADD_TIMESTAMPS) + *tstamps = tal_arr(ctx, struct channel_update_timestamps, 0); + else + *tstamps = NULL; + if (query_option_flags & QUERY_ADD_CHECKSUMS) + *csums = tal_arr(ctx, struct channel_update_checksums, 0); + else + *csums = NULL; /* Avoid underflow: we don't use block 0 anyway */ if (first_blocknum == 0) @@ -472,8 +514,17 @@ static bool queue_channel_ranges(struct peer *peer, else scid_ok = mk_short_channel_id(&scid, first_blocknum, 0, 0); scid.u64--; + /* Out of range? No blocks then. */ if (!scid_ok) - return false; + return NULL; + + if (number_of_blocks == 0) + return NULL; + + /* Fix up number_of_blocks to avoid overflow. */ + end_block = first_blocknum + number_of_blocks - 1; + if (end_block <= first_blocknum) + end_block = UINT_MAX; /* We keep a `uintmap` of `short_channel_id` to `struct chan *`. * Unlike a htable, it's efficient to iterate through, but it only @@ -485,8 +536,8 @@ static bool queue_channel_ranges(struct peer *peer, struct chan *chan; struct channel_update_timestamps ts; struct channel_update_checksums cs; - u32 blocknum = short_channel_id_blocknum(&scid); - if (blocknum >= first_blocknum + number_of_blocks) + + if (short_channel_id_blocknum(&scid) > end_block) break; /* FIXME: Store csum in header. */ @@ -494,7 +545,12 @@ static bool queue_channel_ranges(struct peer *peer, if (!is_chan_public(chan)) continue; - encoding_add_short_channel_id(&encoded_scids, &scid); + tal_arr_expand(&scids, scid); + + /* Don't calc csums if we don't even care */ + if (!(query_option_flags + & (QUERY_ADD_TIMESTAMPS|QUERY_ADD_CHECKSUMS))) + continue; get_checksum_and_timestamp(rstate, chan, 0, &ts.timestamp_node_id_1, @@ -502,72 +558,89 @@ static bool queue_channel_ranges(struct peer *peer, get_checksum_and_timestamp(rstate, chan, 1, &ts.timestamp_node_id_2, &cs.checksum_node_id_2); - - if (csums) - tal_arr_expand(&csums, cs); - if (tstamps) - encoding_add_timestamps(&tstamps->encoded_timestamps, - &ts); + if (query_option_flags & QUERY_ADD_TIMESTAMPS) + tal_arr_expand(tstamps, ts); + if (query_option_flags & QUERY_ADD_CHECKSUMS) + tal_arr_expand(csums, cs); } - extension_bytes = 0; - - /* If either of these can't fit in max_encoded_bytes by itself, - * it's over. */ - if (csums) { - extension_bytes += tlv_len(csums); - } - - if (tstamps) { - if (!encoding_end_external_type(&tstamps->encoded_timestamps, - &tstamps->encoding_type, - max_encoded_bytes)) - goto wont_fit; - /* 1 byte for encoding_type, too */ - extension_bytes += 1 + tlv_len(tstamps->encoded_timestamps); - } - - /* If we can encode that, fine: send it */ - if (extension_bytes <= max_encoded_bytes - && encoding_end_prepend_type(&encoded_scids, - max_encoded_bytes - extension_bytes)) { - reply_channel_range(peer, first_blocknum, - number_of_blocks + tail_blocks, - encoded_scids, - tstamps, csums); - return true; - } + return scids; +} -wont_fit: - /* It wouldn't all fit: divide in half */ - /* We assume we can always send one block! */ - if (number_of_blocks <= 1) { - /* We always assume we can send 1 blocks worth */ - status_broken("Could not fit scids for single block %u", - first_blocknum); - return false; - } - status_debug("queue_channel_ranges full: splitting %u+%u and %u+%u(+%u)", - first_blocknum, - number_of_blocks / 2, - first_blocknum + number_of_blocks / 2, - number_of_blocks - number_of_blocks / 2, - tail_blocks); - return queue_channel_ranges(peer, first_blocknum, number_of_blocks / 2, - 0, query_option_flags) - && queue_channel_ranges(peer, first_blocknum + number_of_blocks / 2, - number_of_blocks - number_of_blocks / 2, - tail_blocks, query_option_flags); +/*~ When we need to send an array of channels, it might go over our 64k packet + * size. But because we use compression, we can't actually tell how much + * we'll use. We pack them into the maximum amount for uncompressed, then + * compress afterwards. + */ +static void queue_channel_ranges(struct peer *peer, + u32 first_blocknum, u32 number_of_blocks, + enum query_option_flags query_option_flags) +{ + struct routing_state *rstate = peer->daemon->rstate; + struct channel_update_timestamps *tstamps; + struct channel_update_checksums *csums; + struct short_channel_id *scids; + size_t off, limit; + + scids = gather_range(tmpctx, rstate, first_blocknum, number_of_blocks, + query_option_flags, &tstamps, &csums); + + limit = max_entries(query_option_flags); + off = 0; + + /* We need to send an empty msg if we have nothing! */ + do { + size_t n = tal_count(scids) - off; + u32 this_num_blocks; + + if (n > limit) { + status_debug("reply_channel_range: splitting %zu-%zu of %zu", + off, off + limit, tal_count(scids)); + n = limit; + + /* ... and reduce to a block boundary. */ + while (short_channel_id_blocknum(&scids[off + n - 1]) + == short_channel_id_blocknum(&scids[off + limit])) { + /* We assume one block doesn't have limit # + * channels. If it does, we have to violate + * spec and send over multiple blocks. */ + if (n == 0) { + status_broken("reply_channel_range: " + "could not fit %zu scids for %u!", + limit, + short_channel_id_blocknum(&scids[off + n - 1])); + n = limit; + break; + } + n--; + } + /* Get *next* channel, add num blocks */ + this_num_blocks + = short_channel_id_blocknum(&scids[off + n]) + - first_blocknum; + } else + /* Last one must end with correct total */ + this_num_blocks = number_of_blocks; + + send_reply_channel_range(peer, first_blocknum, this_num_blocks, + scids + off, + query_option_flags & QUERY_ADD_TIMESTAMPS + ? tstamps + off : NULL, + query_option_flags & QUERY_ADD_CHECKSUMS + ? csums + off : NULL, + n); + first_blocknum += this_num_blocks; + number_of_blocks -= this_num_blocks; + off += n; + } while (number_of_blocks); } /*~ The peer can ask for all channels in a series of blocks. We reply with one * or more messages containing the short_channel_ids. */ const u8 *handle_query_channel_range(struct peer *peer, const u8 *msg) { - struct routing_state *rstate = peer->daemon->rstate; struct bitcoin_blkid chain_hash; - u32 first_blocknum, number_of_blocks, tail_blocks; - struct short_channel_id last_scid; + u32 first_blocknum, number_of_blocks; enum query_option_flags query_option_flags; struct tlv_query_channel_range_tlvs *tlvs = tlv_query_channel_range_tlvs_new(msg); @@ -602,29 +675,12 @@ const u8 *handle_query_channel_range(struct peer *peer, const u8 *msg) return NULL; } - /* If they ask for number_of_blocks UINTMAX, and we have to divide - * and conquer, we'll do a lot of unnecessary work. Cap it at the - * last value we have, then send an empty reply. */ - if (uintmap_last(&rstate->chanmap, &last_scid.u64)) { - u32 last_block = short_channel_id_blocknum(&last_scid); - - /* u64 here avoids overflow on number_of_blocks - UINTMAX for example */ - if ((u64)first_blocknum + number_of_blocks > last_block) { - tail_blocks = first_blocknum + number_of_blocks - - last_block - 1; - number_of_blocks -= tail_blocks; - } else - tail_blocks = 0; - } else - tail_blocks = 0; - - if (!queue_channel_ranges(peer, first_blocknum, number_of_blocks, - tail_blocks, query_option_flags)) - return towire_errorfmt(peer, NULL, - "Invalid query_channel_range %u+%u", - first_blocknum, number_of_blocks + tail_blocks); + /* Fix up number_of_blocks to avoid overflow. */ + if (first_blocknum + number_of_blocks < first_blocknum) + number_of_blocks = UINT_MAX - first_blocknum; + queue_channel_ranges(peer, first_blocknum, number_of_blocks, + query_option_flags); return NULL; } @@ -1066,10 +1122,10 @@ struct io_plan *dev_set_max_scids_encode_size(struct io_conn *conn, const u8 *msg) { if (!fromwire_gossipd_dev_set_max_scids_encode_size(msg, - &max_encoding_bytes)) + &dev_max_encoding_bytes)) master_badmsg(WIRE_GOSSIPD_DEV_SET_MAX_SCIDS_ENCODE_SIZE, msg); - status_debug("Set max_scids_encode_bytes to %u", max_encoding_bytes); + status_debug("Set max_scids_encode_bytes to %u", dev_max_encoding_bytes); return daemon_conn_read_next(conn, daemon->master); } #endif /* DEVELOPER */ diff --git a/tests/test_gossip.py b/tests/test_gossip.py index 3838c1a89..b598a6d25 100644 --- a/tests/test_gossip.py +++ b/tests/test_gossip.py @@ -700,8 +700,8 @@ def test_gossip_query_channel_range(node_factory, bitcoind, chainparams): 0, 1000000, filters=['0109']) # It should definitely have split - l2.daemon.wait_for_log('queue_channel_ranges full: splitting') - # Turns out it sends: 0+53, 53+26, 79+13, 92+7, 99+3, 102+2, 104+1, 105+999895 + l2.daemon.wait_for_log('reply_channel_range: splitting 0-1 of 2') + start = 0 scids = '00' for m in msgs: @@ -719,12 +719,12 @@ def test_gossip_query_channel_range(node_factory, bitcoind, chainparams): stdout=subprocess.PIPE).stdout.strip().decode() assert scids == encoded - # Test overflow case doesn't split forever; should still only get 8 for this + # Test overflow case doesn't split forever; should still only get 2 for this msgs = l2.query_gossip('query_channel_range', genesis_blockhash, 1, 429496000, filters=['0109']) - assert len(msgs) == 8 + assert len(msgs) == 2 # This should actually be large enough for zlib to kick in! scid34, _ = l3.fundchannel(l4, 10**5)