From c0d0272eac3bf77c780149645c7ade7a6884663b Mon Sep 17 00:00:00 2001 From: niftynei Date: Wed, 22 Jul 2020 21:15:38 -0500 Subject: [PATCH] psbt-common: shared psbt utilities includes facilities for - sorting psbt inputs by serial_id - sorting psbt outputs by serial_id - adding a serial_id - getting a serial_id - finding the diffset between two psbts - adding a max_len to a psbt input - getting a max_len from a psbt input --- common/Makefile | 1 + common/psbt_open.c | 389 +++++++++++++++++++++++++++++++++++++++++++++ common/psbt_open.h | 145 +++++++++++++++++ 3 files changed, 535 insertions(+) create mode 100644 common/psbt_open.c create mode 100644 common/psbt_open.h diff --git a/common/Makefile b/common/Makefile index c56be7466..bcc433906 100644 --- a/common/Makefile +++ b/common/Makefile @@ -55,6 +55,7 @@ COMMON_SRC_NOGEN := \ common/peer_failed.c \ common/permute_tx.c \ common/ping.c \ + common/psbt_open.c \ common/pseudorand.c \ common/read_peer_msg.c \ common/setup.c \ diff --git a/common/psbt_open.c b/common/psbt_open.c new file mode 100644 index 000000000..eb9201b42 --- /dev/null +++ b/common/psbt_open.c @@ -0,0 +1,389 @@ +#include "common/psbt_open.h" +#include +#include +#include +#include +#include +#include +#include + +bool psbt_get_serial_id(const struct wally_map *map, u16 *serial_id) +{ + size_t value_len; + void *result = psbt_get_lightning(map, PSBT_TYPE_SERIAL_ID, &value_len); + if (!result) + return false; + + if (value_len != sizeof(*serial_id)) + return false; + + memcpy(serial_id, result, value_len); + return true; +} + +static int compare_serials(const struct wally_map *map_a, + const struct wally_map *map_b) +{ + u16 serial_left, serial_right; + bool ok; + + ok = psbt_get_serial_id(map_a, &serial_left); + assert(ok); + ok = psbt_get_serial_id(map_b, &serial_right); + assert(ok); + if (serial_left > serial_right) + return 1; + if (serial_left < serial_right) + return -1; + return 0; +} + +static int compare_inputs_at(const struct input_set *a, + const struct input_set *b, + void *unused UNUSED) +{ + return compare_serials(&a->input.unknowns, + &b->input.unknowns); +} + +static int compare_outputs_at(const struct output_set *a, + const struct output_set *b, + void *unused UNUSED) +{ + return compare_serials(&a->output.unknowns, + &b->output.unknowns); +} + +static const u8 *linearize_input(const tal_t *ctx, + const struct wally_psbt_input *in, + const struct wally_tx_input *tx_in) +{ + struct wally_psbt *psbt = create_psbt(NULL, 1, 0); + size_t byte_len; + + + if (wally_tx_add_input(psbt->tx, tx_in) != WALLY_OK) + abort(); + psbt->inputs[0] = *in; + psbt->num_inputs++; + + const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len); + + /* Hide the inputs we added, so it doesn't get freed */ + psbt->num_inputs--; + tal_free(psbt); + return bytes; +} + +static const u8 *linearize_output(const tal_t *ctx, + const struct wally_psbt_output *out, + const struct wally_tx_output *tx_out) +{ + struct wally_psbt *psbt = create_psbt(NULL, 1, 1); + size_t byte_len; + struct bitcoin_txid txid; + + /* Add a 'fake' input so this will linearize the tx */ + memset(&txid, 0, sizeof(txid)); + psbt_append_input(psbt, &txid, 0, 0); + + if (wally_tx_add_output(psbt->tx, tx_out) != WALLY_OK) + abort(); + psbt->outputs[0] = *out; + psbt->num_outputs++; + + const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len); + + /* Hide the outputs we added, so it doesn't get freed */ + psbt->num_outputs--; + tal_free(psbt); + return bytes; +} + +static bool input_identical(const struct wally_psbt *a, + size_t a_index, + const struct wally_psbt *b, + size_t b_index) +{ + const u8 *a_in = linearize_input(tmpctx, + &a->inputs[a_index], + &a->tx->inputs[a_index]); + const u8 *b_in = linearize_input(tmpctx, + &b->inputs[b_index], + &b->tx->inputs[b_index]); + + return memeq(a_in, tal_bytelen(a_in), + b_in, tal_bytelen(b_in)); +} + +static bool output_identical(const struct wally_psbt *a, + size_t a_index, + const struct wally_psbt *b, + size_t b_index) +{ + const u8 *a_out = linearize_output(tmpctx, + &a->outputs[a_index], + &a->tx->outputs[a_index]); + const u8 *b_out = linearize_output(tmpctx, + &b->outputs[b_index], + &b->tx->outputs[b_index]); + + return memeq(a_out, tal_bytelen(a_out), + b_out, tal_bytelen(b_out)); +} + +static void sort_inputs(struct wally_psbt *psbt) +{ + /* Build an input map */ + struct input_set *set = tal_arr(NULL, + struct input_set, + psbt->num_inputs); + + for (size_t i = 0; i < tal_count(set); i++) { + set[i].tx_input = psbt->tx->inputs[i]; + set[i].input = psbt->inputs[i]; + } + + asort(set, tal_count(set), + compare_inputs_at, NULL); + + /* Put PSBT parts into place */ + for (size_t i = 0; i < tal_count(set); i++) { + psbt->inputs[i] = set[i].input; + psbt->tx->inputs[i] = set[i].tx_input; + } + + tal_free(set); +} + +static void sort_outputs(struct wally_psbt *psbt) +{ + /* Build an output map */ + struct output_set *set = tal_arr(NULL, + struct output_set, + psbt->num_outputs); + for (size_t i = 0; i < tal_count(set); i++) { + set[i].tx_output = psbt->tx->outputs[i]; + set[i].output = psbt->outputs[i]; + } + + asort(set, tal_count(set), + compare_outputs_at, NULL); + + /* Put PSBT parts into place */ + for (size_t i = 0; i < tal_count(set); i++) { + psbt->outputs[i] = set[i].output; + psbt->tx->outputs[i] = set[i].tx_output; + } + + tal_free(set); +} + +void psbt_sort_by_serial_id(struct wally_psbt *psbt) +{ + sort_inputs(psbt); + sort_outputs(psbt); +} + +#define ADD(type, add_to, from, index) \ + do { \ + struct type##_set a; \ + a.type = from->type##s[index]; \ + a.tx_##type = from->tx->type##s[index]; \ + tal_arr_expand(add_to, a); \ + } while (0) + +/* this requires having a serial_id entry on everything */ +/* YOU MUST KEEP orig + new AROUND TO USE THE RESULTING SETS */ +bool psbt_has_diff(const tal_t *ctx, + struct wally_psbt *orig, + struct wally_psbt *new, + struct input_set **added_ins, + struct input_set **rm_ins, + struct output_set **added_outs, + struct output_set **rm_outs) +{ + int result; + size_t i = 0, j = 0; + + psbt_sort_by_serial_id(orig); + psbt_sort_by_serial_id(new); + + *added_ins = tal_arr(ctx, struct input_set, 0); + *rm_ins = tal_arr(ctx, struct input_set, 0); + *added_outs = tal_arr(ctx, struct output_set, 0); + *rm_outs = tal_arr(ctx, struct output_set, 0); + + /* Find the input diff */ + while (i < orig->num_inputs || j < new->num_inputs) { + if (i >= orig->num_inputs) { + ADD(input, added_ins, new, j); + j++; + continue; + } + if (j >= new->num_inputs) { + ADD(input, rm_ins, orig, i); + i++; + continue; + } + + result = compare_serials(&orig->inputs[i].unknowns, + &new->inputs[j].unknowns); + if (result == -1) { + ADD(input, rm_ins, orig, i); + i++; + continue; + } + if (result == 1) { + ADD(input, added_ins, new, j); + j++; + continue; + } + + if (!input_identical(orig, i, new, j)) { + ADD(input, rm_ins, orig, i); + ADD(input, added_ins, new, j); + } + i++; + j++; + } + /* Find the output diff */ + i = 0; + j = 0; + while (i < orig->num_outputs || j < new->num_outputs) { + if (i >= orig->num_outputs) { + ADD(output, added_outs, new, j); + j++; + continue; + } + if (j >= new->num_outputs) { + ADD(output, rm_outs, orig, i); + i++; + continue; + } + + result = compare_serials(&orig->outputs[i].unknowns, + &new->outputs[j].unknowns); + if (result == -1) { + ADD(output, rm_outs, orig, i); + i++; + continue; + } + if (result == 1) { + ADD(output, added_outs, new, j); + j++; + continue; + } + if (!output_identical(orig, i, new, j)) { + ADD(output, rm_outs, orig, i); + ADD(output, added_outs, new, j); + } + i++; + j++; + } + + return tal_count(*added_ins) != 0 || + tal_count(*rm_ins) != 0 || + tal_count(*added_outs) != 0 || + tal_count(*rm_outs) != 0; +} + +void psbt_input_add_serial_id(struct wally_psbt_input *input, + u16 serial_id) +{ + u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL); + psbt_input_add_unknown(input, key, &serial_id, sizeof(serial_id)); +} + + +void psbt_output_add_serial_id(struct wally_psbt_output *output, + u16 serial_id) +{ + u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL); + psbt_output_add_unknown(output, key, &serial_id, sizeof(serial_id)); +} + +bool psbt_has_serial_input(struct wally_psbt *psbt, u16 serial_id) +{ + for (size_t i = 0; i < psbt->num_inputs; i++) { + u16 in_serial; + if (!psbt_get_serial_id(&psbt->inputs[i].unknowns, &in_serial)) + continue; + if (in_serial == serial_id) + return true; + } + return false; +} + +bool psbt_has_serial_output(struct wally_psbt *psbt, u16 serial_id) +{ + for (size_t i = 0; i < psbt->num_outputs; i++) { + u16 out_serial; + if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &out_serial)) + continue; + if (out_serial == serial_id) + return true; + } + return false; +} + +void psbt_input_add_max_witness_len(struct wally_psbt_input *input, + u16 max_witness_len) +{ + u8 *key = psbt_make_key(NULL, PSBT_TYPE_MAX_WITNESS_LEN, NULL); + psbt_input_add_unknown(input, key, &max_witness_len, sizeof(max_witness_len)); + tal_free(key); +} + + +bool psbt_input_get_max_witness_len(struct wally_psbt_input *input, + u16 *max_witness_len) +{ + size_t value_len; + void *result = psbt_get_lightning(&input->unknowns, + PSBT_TYPE_MAX_WITNESS_LEN, + &value_len); + if (!result) + return false; + + if (value_len != sizeof(*max_witness_len)) + return false; + + memcpy(max_witness_len, result, value_len); + return true; +} + +bool psbt_has_required_fields(struct wally_psbt *psbt) +{ + u16 max_witness, serial_id; + for (size_t i = 0; i < psbt->num_inputs; i++) { + struct wally_psbt_input *input = &psbt->inputs[i]; + + if (!psbt_get_serial_id(&input->unknowns, &serial_id)) + return false; + + /* Inputs had also better have their max_witness_lens + * filled in! */ + if (!psbt_input_get_max_witness_len(input, &max_witness)) + return false; + + /* Required because we send the full tx over the wire now */ + if (!input->utxo) + return false; + + /* If is P2SH, redeemscript must be present */ + const u8 *outscript = + wally_tx_output_get_script(tmpctx, + &input->utxo->outputs[psbt->tx->inputs[i].index]); + if (is_p2sh(outscript, NULL) && input->redeem_script_len == 0) + return false; + + } + + for (size_t i = 0; i < psbt->num_outputs; i++) { + if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &serial_id)) + return false; + } + + return true; +} diff --git a/common/psbt_open.h b/common/psbt_open.h new file mode 100644 index 000000000..776811e13 --- /dev/null +++ b/common/psbt_open.h @@ -0,0 +1,145 @@ +#ifndef LIGHTNING_COMMON_PSBT_OPEN_H +#define LIGHTNING_COMMON_PSBT_OPEN_H +#include "config.h" +#include +#include +#include +#include +#include + +struct wally_tx_input; +struct wally_tx_output; +struct wally_psbt; +struct wally_psbt_input; +struct wally_psbt_output; +struct wally_map; + +struct input_set { + struct wally_tx_input tx_input; + struct wally_psbt_input input; +}; + +struct output_set { + struct wally_tx_output tx_output; + struct wally_psbt_output output; +}; + +#define PSBT_TYPE_SERIAL_ID 0x01 +#define PSBT_TYPE_MAX_WITNESS_LEN 0x02 + +/* psbt_get_serial_id - Returns the serial_id from an unknowns map + * + * @map - the map to find the serial id entry within + * @serial_id - found serial_id + * + * Returns false if serial_id is not present + */ +WARN_UNUSED_RESULT bool psbt_get_serial_id(const struct wally_map *map, + u16 *serial_id); + +/* psbt_sort_by_serial_id - Sort PSBT by serial_ids + * + * MUST have a serial_id on every input/output. + * + * @psbt - psbt to sort + */ +void psbt_sort_by_serial_id(struct wally_psbt *psbt); + +/* psbt_has_diff - Returns set of diffs btw orig + new psbt + * + * All inputs+outputs MUST have a serial_id field present before + * calling this. + * + * @ctx - allocation context for returned diffsets + * @orig - original psbt + * @new - updated psbt + * @added_ins - inputs added {new} + * @rm_ins - inputs removed {orig} + * @added_outs - outputs added {new} + * @rm_outs - outputs removed {orig} + * + * Note that the input + output data returned in the diff sets + * contain references to the originating PSBT; they are not copies. + * + * Returns true if changes are found. + */ +bool psbt_has_diff(const tal_t *ctx, + struct wally_psbt *orig, + struct wally_psbt *new, + struct input_set **added_ins, + struct input_set **rm_ins, + struct output_set **added_outs, + struct output_set **rm_outs); + +/* psbt_input_add_serial_id - Adds a serial id to given input + * + * @input - to add serial_id to + * @serial_id - to add + */ +void psbt_input_add_serial_id(struct wally_psbt_input *input, + u16 serial_id); +/* psbt_output_add_serial_id - Adds a serial id to given output + * + * @output - to add serial_id to + * @serial_id - to add + */ +void psbt_output_add_serial_id(struct wally_psbt_output *output, + u16 serial_id); + +/* psbt_sort_by_serial_id - Sorts the inputs + outputs by serial_id + * + * Requires every input/output to have a serial_id entry. + * + * @psbt - psbt to sort inputs/outputs + */ +void psbt_sort_by_serial_id(struct wally_psbt *psbt); + +/* psbt_has_serial_input - Checks inputs for provided serial_id + * + * @psbt - psbt's inputs to check + * @serial_id - id to look for + * Returns true if serial_id found. + */ +WARN_UNUSED_RESULT bool +psbt_has_serial_input(struct wally_psbt *psbt, u16 serial_id); + +/* psbt_has_serial_output - Checks outputs for provided serial_id + * + * @psbt - psbt's outputs to check + * @serial_id - id to look for + * Returns true if serial_id found. + */ +WARN_UNUSED_RESULT bool +psbt_has_serial_output(struct wally_psbt *psbt, u16 serial_id); + +/* psbt_input_add_max_witness_len - Put a max witness len on a thing + * + * @input - input to add max-witness-len to + * @max_witness_len - value + */ +void psbt_input_add_max_witness_len(struct wally_psbt_input *input, + u16 max_witness_len); + +/* psbt_input_get_max_witness_len - Get the max_witness_len + * + * @input - psbt input to look for max witness len on + * @max_witness_len - found length + * + * Returns false if key not present */ +WARN_UNUSED_RESULT bool +psbt_input_get_max_witness_len(struct wally_psbt_input *input, + u16 *max_witness_len); + +/* psbt_has_required_fields - Validates psbt field completion + * + * Required fields are: + * - a serial_id; input+output + * - a prev_tx; input,non_witness_utxo + * - redeemscript; input,iff is P2SH-P2W* + * @psbt - psbt to validate + * + * Returns true if all required fields are present + */ +bool psbt_has_required_fields(struct wally_psbt *psbt); + +#endif /* LIGHTNING_COMMON_PSBT_OPEN_H */