#include "common/psbt_open.h"
#include <assert.h>
#include <bitcoin/psbt.h>
#include <bitcoin/script.h>
#include <bitcoin/tx.h>
#include <ccan/asort/asort.h>
#include <ccan/ccan/endian/endian.h>
#include <ccan/ccan/mem/mem.h>
#include <common/channel_id.h>
#include <common/utils.h>
#include <wire/peer_wire.h>

bool psbt_get_serial_id(const struct wally_map *map, u16 *serial_id)
{
	size_t value_len;
	beint16_t bev;
	void *result = psbt_get_lightning(map, PSBT_TYPE_SERIAL_ID, &value_len);
	if (!result)
		return false;

	if (value_len != sizeof(bev))
		return false;

	memcpy(&bev, result, value_len);
	*serial_id = be16_to_cpu(bev);
	return true;
}

static int compare_serials(const struct wally_map *map_a,
			   const struct wally_map *map_b)
{
	u16 serial_left, serial_right;
	bool ok;

	ok = psbt_get_serial_id(map_a, &serial_left);
	assert(ok);
	ok = psbt_get_serial_id(map_b, &serial_right);
	assert(ok);
	if (serial_left > serial_right)
		return 1;
	if (serial_left < serial_right)
		return -1;
	return 0;
}

static int compare_inputs_at(const struct input_set *a,
			     const struct input_set *b,
			     void *unused UNUSED)
{
	return compare_serials(&a->input.unknowns,
			       &b->input.unknowns);
}

static int compare_outputs_at(const struct output_set *a,
			      const struct output_set *b,
			      void *unused UNUSED)
{
	return compare_serials(&a->output.unknowns,
			       &b->output.unknowns);
}

static const u8 *linearize_input(const tal_t *ctx,
				 const struct wally_psbt_input *in,
				 const struct wally_tx_input *tx_in)
{
	struct wally_psbt *psbt = create_psbt(NULL, 1, 0, 0);
	size_t byte_len;

	if (wally_tx_add_input(psbt->tx, tx_in) != WALLY_OK)
		abort();
	psbt->inputs[0] = *in;
	psbt->num_inputs++;


	/* Sort the inputs, so serializing them is ok */
	wally_map_sort(&psbt->inputs[0].unknowns, 0);

	/* signatures, keypaths, etc - we dont care if they change */
	psbt->inputs[0].final_witness = NULL;
	psbt->inputs[0].final_scriptsig_len = 0;
	psbt->inputs[0].witness_script_len = 0;
	psbt->inputs[0].redeem_script_len = 0;
	psbt->inputs[0].keypaths.num_items = 0;
	psbt->inputs[0].signatures.num_items = 0;


	const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len);

	/* Hide the inputs we added, so it doesn't get freed */
	psbt->num_inputs--;
	tal_free(psbt);
	return bytes;
}

static const u8 *linearize_output(const tal_t *ctx,
				  const struct wally_psbt_output *out,
				  const struct wally_tx_output *tx_out)
{
	struct wally_psbt *psbt = create_psbt(NULL, 1, 1, 0);
	size_t byte_len;
	struct bitcoin_txid txid;

	/* Add a 'fake' input so this will linearize the tx */
	memset(&txid, 0, sizeof(txid));
	psbt_append_input(psbt, &txid, 0, 0, NULL, NULL, NULL);

	if (wally_tx_add_output(psbt->tx, tx_out) != WALLY_OK)
		abort();

	psbt->outputs[0] = *out;
	psbt->num_outputs++;
	/* Sort the outputs, so serializing them is ok */
	wally_map_sort(&psbt->outputs[0].unknowns, 0);

	/* We don't care if the keypaths change */
	psbt->outputs[0].keypaths.num_items = 0;
	/* And you can add scripts, no problem */
	psbt->outputs[0].witness_script_len = 0;
	psbt->outputs[0].redeem_script_len = 0;

	const u8 *bytes = psbt_get_bytes(ctx, psbt, &byte_len);

	/* Hide the outputs we added, so it doesn't get freed */
	psbt->num_outputs--;
	tal_free(psbt);
	return bytes;
}

static bool input_identical(const struct wally_psbt *a,
			    size_t a_index,
			    const struct wally_psbt *b,
			    size_t b_index)
{
	const u8 *a_in = linearize_input(tmpctx,
					 &a->inputs[a_index],
					 &a->tx->inputs[a_index]);
	const u8 *b_in = linearize_input(tmpctx,
					 &b->inputs[b_index],
					 &b->tx->inputs[b_index]);

	return memeq(a_in, tal_bytelen(a_in),
		     b_in, tal_bytelen(b_in));
}

static bool output_identical(const struct wally_psbt *a,
			     size_t a_index,
			     const struct wally_psbt *b,
			     size_t b_index)
{
	const u8 *a_out = linearize_output(tmpctx,
					   &a->outputs[a_index],
					   &a->tx->outputs[a_index]);
	const u8 *b_out = linearize_output(tmpctx,
					   &b->outputs[b_index],
					   &b->tx->outputs[b_index]);
	return memeq(a_out, tal_bytelen(a_out),
		     b_out, tal_bytelen(b_out));
}

static void sort_inputs(struct wally_psbt *psbt)
{
	/* Build an input map */
	struct input_set *set = tal_arr(NULL,
					struct input_set,
					psbt->num_inputs);

	for (size_t i = 0; i < tal_count(set); i++) {
		set[i].tx_input = psbt->tx->inputs[i];
		set[i].input = psbt->inputs[i];
	}

	asort(set, tal_count(set),
	      compare_inputs_at, NULL);

	/* Put PSBT parts into place */
	for (size_t i = 0; i < tal_count(set); i++) {
		psbt->inputs[i] = set[i].input;
		psbt->tx->inputs[i] = set[i].tx_input;
	}

	tal_free(set);
}

static void sort_outputs(struct wally_psbt *psbt)
{
	/* Build an output map */
	struct output_set *set = tal_arr(NULL,
					 struct output_set,
					 psbt->num_outputs);
	for (size_t i = 0; i < tal_count(set); i++) {
		set[i].tx_output = psbt->tx->outputs[i];
		set[i].output = psbt->outputs[i];
	}

	asort(set, tal_count(set),
	      compare_outputs_at, NULL);

	/* Put PSBT parts into place */
	for (size_t i = 0; i < tal_count(set); i++) {
		psbt->outputs[i] = set[i].output;
		psbt->tx->outputs[i] = set[i].tx_output;
	}

	tal_free(set);
}

void psbt_sort_by_serial_id(struct wally_psbt *psbt)
{
	sort_inputs(psbt);
	sort_outputs(psbt);
}

#define ADD(type, add_to, from, index)				\
	do {							\
		struct type##_set a;				\
		a.type = from->type##s[index];			\
		a.tx_##type = from->tx->type##s[index]; 	\
		tal_arr_expand(&add_to, a);			\
	} while (0)

static struct psbt_changeset *new_changeset(const tal_t *ctx)
{
	struct psbt_changeset *set = tal(ctx, struct psbt_changeset);

	set->added_ins = tal_arr(set, struct input_set, 0);
	set->rm_ins = tal_arr(set, struct input_set, 0);
	set->added_outs = tal_arr(set, struct output_set, 0);
	set->rm_outs = tal_arr(set, struct output_set, 0);

	return set;
}

/* this requires having a serial_id entry on everything */
/* YOU MUST KEEP orig + new AROUND TO USE THE RESULTING SETS */
struct psbt_changeset *psbt_get_changeset(const tal_t *ctx,
					  struct wally_psbt *orig,
					  struct wally_psbt *new)
{
	int result;
	size_t i = 0, j = 0;
	struct psbt_changeset *set;

	psbt_sort_by_serial_id(orig);
	psbt_sort_by_serial_id(new);

	set = new_changeset(ctx);

	/* Find the input diff */
	while (i < orig->num_inputs || j < new->num_inputs) {
		if (i >= orig->num_inputs) {
			ADD(input, set->added_ins, new, j);
			j++;
			continue;
		}
		if (j >= new->num_inputs) {
			ADD(input, set->rm_ins, orig, i);
			i++;
			continue;
		}

		result = compare_serials(&orig->inputs[i].unknowns,
					 &new->inputs[j].unknowns);
		if (result == -1) {
			ADD(input, set->rm_ins, orig, i);
			i++;
			continue;
		}
		if (result == 1) {
			ADD(input, set->added_ins, new, j);
			j++;
			continue;
		}

		if (!input_identical(orig, i, new, j)) {
			ADD(input, set->rm_ins, orig, i);
			ADD(input, set->added_ins, new, j);
		}
		i++;
		j++;
	}
	/* Find the output diff */
	i = 0;
	j = 0;
	while (i < orig->num_outputs || j < new->num_outputs) {
		if (i >= orig->num_outputs) {
			ADD(output, set->added_outs, new, j);
			j++;
			continue;
		}
		if (j >= new->num_outputs) {
			ADD(output, set->rm_outs, orig, i);
			i++;
			continue;
		}

		result = compare_serials(&orig->outputs[i].unknowns,
					 &new->outputs[j].unknowns);
		if (result == -1) {
			ADD(output, set->rm_outs, orig, i);
			i++;
			continue;
		}
		if (result == 1) {
			ADD(output, set->added_outs, new, j);
			j++;
			continue;
		}
		if (!output_identical(orig, i, new, j)) {
			ADD(output, set->rm_outs, orig, i);
			ADD(output, set->added_outs, new, j);
		}
		i++;
		j++;
	}

	return set;
}

u8 *psbt_changeset_get_next(const tal_t *ctx, struct channel_id *cid,
			    struct psbt_changeset *set)
{
	u16 serial_id;
	u8 *msg;

	if (tal_count(set->added_ins) != 0) {
		const struct input_set *in = &set->added_ins[0];
		u8 *script;

		if (!psbt_get_serial_id(&in->input.unknowns, &serial_id))
			abort();

		const u8 *prevtx = linearize_wtx(ctx,
						 in->input.utxo);

		if (in->input.redeem_script_len)
			script = tal_dup_arr(ctx, u8,
					     in->input.redeem_script,
					     in->input.redeem_script_len, 0);
		else
			script = NULL;

		msg = towire_tx_add_input(ctx, cid, serial_id,
					  prevtx, in->tx_input.index,
					  in->tx_input.sequence,
					  script,
					  NULL);

		tal_arr_remove(&set->added_ins, 0);
		return msg;
	}
	if (tal_count(set->rm_ins) != 0) {
		if (!psbt_get_serial_id(&set->rm_ins[0].input.unknowns,
					&serial_id))
			abort();

		msg = towire_tx_remove_input(ctx, cid, serial_id);

		tal_arr_remove(&set->rm_ins, 0);
		return msg;
	}
	if (tal_count(set->added_outs) != 0) {
		struct amount_sat sats;
		struct amount_asset asset_amt;

		const struct output_set *out = &set->added_outs[0];
		if (!psbt_get_serial_id(&out->output.unknowns, &serial_id))
			abort();

		asset_amt = wally_tx_output_get_amount(&out->tx_output);
		sats = amount_asset_to_sat(&asset_amt);
		const u8 *script = wally_tx_output_get_script(ctx,
							      &out->tx_output);

		msg = towire_tx_add_output(ctx, cid, serial_id,
					   sats.satoshis, /* Raw: wire interface */
					   script);

		tal_arr_remove(&set->added_outs, 0);
		return msg;
	}
	if (tal_count(set->rm_outs) != 0) {
		if (!psbt_get_serial_id(&set->rm_outs[0].output.unknowns,
					&serial_id))
			abort();

		msg = towire_tx_remove_output(ctx, cid, serial_id);

		/* Is this a kosher way to move the list forward? */
		tal_arr_remove(&set->rm_outs, 0);
		return msg;
	}
	return NULL;
}

void psbt_input_add_serial_id(struct wally_psbt_input *input,
			      u16 serial_id)
{
	u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL);
	beint16_t bev = cpu_to_be16(serial_id);

	psbt_input_add_unknown(input, key, &bev, sizeof(bev));
}


void psbt_output_add_serial_id(struct wally_psbt_output *output,
			       u16 serial_id)
{
	u8 *key = psbt_make_key(tmpctx, PSBT_TYPE_SERIAL_ID, NULL);
	beint16_t bev = cpu_to_be16(serial_id);
	psbt_output_add_unknown(output, key, &bev, sizeof(bev));
}

int psbt_find_serial_input(struct wally_psbt *psbt, u16 serial_id)
{
	for (size_t i = 0; i < psbt->num_inputs; i++) {
		u16 in_serial;
		if (!psbt_get_serial_id(&psbt->inputs[i].unknowns, &in_serial))
			continue;
		if (in_serial == serial_id)
			return i;
	}
	return -1;
}

int psbt_find_serial_output(struct wally_psbt *psbt, u16 serial_id)
{
	for (size_t i = 0; i < psbt->num_outputs; i++) {
		u16 out_serial;
		if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &out_serial))
			continue;
		if (out_serial == serial_id)
			return i;
	}
	return -1;
}

bool psbt_has_required_fields(struct wally_psbt *psbt)
{
	u16 serial_id;
	for (size_t i = 0; i < psbt->num_inputs; i++) {
		struct wally_psbt_input *input = &psbt->inputs[i];

		if (!psbt_get_serial_id(&input->unknowns, &serial_id))
			return false;

		/* Required because we send the full tx over the wire now */
		if (!input->utxo)
			return false;

		/* If is P2SH, redeemscript must be present */
		assert(psbt->tx->inputs[i].index < input->utxo->num_outputs);
		const u8 *outscript =
			wally_tx_output_get_script(tmpctx,
				&input->utxo->outputs[psbt->tx->inputs[i].index]);
		if (is_p2sh(outscript, NULL) && input->redeem_script_len == 0)
			return false;

	}

	for (size_t i = 0; i < psbt->num_outputs; i++) {
		if (!psbt_get_serial_id(&psbt->outputs[i].unknowns, &serial_id))
			return false;
	}

	return true;
}