diff --git a/tests/plugins/reject_some_invoices.py b/tests/plugins/reject_some_invoices.py new file mode 100755 index 000000000..6f520effd --- /dev/null +++ b/tests/plugins/reject_some_invoices.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +"""Simple plugin to test the invoice_payment_hook. + +We just refuse to let them pay invoices with preimages divisible by 16. +""" + +from lightning import Plugin + +plugin = Plugin() + + +@plugin.hook('invoice_payment') +def on_payment(payment, plugin): + print("label={}".format(payment['label'])) + print("msat={}".format(payment['msat'])) + print("preimage={}".format(payment['preimage'])) + + if payment['preimage'].endswith('0'): + # FIXME: Define this! + WIRE_TEMPORARY_NODE_FAILURE = 0x2002 + return {'result': {'failure_code': WIRE_TEMPORARY_NODE_FAILURE}} + + return {'result': {}} + + +plugin.run() diff --git a/tests/test_plugin.py b/tests/test_plugin.py index ebba4bde8..64dc97d72 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -233,3 +233,30 @@ def test_utf8_passthrough(node_factory, executor): 'utf8', 'ナンセンス 1杯']).decode('utf-8') assert '\\u' not in out assert out == '{\n "utf8" : "ナンセンス 1杯"\n}\n' + + +def test_invoice_payment_hook(node_factory): + """ l1 uses the reject-payment plugin to reject invoices with odd preimages. + """ + opts = [{}, {'plugin': 'tests/plugins/reject_some_invoices.py'}] + l1, l2 = node_factory.line_graph(2, opts=opts) + + # This one works + inv1 = l2.rpc.invoice(123000, 'label', 'description', preimage='1' * 64) + l1.rpc.pay(inv1['bolt11']) + + l2.daemon.wait_for_log('label=label') + l2.daemon.wait_for_log('msat=') + l2.daemon.wait_for_log('preimage=' + '1' * 64) + + # This one will be rejected. + inv2 = l2.rpc.invoice(123000, 'label2', 'description', preimage='0' * 64) + with pytest.raises(RpcError): + l1.rpc.pay(inv2['bolt11']) + + pstatus = l1.rpc.call('paystatus', [inv2['bolt11']])['pay'][0] + assert pstatus['attempts'][0]['failure']['data']['failcodename'] == 'WIRE_TEMPORARY_NODE_FAILURE' + + l2.daemon.wait_for_log('label=label2') + l2.daemon.wait_for_log('msat=') + l2.daemon.wait_for_log('preimage=' + '0' * 64)