Browse Source
includes facilities for - sorting psbt inputs by serial_id - sorting psbt outputs by serial_id - adding a serial_id - getting a serial_id - finding the diffset between two psbts - adding a max_len to a psbt input - getting a max_len from a psbt inputbump-pyln-proto
niftynei
4 years ago
committed by
Rusty Russell
3 changed files with 535 additions and 0 deletions
@ -0,0 +1,389 @@ |
|||
#include "common/psbt_open.h" |
|||
#include <assert.h> |
|||
#include <bitcoin/psbt.h> |
|||
#include <bitcoin/script.h> |
|||
#include <bitcoin/tx.h> |
|||
#include <ccan/asort/asort.h> |
|||
#include <ccan/ccan/mem/mem.h> |
|||
#include <common/utils.h> |
|||
|
|||
bool psbt_get_serial_id(const struct wally_map *map, u16 *serial_id) |
|||
{ |
|||
size_t value_len; |
|||
void *result = psbt_get_lightning(map, PSBT_TYPE_SERIAL_ID, &value_len); |
|||
if (!result) |
|||
return false; |
|||
|
|||
if (value_len != sizeof(*serial_id)) |
|||
return false; |
|||
|
|||
memcpy(serial_id, result, value_len); |
|||
return true; |
|||
} |
|||
|
|||
static int compare_serials(const struct wally_map *map_a, |
|||
const struct wally_map *map_b) |
|||
{ |
|||
u16 serial_left, serial_right; |
|||
bool ok; |
|||
|
|||
ok = psbt_get_serial_id(map_a, &serial_left); |
|||
assert(ok); |
|||
ok = psbt_get_serial_id(map_b, &serial_right); |
|||
assert(ok); |
|||
if (serial_left > serial_right) |
|||
return 1; |
|||
if (serial_left < serial_right) |
|||
return -1; |
|||
return 0; |
|||
} |
|||
|
|||
static int compare_inputs_at(const struct input_set *a, |
|||
const struct input_set *b, |
|||
void *unused UNUSED) |
|||
{ |
|||
return compare_serials(&a->input.unknowns, |
|||
&b->input.unknowns); |
|||
} |
|||
|
|||
static int compare_outputs_at(const struct output_set *a, |
|||
const struct output_set *b, |
|||
void *unused UNUSED) |
|||
{ |
|||
return compare_serials(&a->output.unknowns, |
|||
&b->output.unknowns); |
|||
} |
|||
|
|||
static const u8 *linearize_input(const tal_t *ctx, |
|||
const struct wally_psbt_input *in, |
|||
const struct wally_tx_input *tx_in) |
|||
{ |
|||
struct wally_psbt *psbt = create_psbt(NULL, 1, 0); |
|||
size_t byte_len; |
|||
|
|||
|
|||
if (wally_tx_add_input(psbt->tx, tx_in) != WALLY_OK) |
|||
abort(); |
|||
psbt->inputs[0] = *in; |
|||
psbt->num_inputs++; |
|||
|
|||
const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len); |
|||
|
|||
/* Hide the inputs we added, so it doesn't get freed */ |
|||
psbt->num_inputs--; |
|||
tal_free(psbt); |
|||
return bytes; |
|||
} |
|||
|
|||
static const u8 *linearize_output(const tal_t *ctx, |
|||
const struct wally_psbt_output *out, |
|||
const struct wally_tx_output *tx_out) |
|||
{ |
|||
struct wally_psbt *psbt = create_psbt(NULL, 1, 1); |
|||
size_t byte_len; |
|||
struct bitcoin_txid txid; |
|||
|
|||
/* Add a 'fake' input so this will linearize the tx */ |
|||
memset(&txid, 0, sizeof(txid)); |
|||
psbt_append_input(psbt, &txid, 0, 0); |
|||
|
|||
if (wally_tx_add_output(psbt->tx, tx_out) != WALLY_OK) |
|||
abort(); |
|||
psbt->outputs[0] = *out; |
|||
psbt->num_outputs++; |
|||
|
|||
const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len); |
|||
|
|||
/* Hide the outputs we added, so it doesn't get freed */ |
|||
psbt->num_outputs--; |
|||
tal_free(psbt); |
|||
return bytes; |
|||
} |
|||
|
|||
static bool input_identical(const struct wally_psbt *a, |
|||
size_t a_index, |
|||
const struct wally_psbt *b, |
|||
size_t b_index) |
|||
{ |
|||
const u8 *a_in = linearize_input(tmpctx, |
|||
&a->inputs[a_index], |
|||
&a->tx->inputs[a_index]); |
|||
const u8 *b_in = linearize_input(tmpctx, |
|||
&b->inputs[b_index], |
|||
&b->tx->inputs[b_index]); |
|||
|
|||
return memeq(a_in, tal_bytelen(a_in), |
|||
b_in, tal_bytelen(b_in)); |
|||
} |
|||
|
|||
static bool output_identical(const struct wally_psbt *a, |
|||
size_t a_index, |
|||
const struct wally_psbt *b, |
|||
size_t b_index) |
|||
{ |
|||
const u8 *a_out = linearize_output(tmpctx, |
|||
&a->outputs[a_index], |
|||
&a->tx->outputs[a_index]); |
|||
const u8 *b_out = linearize_output(tmpctx, |
|||
&b->outputs[b_index], |
|||
&b->tx->outputs[b_index]); |
|||
|
|||
return memeq(a_out, tal_bytelen(a_out), |
|||
b_out, tal_bytelen(b_out)); |
|||
} |
|||
|
|||
static void sort_inputs(struct wally_psbt *psbt) |
|||
{ |
|||
/* Build an input map */ |
|||
struct input_set *set = tal_arr(NULL, |
|||
struct input_set, |
|||
psbt->num_inputs); |
|||
|
|||
for (size_t i = 0; i < tal_count(set); i++) { |
|||
set[i].tx_input = psbt->tx->inputs[i]; |
|||
set[i].input = psbt->inputs[i]; |
|||
} |
|||
|
|||
asort(set, tal_count(set), |
|||
compare_inputs_at, NULL); |
|||
|
|||
/* Put PSBT parts into place */ |
|||
for (size_t i = 0; i < tal_count(set); i++) { |
|||
psbt->inputs[i] = set[i].input; |
|||
psbt->tx->inputs[i] = set[i].tx_input; |
|||
} |
|||
|
|||
tal_free(set); |
|||
} |
|||
|
|||
static void sort_outputs(struct wally_psbt *psbt) |
|||
{ |
|||
/* Build an output map */ |
|||
struct output_set *set = tal_arr(NULL, |
|||
struct output_set, |
|||
psbt->num_outputs); |
|||
for (size_t i = 0; i < tal_count(set); i++) { |
|||
set[i].tx_output = psbt->tx->outputs[i]; |
|||
set[i].output = psbt->outputs[i]; |
|||
} |
|||
|
|||
asort(set, tal_count(set), |
|||
compare_outputs_at, NULL); |
|||
|
|||
/* Put PSBT parts into place */ |
|||
for (size_t i = 0; i < tal_count(set); i++) { |
|||
psbt->outputs[i] = set[i].output; |
|||
psbt->tx->outputs[i] = set[i].tx_output; |
|||
} |
|||
|
|||
tal_free(set); |
|||
} |
|||
|
|||
void psbt_sort_by_serial_id(struct wally_psbt *psbt) |
|||
{ |
|||
sort_inputs(psbt); |
|||
sort_outputs(psbt); |
|||
} |
|||
|
|||
#define ADD(type, add_to, from, index) \ |
|||
do { \ |
|||
struct type##_set a; \ |
|||
a.type = from->type##s[index]; \ |
|||
a.tx_##type = from->tx->type##s[index]; \ |
|||
tal_arr_expand(add_to, a); \ |
|||
} while (0) |
|||
|
|||
/* this requires having a serial_id entry on everything */ |
|||
/* YOU MUST KEEP orig + new AROUND TO USE THE RESULTING SETS */ |
|||
bool psbt_has_diff(const tal_t *ctx, |
|||
struct wally_psbt *orig, |
|||
struct wally_psbt *new, |
|||
struct input_set **added_ins, |
|||
struct input_set **rm_ins, |
|||
struct output_set **added_outs, |
|||
struct output_set **rm_outs) |
|||
{ |
|||
int result; |
|||
size_t i = 0, j = 0; |
|||
|
|||
psbt_sort_by_serial_id(orig); |
|||
psbt_sort_by_serial_id(new); |
|||
|
|||
*added_ins = tal_arr(ctx, struct input_set, 0); |
|||
*rm_ins = tal_arr(ctx, struct input_set, 0); |
|||
*added_outs = tal_arr(ctx, struct output_set, 0); |
|||
*rm_outs = tal_arr(ctx, struct output_set, 0); |
|||
|
|||
/* Find the input diff */ |
|||
while (i < orig->num_inputs || j < new->num_inputs) { |
|||
if (i >= orig->num_inputs) { |
|||
ADD(input, added_ins, new, j); |
|||
j++; |
|||
continue; |
|||
} |
|||
if (j >= new->num_inputs) { |
|||
ADD(input, rm_ins, orig, i); |
|||
i++; |
|||
continue; |
|||
} |
|||
|
|||
result = compare_serials(&orig->inputs[i].unknowns, |
|||
&new->inputs[j].unknowns); |
|||
if (result == -1) { |
|||
ADD(input, rm_ins, orig, i); |
|||
i++; |
|||
continue; |
|||
} |
|||
if (result == 1) { |
|||
ADD(input, added_ins, new, j); |
|||
j++; |
|||
continue; |
|||
} |
|||
|
|||
if (!input_identical(orig, i, new, j)) { |
|||
ADD(input, rm_ins, orig, i); |
|||
ADD(input, added_ins, new, j); |
|||
} |
|||
i++; |
|||
j++; |
|||
} |
|||
/* Find the output diff */ |
|||
i = 0; |
|||
j = 0; |
|||
while (i < orig->num_outputs || j < new->num_outputs) { |
|||
if (i >= orig->num_outputs) { |
|||
ADD(output, added_outs, new, j); |
|||
j++; |
|||
continue; |
|||
} |
|||
if (j >= new->num_outputs) { |
|||
ADD(output, rm_outs, orig, i); |
|||
i++; |
|||
continue; |
|||
} |
|||
|
|||
result = compare_serials(&orig->outputs[i].unknowns, |
|||
&new->outputs[j].unknowns); |
|||
if (result == -1) { |
|||
ADD(output, rm_outs, orig, i); |
|||
i++; |
|||
continue; |
|||
} |
|||
if (result == 1) { |
|||
ADD(output, added_outs, new, j); |
|||
j++; |
|||
continue; |
|||
} |
|||
if (!output_identical(orig, i, new, j)) { |
|||
ADD(output, rm_outs, orig, i); |
|||
ADD(output, added_outs, new, j); |
|||
} |
|||
i++; |
|||
j++; |
|||
} |
|||
|
|||
return tal_count(*added_ins) != 0 || |
|||
tal_count(*rm_ins) != 0 || |
|||
tal_count(*added_outs) != 0 || |
|||
tal_count(*rm_outs) != 0; |
|||
} |
|||
|
|||
void psbt_input_add_serial_id(struct wally_psbt_input *input, |
|||
u16 serial_id) |
|||
{ |
|||
u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL); |
|||
psbt_input_add_unknown(input, key, &serial_id, sizeof(serial_id)); |
|||
} |
|||
|
|||
|
|||
void psbt_output_add_serial_id(struct wally_psbt_output *output, |
|||
u16 serial_id) |
|||
{ |
|||
u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL); |
|||
psbt_output_add_unknown(output, key, &serial_id, sizeof(serial_id)); |
|||
} |
|||
|
|||
bool psbt_has_serial_input(struct wally_psbt *psbt, u16 serial_id) |
|||
{ |
|||
for (size_t i = 0; i < psbt->num_inputs; i++) { |
|||
u16 in_serial; |
|||
if (!psbt_get_serial_id(&psbt->inputs[i].unknowns, &in_serial)) |
|||
continue; |
|||
if (in_serial == serial_id) |
|||
return true; |
|||
} |
|||
return false; |
|||
} |
|||
|
|||
bool psbt_has_serial_output(struct wally_psbt *psbt, u16 serial_id) |
|||
{ |
|||
for (size_t i = 0; i < psbt->num_outputs; i++) { |
|||
u16 out_serial; |
|||
if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &out_serial)) |
|||
continue; |
|||
if (out_serial == serial_id) |
|||
return true; |
|||
} |
|||
return false; |
|||
} |
|||
|
|||
void psbt_input_add_max_witness_len(struct wally_psbt_input *input, |
|||
u16 max_witness_len) |
|||
{ |
|||
u8 *key = psbt_make_key(NULL, PSBT_TYPE_MAX_WITNESS_LEN, NULL); |
|||
psbt_input_add_unknown(input, key, &max_witness_len, sizeof(max_witness_len)); |
|||
tal_free(key); |
|||
} |
|||
|
|||
|
|||
bool psbt_input_get_max_witness_len(struct wally_psbt_input *input, |
|||
u16 *max_witness_len) |
|||
{ |
|||
size_t value_len; |
|||
void *result = psbt_get_lightning(&input->unknowns, |
|||
PSBT_TYPE_MAX_WITNESS_LEN, |
|||
&value_len); |
|||
if (!result) |
|||
return false; |
|||
|
|||
if (value_len != sizeof(*max_witness_len)) |
|||
return false; |
|||
|
|||
memcpy(max_witness_len, result, value_len); |
|||
return true; |
|||
} |
|||
|
|||
bool psbt_has_required_fields(struct wally_psbt *psbt) |
|||
{ |
|||
u16 max_witness, serial_id; |
|||
for (size_t i = 0; i < psbt->num_inputs; i++) { |
|||
struct wally_psbt_input *input = &psbt->inputs[i]; |
|||
|
|||
if (!psbt_get_serial_id(&input->unknowns, &serial_id)) |
|||
return false; |
|||
|
|||
/* Inputs had also better have their max_witness_lens
|
|||
* filled in! */ |
|||
if (!psbt_input_get_max_witness_len(input, &max_witness)) |
|||
return false; |
|||
|
|||
/* Required because we send the full tx over the wire now */ |
|||
if (!input->utxo) |
|||
return false; |
|||
|
|||
/* If is P2SH, redeemscript must be present */ |
|||
const u8 *outscript = |
|||
wally_tx_output_get_script(tmpctx, |
|||
&input->utxo->outputs[psbt->tx->inputs[i].index]); |
|||
if (is_p2sh(outscript, NULL) && input->redeem_script_len == 0) |
|||
return false; |
|||
|
|||
} |
|||
|
|||
for (size_t i = 0; i < psbt->num_outputs; i++) { |
|||
if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &serial_id)) |
|||
return false; |
|||
} |
|||
|
|||
return true; |
|||
} |
@ -0,0 +1,145 @@ |
|||
#ifndef LIGHTNING_COMMON_PSBT_OPEN_H |
|||
#define LIGHTNING_COMMON_PSBT_OPEN_H |
|||
#include "config.h" |
|||
#include <ccan/short_types/short_types.h> |
|||
#include <ccan/tal/tal.h> |
|||
#include <stdbool.h> |
|||
#include <wally_psbt.h> |
|||
#include <wally_transaction.h> |
|||
|
|||
struct wally_tx_input; |
|||
struct wally_tx_output; |
|||
struct wally_psbt; |
|||
struct wally_psbt_input; |
|||
struct wally_psbt_output; |
|||
struct wally_map; |
|||
|
|||
struct input_set { |
|||
struct wally_tx_input tx_input; |
|||
struct wally_psbt_input input; |
|||
}; |
|||
|
|||
struct output_set { |
|||
struct wally_tx_output tx_output; |
|||
struct wally_psbt_output output; |
|||
}; |
|||
|
|||
#define PSBT_TYPE_SERIAL_ID 0x01 |
|||
#define PSBT_TYPE_MAX_WITNESS_LEN 0x02 |
|||
|
|||
/* psbt_get_serial_id - Returns the serial_id from an unknowns map
|
|||
* |
|||
* @map - the map to find the serial id entry within |
|||
* @serial_id - found serial_id |
|||
* |
|||
* Returns false if serial_id is not present |
|||
*/ |
|||
WARN_UNUSED_RESULT bool psbt_get_serial_id(const struct wally_map *map, |
|||
u16 *serial_id); |
|||
|
|||
/* psbt_sort_by_serial_id - Sort PSBT by serial_ids
|
|||
* |
|||
* MUST have a serial_id on every input/output. |
|||
* |
|||
* @psbt - psbt to sort |
|||
*/ |
|||
void psbt_sort_by_serial_id(struct wally_psbt *psbt); |
|||
|
|||
/* psbt_has_diff - Returns set of diffs btw orig + new psbt
|
|||
* |
|||
* All inputs+outputs MUST have a serial_id field present before |
|||
* calling this. |
|||
* |
|||
* @ctx - allocation context for returned diffsets |
|||
* @orig - original psbt |
|||
* @new - updated psbt |
|||
* @added_ins - inputs added {new} |
|||
* @rm_ins - inputs removed {orig} |
|||
* @added_outs - outputs added {new} |
|||
* @rm_outs - outputs removed {orig} |
|||
* |
|||
* Note that the input + output data returned in the diff sets |
|||
* contain references to the originating PSBT; they are not copies. |
|||
* |
|||
* Returns true if changes are found. |
|||
*/ |
|||
bool psbt_has_diff(const tal_t *ctx, |
|||
struct wally_psbt *orig, |
|||
struct wally_psbt *new, |
|||
struct input_set **added_ins, |
|||
struct input_set **rm_ins, |
|||
struct output_set **added_outs, |
|||
struct output_set **rm_outs); |
|||
|
|||
/* psbt_input_add_serial_id - Adds a serial id to given input
|
|||
* |
|||
* @input - to add serial_id to |
|||
* @serial_id - to add |
|||
*/ |
|||
void psbt_input_add_serial_id(struct wally_psbt_input *input, |
|||
u16 serial_id); |
|||
/* psbt_output_add_serial_id - Adds a serial id to given output
|
|||
* |
|||
* @output - to add serial_id to |
|||
* @serial_id - to add |
|||
*/ |
|||
void psbt_output_add_serial_id(struct wally_psbt_output *output, |
|||
u16 serial_id); |
|||
|
|||
/* psbt_sort_by_serial_id - Sorts the inputs + outputs by serial_id
|
|||
* |
|||
* Requires every input/output to have a serial_id entry. |
|||
* |
|||
* @psbt - psbt to sort inputs/outputs |
|||
*/ |
|||
void psbt_sort_by_serial_id(struct wally_psbt *psbt); |
|||
|
|||
/* psbt_has_serial_input - Checks inputs for provided serial_id
|
|||
* |
|||
* @psbt - psbt's inputs to check |
|||
* @serial_id - id to look for |
|||
* Returns true if serial_id found. |
|||
*/ |
|||
WARN_UNUSED_RESULT bool |
|||
psbt_has_serial_input(struct wally_psbt *psbt, u16 serial_id); |
|||
|
|||
/* psbt_has_serial_output - Checks outputs for provided serial_id
|
|||
* |
|||
* @psbt - psbt's outputs to check |
|||
* @serial_id - id to look for |
|||
* Returns true if serial_id found. |
|||
*/ |
|||
WARN_UNUSED_RESULT bool |
|||
psbt_has_serial_output(struct wally_psbt *psbt, u16 serial_id); |
|||
|
|||
/* psbt_input_add_max_witness_len - Put a max witness len on a thing
|
|||
* |
|||
* @input - input to add max-witness-len to |
|||
* @max_witness_len - value |
|||
*/ |
|||
void psbt_input_add_max_witness_len(struct wally_psbt_input *input, |
|||
u16 max_witness_len); |
|||
|
|||
/* psbt_input_get_max_witness_len - Get the max_witness_len
|
|||
* |
|||
* @input - psbt input to look for max witness len on |
|||
* @max_witness_len - found length |
|||
* |
|||
* Returns false if key not present */ |
|||
WARN_UNUSED_RESULT bool |
|||
psbt_input_get_max_witness_len(struct wally_psbt_input *input, |
|||
u16 *max_witness_len); |
|||
|
|||
/* psbt_has_required_fields - Validates psbt field completion
|
|||
* |
|||
* Required fields are: |
|||
* - a serial_id; input+output |
|||
* - a prev_tx; input,non_witness_utxo |
|||
* - redeemscript; input,iff is P2SH-P2W* |
|||
* @psbt - psbt to validate |
|||
* |
|||
* Returns true if all required fields are present |
|||
*/ |
|||
bool psbt_has_required_fields(struct wally_psbt *psbt); |
|||
|
|||
#endif /* LIGHTNING_COMMON_PSBT_OPEN_H */ |
Loading…
Reference in new issue