Browse Source

common/random_select: central place for reservoir sampling.

Turns out we can make quite a simple API out of it.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
bump-pyln-proto
Rusty Russell 5 years ago
parent
commit
496c0dd1e6
  1. 1
      common/Makefile
  2. 11
      common/random_select.c
  3. 20
      common/random_select.h
  4. 1
      gossipd/Makefile
  5. 8
      gossipd/seeker.c
  6. 3
      gossipd/test/run-next_block_range.c
  7. 1
      lightningd/Makefile
  8. 52
      lightningd/invoice.c
  9. 1
      lightningd/test/Makefile
  10. 1
      plugins/Makefile
  11. 36
      plugins/libplugin-pay.c

1
common/Makefile

@ -57,6 +57,7 @@ COMMON_SRC_NOGEN := \
common/ping.c \
common/psbt_open.c \
common/pseudorand.c \
common/random_select.c \
common/read_peer_msg.c \
common/setup.c \
common/socket_close.c \

11
common/random_select.c

@ -0,0 +1,11 @@
#include <common/pseudorand.h>
#include <common/random_select.h>
bool random_select(double weight, double *tot_weight)
{
*tot_weight += weight;
if (weight == 0)
return false;
return pseudorand_double() <= weight / *tot_weight;
}

20
common/random_select.h

@ -0,0 +1,20 @@
#ifndef LIGHTNING_COMMON_RANDOM_SELECT_H
#define LIGHTNING_COMMON_RANDOM_SELECT_H
#include "config.h"
#include <stdbool.h>
/* Use weighted reservoir sampling, see:
* https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
* But (currently) the result will consist of only one sample (k=1)
*/
/**
* random_select: return true if we should select this one.
* @weight: weight for this option (use 1.0 if all the same)
* @tot_wieght: returns with sum of weights (must be initialized to zero)
*
* This always returns true on the first non-zero weight, and weighted
* randomly from then on.
*/
bool random_select(double weight, double *tot_weight);
#endif /* LIGHTNING_COMMON_RANDOM_SELECT_H */

1
gossipd/Makefile

@ -65,6 +65,7 @@ GOSSIPD_COMMON_OBJS := \
common/per_peer_state.o \
common/ping.o \
common/pseudorand.o \
common/random_select.o \
common/setup.o \
common/status.o \
common/status_wire.o \

8
gossipd/seeker.c

@ -9,6 +9,7 @@
#include <ccan/tal/tal.h>
#include <common/decode_array.h>
#include <common/pseudorand.h>
#include <common/random_select.h>
#include <common/status.h>
#include <common/timeout.h>
#include <common/type_to_string.h>
@ -454,7 +455,7 @@ static bool get_unannounced_nodes(const tal_t *ctx,
{
size_t num = 0;
u64 offset;
u64 threshold = pseudorand_u64();
double total_weight = 0.0;
/* Pick an example short_channel_id at random to query. As a
* side-effect this gets the node. */
@ -475,11 +476,8 @@ static bool get_unannounced_nodes(const tal_t *ctx,
(*scids)[num++] = c->scid;
} else {
/* Maybe replace one: approx. reservoir sampling */
u64 p = pseudorand_u64();
if (p > threshold) {
if (random_select(1.0, &total_weight))
(*scids)[pseudorand(max)] = c->scid;
threshold = p;
}
}
}

3
gossipd/test/run-next_block_range.c

@ -63,6 +63,9 @@ void queue_peer_msg(struct peer *peer UNNEEDED, const u8 *msg TAKES UNNEEDED)
struct peer *random_peer(struct daemon *daemon UNNEEDED,
bool (*check_peer)(const struct peer *peer))
{ fprintf(stderr, "random_peer called!\n"); abort(); }
/* Generated stub for random_select */
bool random_select(double weight UNNEEDED, double *tot_weight UNNEEDED)
{ fprintf(stderr, "random_select called!\n"); abort(); }
/* Generated stub for status_failed */
void status_failed(enum status_failreason code UNNEEDED,
const char *fmt UNNEEDED, ...)

1
lightningd/Makefile

@ -60,6 +60,7 @@ LIGHTNINGD_COMMON_OBJS := \
common/per_peer_state.o \
common/permute_tx.o \
common/pseudorand.o \
common/random_select.o \
common/setup.o \
common/sphinx.o \
common/status_wire.o \

52
lightningd/invoice.c

@ -16,7 +16,7 @@
#include <common/jsonrpc_errors.h>
#include <common/overflows.h>
#include <common/param.h>
#include <common/pseudorand.h>
#include <common/random_select.h>
#include <common/timeout.h>
#include <common/utils.h>
#include <errno.h>
@ -489,15 +489,8 @@ static struct route_info **select_inchan(const tal_t *ctx,
bool *any_offline)
{
/* BOLT11 struct wants an array of arrays (can provide multiple routes) */
struct route_info **R;
double wsum, p;
struct sample {
const struct route_info *route;
double weight;
};
struct sample *S = tal_arr(tmpctx, struct sample, 0);
struct route_info **r = NULL;
double total_weight = 0.0;
*any_offline = false;
@ -505,7 +498,6 @@ static struct route_info **select_inchan(const tal_t *ctx,
for (size_t i = 0; i < tal_count(inchans); i++) {
struct peer *peer;
struct channel *c;
struct sample sample;
struct amount_msat capacity_to_pay_us, excess, capacity;
struct amount_sat cumulative_reserve;
double excess_frac;
@ -564,33 +556,23 @@ static struct route_info **select_inchan(const tal_t *ctx,
continue;
}
/* We don't want a 0 probability if 0 excess; it might be the
* only one! So bump it by 1 msat */
if (!amount_msat_add(&excess, excess, AMOUNT_MSAT(1))) {
log_broken(ld->log, "Channel %s excess overflow!",
type_to_string(tmpctx, struct short_channel_id, c->scid));
continue;
}
excess_frac = amount_msat_ratio(excess, capacity);
sample.route = &inchans[i];
sample.weight = excess_frac;
tal_arr_expand(&S, sample);
}
if (!tal_count(S))
return NULL;
/* Use weighted reservoir sampling, see:
* https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
* But (currently) the result will consist of only one sample (k=1) */
R = tal_arr(ctx, struct route_info *, 1);
R[0] = tal_dup(R, struct route_info, S[0].route);
wsum = S[0].weight;
for (size_t i = 1; i < tal_count(S); i++) {
wsum += S[i].weight;
p = S[i].weight / wsum;
double random_1 = pseudorand_double(); /* range [0,1) */
if (random_1 <= p)
R[0] = tal_dup(R, struct route_info, S[i].route);
if (random_select(excess_frac, &total_weight)) {
tal_free(r);
r = tal_arr(ctx, struct route_info *, 1);
r[0] = tal_dup(r, struct route_info, &inchans[i]);
}
}
return R;
return r;
}
/** select_inchan_mpp
@ -1414,6 +1396,7 @@ static struct command_result *json_waitanyinvoice(struct command *cmd,
" is non-trivial.");
}
static const struct json_command waitanyinvoice_command = {
"waitanyinvoice",
"payment",
@ -1423,7 +1406,6 @@ static const struct json_command waitanyinvoice_command = {
};
AUTODATA(json_command, &waitanyinvoice_command);
/* Wait for an incoming payment matching the `label` in the JSON
* command. This will either return immediately if the payment has
* already been received or it may add the `cmd` to the list of

1
lightningd/test/Makefile

@ -16,6 +16,7 @@ LIGHTNINGD_TEST_COMMON_OBJS := \
common/json.o \
common/key_derive.o \
common/pseudorand.o \
common/random_select.o \
common/memleak.o \
common/msg_queue.o \
common/utils.o \

1
plugins/Makefile

@ -70,6 +70,7 @@ PLUGIN_COMMON_OBJS := \
common/node_id.o \
common/param.o \
common/pseudorand.o \
common/random_select.o \
common/setup.o \
common/type_to_string.o \
common/utils.o \

36
plugins/libplugin-pay.c

@ -3,6 +3,7 @@
#include <ccan/tal/str/str.h>
#include <common/json_stream.h>
#include <common/pseudorand.h>
#include <common/random_select.h>
#include <common/type_to_string.h>
#include <plugins/libplugin-pay.h>
@ -2421,12 +2422,11 @@ static struct command_result *shadow_route_listchannels(struct command *cmd,
const jsmntok_t *result,
struct payment *p)
{
/* Use reservoir sampling across the capable channels. */
struct shadow_route_data *d = payment_mod_shadowroute_get_data(p);
struct payment_constraints *cons = &d->constraints;
struct route_info *best = NULL;
double total_weight = 0.0;
size_t i;
u64 sample = 0;
struct amount_msat best_fee;
const jsmntok_t *sattok, *delaytok, *basefeetok, *propfeetok, *desttok,
*channelstok, *chan, *scidtok;
@ -2438,7 +2438,6 @@ static struct command_result *shadow_route_listchannels(struct command *cmd,
channelstok = json_get_member(buf, result, "channels");
json_for_each_arr(i, chan, channelstok) {
u64 v = pseudorand(UINT64_MAX);
struct route_info curr;
struct amount_sat capacity;
struct amount_msat fee;
@ -2465,28 +2464,27 @@ static struct command_result *shadow_route_listchannels(struct command *cmd,
json_to_sat(buf, sattok, &capacity);
json_to_node_id(buf, desttok, &curr.pubkey);
if (!best || v > sample) {
/* If the capacity is insufficient to pass the amount
* it's not a plausible extension. */
if (amount_msat_greater_sat(p->amount, capacity))
continue;
/* If the capacity is insufficient to pass the amount
* it's not a plausible extension. */
if (amount_msat_greater_sat(p->amount, capacity))
continue;
if (curr.cltv_expiry_delta > cons->cltv_budget)
continue;
if (curr.cltv_expiry_delta > cons->cltv_budget)
continue;
if (!amount_msat_fee(
&fee, p->amount, curr.fee_base_msat,
curr.fee_proportional_millionths)) {
/* Fee computation failed... */
continue;
}
if (!amount_msat_fee(
&fee, p->amount, curr.fee_base_msat,
curr.fee_proportional_millionths)) {
/* Fee computation failed... */
continue;
}
if (amount_msat_greater_eq(fee, cons->fee_budget))
continue;
if (amount_msat_greater_eq(fee, cons->fee_budget))
continue;
if (random_select(1.0, &total_weight)) {
best = tal_dup(tmpctx, struct route_info, &curr);
best_fee = fee;
sample = v;
}
}

Loading…
Cancel
Save