Browse Source

PaymentInfo: use msat precision

patch-4
ThomasV 4 years ago
parent
commit
e477a43385
  1. 7
      electrum/lnpeer.py
  2. 35
      electrum/lnworker.py
  3. 20
      electrum/tests/test_lnpeer.py
  4. 14
      electrum/wallet_db.py

7
electrum/lnpeer.py

@ -1388,7 +1388,7 @@ class Peer(Logger):
if payment_secret_from_onion != derive_payment_secret_from_payment_preimage(preimage):
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
return None, reason
expected_received_msat = int(info.amount * 1000) if info.amount is not None else None
expected_received_msat = info.amount_msat
if expected_received_msat is not None and \
not (expected_received_msat <= htlc.amount_msat <= 2 * expected_received_msat):
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
@ -1410,8 +1410,9 @@ class Peer(Logger):
reason = OnionRoutingFailureMessage(code=OnionFailureCode.INVALID_ONION_PAYLOAD, data=b'\x00\x00\x00')
return None, reason
if cltv_from_onion != htlc.cltv_expiry:
reason = OnionRoutingFailureMessage(code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY,
data=htlc.cltv_expiry.to_bytes(4, byteorder="big"))
reason = OnionRoutingFailureMessage(
code=OnionFailureCode.FINAL_INCORRECT_CLTV_EXPIRY,
data=htlc.cltv_expiry.to_bytes(4, byteorder="big"))
return None, reason
try:
amount_from_onion = processed_onion.hop_data.payload["amt_to_forward"]["amt_to_forward"]

35
electrum/lnworker.py

@ -138,7 +138,7 @@ FALLBACK_NODE_LIST_MAINNET = [
class PaymentInfo(NamedTuple):
payment_hash: bytes
amount: Optional[int] # in satoshis # TODO make it msat and rename to amount_msat
amount_msat: Optional[int]
direction: int
status: int
@ -564,7 +564,7 @@ class LNWallet(LNWorker):
self.config = wallet.config
self.lnwatcher = None
self.lnrater: LNRater = None
self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid # FIXME amt should be msat
self.payments = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
# note: this sweep_address is only used as fallback; as it might result in address-reuse
self.sweep_address = wallet.get_new_sweep_address_for_channel()
@ -687,8 +687,8 @@ class LNWallet(LNWorker):
fee_msat = None
for chan_id, htlc, _direction in plist:
amount_msat += int(_direction) * htlc.amount_msat
if _direction == SENT and info and info.amount:
fee_msat = (fee_msat or 0) - info.amount*1000 - amount_msat
if _direction == SENT and info and info.amount_msat:
fee_msat = (fee_msat or 0) - info.amount_msat - amount_msat
timestamp = min([htlc.timestamp for chan_id, htlc, _direction in plist])
return amount_msat, fee_msat, timestamp
@ -948,13 +948,13 @@ class LNWallet(LNWorker):
lnaddr = self._check_invoice(invoice, amount_msat=amount_msat)
payment_hash = lnaddr.paymenthash
key = payment_hash.hex()
amount = int(lnaddr.amount * COIN)
amount_msat = lnaddr.get_amount_msat()
status = self.get_payment_status(payment_hash)
if status == PR_PAID:
raise PaymentFailure(_("This invoice has been paid already"))
if status == PR_INFLIGHT:
raise PaymentFailure(_("A payment was already initiated for this invoice"))
info = PaymentInfo(lnaddr.paymenthash, amount, SENT, PR_UNPAID)
info = PaymentInfo(payment_hash, amount_msat, SENT, PR_UNPAID)
self.save_payment_info(info)
self.wallet.set_label(key, lnaddr.get_description())
self.logs[key] = log = []
@ -1217,16 +1217,16 @@ class LNWallet(LNWorker):
raise Exception(_("add invoice timed out"))
@log_exceptions
async def create_invoice(self, amount_sat: Optional[int], message, expiry: int):
async def create_invoice(self, amount_msat: Optional[int], message, expiry: int):
timestamp = int(time.time())
routing_hints = await self._calc_routing_hints_for_invoice(amount_sat)
routing_hints = await self._calc_routing_hints_for_invoice(amount_msat)
if not routing_hints:
self.logger.info("Warning. No routing hints added to invoice. "
"Other clients will likely not be able to send to us.")
payment_preimage = os.urandom(32)
payment_hash = sha256(payment_preimage)
info = PaymentInfo(payment_hash, amount_sat, RECEIVED, PR_UNPAID)
amount_btc = amount_sat/Decimal(COIN) if amount_sat else None
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None
if expiry == 0:
expiry = LN_EXPIRY_NEVER
lnaddr = LnAddr(paymenthash=payment_hash,
@ -1244,7 +1244,8 @@ class LNWallet(LNWorker):
return lnaddr, invoice
async def _add_request_coro(self, amount_sat: Optional[int], message, expiry: int) -> str:
lnaddr, invoice = await self.create_invoice(amount_sat, message, expiry)
amount_msat = amount_sat * 1000 if amount_sat is not None else None
lnaddr, invoice = await self.create_invoice(amount_msat, message, expiry)
key = bh2u(lnaddr.paymenthash)
req = LNInvoice.from_bech32(invoice)
self.wallet.add_payment_request(req)
@ -1265,14 +1266,14 @@ class LNWallet(LNWorker):
key = payment_hash.hex()
with self.lock:
if key in self.payments:
amount, direction, status = self.payments[key]
return PaymentInfo(payment_hash, amount, direction, status)
amount_msat, direction, status = self.payments[key]
return PaymentInfo(payment_hash, amount_msat, direction, status)
def save_payment_info(self, info: PaymentInfo) -> None:
key = info.payment_hash.hex()
assert info.status in SAVED_PR_STATUS
with self.lock:
self.payments[key] = info.amount, info.direction, info.status
self.payments[key] = info.amount_msat, info.direction, info.status
self.wallet.save_db()
def get_payment_status(self, payment_hash):
@ -1355,16 +1356,14 @@ class LNWallet(LNWorker):
util.trigger_callback('request_status', self.wallet, payment_hash.hex(), PR_PAID)
util.trigger_callback('ln_payment_completed', payment_hash, chan.channel_id)
async def _calc_routing_hints_for_invoice(self, amount_sat: Optional[int]):
async def _calc_routing_hints_for_invoice(self, amount_msat: Optional[int]):
"""calculate routing hints (BOLT-11 'r' field)"""
routing_hints = []
channels = list(self.channels.values())
random.shuffle(channels) # not sure this has any benefit but let's not leak channel order
scid_to_my_channels = {chan.short_channel_id: chan for chan in channels
if chan.short_channel_id is not None}
if amount_sat:
amount_msat = 1000 * amount_sat
else:
if not amount_msat:
# for no amt invoices, check if channel can receive at least 1 msat
amount_msat = 1
# note: currently we add *all* our channels; but this might be a privacy leak?

20
electrum/tests/test_lnpeer.py

@ -373,17 +373,17 @@ class TestPeer(ElectrumTestCase):
async def prepare_invoice(
w2: MockLNWallet, # receiver
*,
amount_sat=100_000,
amount_msat=100_000_000,
include_routing_hints=False,
):
amount_btc = amount_sat/Decimal(COIN)
amount_btc = amount_msat/Decimal(COIN*1000)
payment_preimage = os.urandom(32)
RHASH = sha256(payment_preimage)
info = PaymentInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID)
info = PaymentInfo(RHASH, amount_msat, RECEIVED, PR_UNPAID)
w2.save_preimage(RHASH, payment_preimage)
w2.save_payment_info(info)
if include_routing_hints:
routing_hints = await w2._calc_routing_hints_for_invoice(amount_sat)
routing_hints = await w2._calc_routing_hints_for_invoice(amount_msat)
else:
routing_hints = []
lnaddr = LnAddr(
@ -541,14 +541,14 @@ class TestPeer(ElectrumTestCase):
alice_init_balance_msat = alice_channel.balance(HTLCOwner.LOCAL)
bob_init_balance_msat = bob_channel.balance(HTLCOwner.LOCAL)
num_payments = 50
payment_value_sat = 10000 # make it large enough so that there are actually HTLCs on the ctx
payment_value_msat = 10_000_000 # make it large enough so that there are actually HTLCs on the ctx
max_htlcs_in_flight = asyncio.Semaphore(5)
async def single_payment(pay_req):
async with max_htlcs_in_flight:
await w1._pay(pay_req)
async def many_payments():
async with TaskGroup() as group:
pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_sat=payment_value_sat))
pay_reqs_tasks = [await group.spawn(self.prepare_invoice(w2, amount_msat=payment_value_msat))
for i in range(num_payments)]
async with TaskGroup() as group:
for pay_req_task in pay_reqs_tasks:
@ -560,10 +560,10 @@ class TestPeer(ElectrumTestCase):
await gath
with self.assertRaises(concurrent.futures.CancelledError):
run(f())
self.assertEqual(alice_init_balance_msat - num_payments * payment_value_sat * 1000, alice_channel.balance(HTLCOwner.LOCAL))
self.assertEqual(alice_init_balance_msat - num_payments * payment_value_sat * 1000, bob_channel.balance(HTLCOwner.REMOTE))
self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, bob_channel.balance(HTLCOwner.LOCAL))
self.assertEqual(bob_init_balance_msat + num_payments * payment_value_sat * 1000, alice_channel.balance(HTLCOwner.REMOTE))
self.assertEqual(alice_init_balance_msat - num_payments * payment_value_msat, alice_channel.balance(HTLCOwner.LOCAL))
self.assertEqual(alice_init_balance_msat - num_payments * payment_value_msat, bob_channel.balance(HTLCOwner.REMOTE))
self.assertEqual(bob_init_balance_msat + num_payments * payment_value_msat, bob_channel.balance(HTLCOwner.LOCAL))
self.assertEqual(bob_init_balance_msat + num_payments * payment_value_msat, alice_channel.balance(HTLCOwner.REMOTE))
@needs_test_with_all_chacha20_implementations
def test_payment_multihop(self):

14
electrum/wallet_db.py

@ -52,7 +52,7 @@ if TYPE_CHECKING:
OLD_SEED_VERSION = 4 # electrum versions < 2.0
NEW_SEED_VERSION = 11 # electrum versions >= 2.0
FINAL_SEED_VERSION = 36 # electrum >= 2.7 will set this to prevent
FINAL_SEED_VERSION = 37 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format
@ -184,6 +184,7 @@ class WalletDB(JsonDB):
self._convert_version_34()
self._convert_version_35()
self._convert_version_36()
self._convert_version_37()
self.put('seed_version', FINAL_SEED_VERSION) # just to be sure
self._after_upgrade_tasks()
@ -740,6 +741,17 @@ class WalletDB(JsonDB):
self.data['frozen_coins'] = new_frozen_coins
self.data['seed_version'] = 36
def _convert_version_37(self):
if not self._is_upgrade_method_needed(36, 36):
return
payments = self.data.get('lightning_payments', {})
for k, v in list(payments.items()):
amount_sat, direction, status = v
amount_msat = amount_sat * 1000 if amount_sat is not None else None
payments[k] = amount_msat, direction, status
self.data['lightning_payments'] = payments
self.data['seed_version'] = 37
def _convert_imported(self):
if not self._is_upgrade_method_needed(0, 13):
return

Loading…
Cancel
Save