@ -6,7 +6,9 @@
# include <ccan/asort/asort.h>
# include <ccan/asort/asort.h>
# include <ccan/ccan/endian/endian.h>
# include <ccan/ccan/endian/endian.h>
# include <ccan/ccan/mem/mem.h>
# include <ccan/ccan/mem/mem.h>
# include <common/channel_id.h>
# include <common/utils.h>
# include <common/utils.h>
# include <wire/peer_wire.h>
bool psbt_get_serial_id ( const struct wally_map * map , u16 * serial_id )
bool psbt_get_serial_id ( const struct wally_map * map , u16 * serial_id )
{
{
@ -213,39 +215,45 @@ void psbt_sort_by_serial_id(struct wally_psbt *psbt)
struct type # # _set a ; \
struct type # # _set a ; \
a . type = from - > type # # s [ index ] ; \
a . type = from - > type # # s [ index ] ; \
a . tx_ # # type = from - > tx - > type # # s [ index ] ; \
a . tx_ # # type = from - > tx - > type # # s [ index ] ; \
tal_arr_expand ( add_to , a ) ; \
tal_arr_expand ( & add_to , a ) ; \
} while ( 0 )
} while ( 0 )
static struct psbt_changeset * new_changeset ( const tal_t * ctx )
{
struct psbt_changeset * set = tal ( ctx , struct psbt_changeset ) ;
set - > added_ins = tal_arr ( set , struct input_set , 0 ) ;
set - > rm_ins = tal_arr ( set , struct input_set , 0 ) ;
set - > added_outs = tal_arr ( set , struct output_set , 0 ) ;
set - > rm_outs = tal_arr ( set , struct output_set , 0 ) ;
return set ;
}
/* this requires having a serial_id entry on everything */
/* this requires having a serial_id entry on everything */
/* YOU MUST KEEP orig + new AROUND TO USE THE RESULTING SETS */
/* YOU MUST KEEP orig + new AROUND TO USE THE RESULTING SETS */
bool psbt_has_diff ( const tal_t * ctx ,
struct psbt_changeset * psbt_get_changeset ( const tal_t * ctx ,
struct wally_psbt * orig ,
struct wally_psbt * orig ,
struct wally_psbt * new ,
struct wally_psbt * new )
struct input_set * * added_ins ,
struct input_set * * rm_ins ,
struct output_set * * added_outs ,
struct output_set * * rm_outs )
{
{
int result ;
int result ;
size_t i = 0 , j = 0 ;
size_t i = 0 , j = 0 ;
struct psbt_changeset * set ;
psbt_sort_by_serial_id ( orig ) ;
psbt_sort_by_serial_id ( orig ) ;
psbt_sort_by_serial_id ( new ) ;
psbt_sort_by_serial_id ( new ) ;
* added_ins = tal_arr ( ctx , struct input_set , 0 ) ;
set = new_changeset ( ctx ) ;
* rm_ins = tal_arr ( ctx , struct input_set , 0 ) ;
* added_outs = tal_arr ( ctx , struct output_set , 0 ) ;
* rm_outs = tal_arr ( ctx , struct output_set , 0 ) ;
/* Find the input diff */
/* Find the input diff */
while ( i < orig - > num_inputs | | j < new - > num_inputs ) {
while ( i < orig - > num_inputs | | j < new - > num_inputs ) {
if ( i > = orig - > num_inputs ) {
if ( i > = orig - > num_inputs ) {
ADD ( input , added_ins , new , j ) ;
ADD ( input , set - > added_ins , new , j ) ;
j + + ;
j + + ;
continue ;
continue ;
}
}
if ( j > = new - > num_inputs ) {
if ( j > = new - > num_inputs ) {
ADD ( input , rm_ins , orig , i ) ;
ADD ( input , set - > rm_ins , orig , i ) ;
i + + ;
i + + ;
continue ;
continue ;
}
}
@ -253,19 +261,19 @@ bool psbt_has_diff(const tal_t *ctx,
result = compare_serials ( & orig - > inputs [ i ] . unknowns ,
result = compare_serials ( & orig - > inputs [ i ] . unknowns ,
& new - > inputs [ j ] . unknowns ) ;
& new - > inputs [ j ] . unknowns ) ;
if ( result = = - 1 ) {
if ( result = = - 1 ) {
ADD ( input , rm_ins , orig , i ) ;
ADD ( input , set - > rm_ins , orig , i ) ;
i + + ;
i + + ;
continue ;
continue ;
}
}
if ( result = = 1 ) {
if ( result = = 1 ) {
ADD ( input , added_ins , new , j ) ;
ADD ( input , set - > added_ins , new , j ) ;
j + + ;
j + + ;
continue ;
continue ;
}
}
if ( ! input_identical ( orig , i , new , j ) ) {
if ( ! input_identical ( orig , i , new , j ) ) {
ADD ( input , rm_ins , orig , i ) ;
ADD ( input , set - > rm_ins , orig , i ) ;
ADD ( input , added_ins , new , j ) ;
ADD ( input , set - > added_ins , new , j ) ;
}
}
i + + ;
i + + ;
j + + ;
j + + ;
@ -275,12 +283,12 @@ bool psbt_has_diff(const tal_t *ctx,
j = 0 ;
j = 0 ;
while ( i < orig - > num_outputs | | j < new - > num_outputs ) {
while ( i < orig - > num_outputs | | j < new - > num_outputs ) {
if ( i > = orig - > num_outputs ) {
if ( i > = orig - > num_outputs ) {
ADD ( output , added_outs , new , j ) ;
ADD ( output , set - > added_outs , new , j ) ;
j + + ;
j + + ;
continue ;
continue ;
}
}
if ( j > = new - > num_outputs ) {
if ( j > = new - > num_outputs ) {
ADD ( output , rm_outs , orig , i ) ;
ADD ( output , set - > rm_outs , orig , i ) ;
i + + ;
i + + ;
continue ;
continue ;
}
}
@ -288,27 +296,106 @@ bool psbt_has_diff(const tal_t *ctx,
result = compare_serials ( & orig - > outputs [ i ] . unknowns ,
result = compare_serials ( & orig - > outputs [ i ] . unknowns ,
& new - > outputs [ j ] . unknowns ) ;
& new - > outputs [ j ] . unknowns ) ;
if ( result = = - 1 ) {
if ( result = = - 1 ) {
ADD ( output , rm_outs , orig , i ) ;
ADD ( output , set - > rm_outs , orig , i ) ;
i + + ;
i + + ;
continue ;
continue ;
}
}
if ( result = = 1 ) {
if ( result = = 1 ) {
ADD ( output , added_outs , new , j ) ;
ADD ( output , set - > added_outs , new , j ) ;
j + + ;
j + + ;
continue ;
continue ;
}
}
if ( ! output_identical ( orig , i , new , j ) ) {
if ( ! output_identical ( orig , i , new , j ) ) {
ADD ( output , rm_outs , orig , i ) ;
ADD ( output , set - > rm_outs , orig , i ) ;
ADD ( output , added_outs , new , j ) ;
ADD ( output , set - > added_outs , new , j ) ;
}
}
i + + ;
i + + ;
j + + ;
j + + ;
}
}
return tal_count ( * added_ins ) ! = 0 | |
return set ;
tal_count ( * rm_ins ) ! = 0 | |
}
tal_count ( * added_outs ) ! = 0 | |
tal_count ( * rm_outs ) ! = 0 ;
u8 * psbt_changeset_get_next ( const tal_t * ctx , struct channel_id * cid ,
struct psbt_changeset * set )
{
u16 serial_id ;
u8 * msg ;
if ( tal_count ( set - > added_ins ) ! = 0 ) {
const struct input_set * in = & set - > added_ins [ 0 ] ;
u16 max_witness_len ;
u8 * script ;
if ( ! psbt_get_serial_id ( & in - > input . unknowns , & serial_id ) )
abort ( ) ;
const u8 * prevtx = linearize_wtx ( ctx ,
in - > input . utxo ) ;
if ( ! psbt_input_get_max_witness_len ( & in - > input ,
& max_witness_len ) )
abort ( ) ;
if ( in - > input . redeem_script_len )
script = tal_dup_arr ( ctx , u8 ,
in - > input . redeem_script ,
in - > input . redeem_script_len , 0 ) ;
else
script = NULL ;
msg = towire_tx_add_input ( ctx , cid , serial_id ,
prevtx , in - > tx_input . index ,
in - > tx_input . sequence ,
max_witness_len ,
script ,
NULL ) ;
tal_arr_remove ( & set - > added_ins , 0 ) ;
return msg ;
}
if ( tal_count ( set - > rm_ins ) ! = 0 ) {
if ( ! psbt_get_serial_id ( & set - > rm_ins [ 0 ] . input . unknowns ,
& serial_id ) )
abort ( ) ;
msg = towire_tx_remove_input ( ctx , cid , serial_id ) ;
tal_arr_remove ( & set - > rm_ins , 0 ) ;
return msg ;
}
if ( tal_count ( set - > added_outs ) ! = 0 ) {
struct amount_sat sats ;
struct amount_asset asset_amt ;
const struct output_set * out = & set - > added_outs [ 0 ] ;
if ( ! psbt_get_serial_id ( & out - > output . unknowns , & serial_id ) )
abort ( ) ;
asset_amt = wally_tx_output_get_amount ( & out - > tx_output ) ;
sats = amount_asset_to_sat ( & asset_amt ) ;
const u8 * script = wally_tx_output_get_script ( ctx ,
& out - > tx_output ) ;
msg = towire_tx_add_output ( ctx , cid , serial_id ,
sats . satoshis , /* Raw: wire interface */
script ) ;
tal_arr_remove ( & set - > added_outs , 0 ) ;
return msg ;
}
if ( tal_count ( set - > rm_outs ) ! = 0 ) {
if ( ! psbt_get_serial_id ( & set - > rm_outs [ 0 ] . output . unknowns ,
& serial_id ) )
abort ( ) ;
msg = towire_tx_remove_output ( ctx , cid , serial_id ) ;
/* Is this a kosher way to move the list forward? */
tal_arr_remove ( & set - > rm_outs , 0 ) ;
return msg ;
}
return NULL ;
}
}
void psbt_input_add_serial_id ( struct wally_psbt_input * input ,
void psbt_input_add_serial_id ( struct wally_psbt_input * input ,