Browse Source

lnpeer: ignore unknown 'odd' type messages

from BOLT-01:
A receiving node:
  - upon receiving a message of odd, unknown type:
    - MUST ignore the received message.

b201efe054/01-messaging.md (lightning-message-format)
patch-4
SomberNight 4 years ago
parent
commit
c912036180
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 9
      electrum/lnmsg.py
  2. 8
      electrum/lnpeer.py
  3. 90
      electrum/tests/test_lnpeer.py

9
electrum/lnmsg.py

@ -9,9 +9,11 @@ from .lnutil import OnionFailureCodeMetaFlag
class FailedToParseMsg(Exception): pass
class MalformedMsg(FailedToParseMsg): pass
class UnknownMsgType(FailedToParseMsg): pass
class UnknownOptionalMsgType(UnknownMsgType): pass
class UnknownMandatoryMsgType(UnknownMsgType): pass
class MalformedMsg(FailedToParseMsg): pass
class UnknownMsgFieldType(MalformedMsg): pass
class UnexpectedEndOfStream(MalformedMsg): pass
class FieldEncodingNotMinimal(MalformedMsg): pass
@ -479,7 +481,10 @@ class LNSerializer:
try:
scheme = self.msg_scheme_from_type[msg_type_bytes]
except KeyError:
raise UnknownMsgType(f"msg_type={msg_type_int}") # TODO even/odd type?
if msg_type_int % 2 == 0: # even types must be understood: "mandatory"
raise UnknownMandatoryMsgType(f"msg_type={msg_type_int}")
else: # odd types are ok not to understand: "optional"
raise UnknownOptionalMsgType(f"msg_type={msg_type_int}")
assert scheme[0][2] == msg_type_int
msg_type_name = scheme[0][1]
parsed = {}

8
electrum/lnpeer.py

@ -46,7 +46,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
UpfrontShutdownScriptViolation)
from .lnutil import FeeUpdate, channel_id_from_funding_tx
from .lntransport import LNTransport, LNTransportBase
from .lnmsg import encode_msg, decode_msg
from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType
from .interface import GracefulDisconnect
from .lnrouter import fee_for_edge_msat
from .lnutil import ln_dummy_address
@ -179,7 +179,11 @@ class Peer(Logger):
self.ping_time = time.time()
def process_message(self, message):
message_type, payload = decode_msg(message)
try:
message_type, payload = decode_msg(message)
except UnknownOptionalMsgType as e:
self.logger.info(f"received unknown message from peer. ignoring: {e!r}")
return
# only process INIT if we are a backup
if self.is_channel_backup is True and message_type != 'init':
return

90
electrum/tests/test_lnpeer.py

@ -29,6 +29,7 @@ from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound
from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger
from electrum.lnworker import PaymentInfo, RECEIVED
from electrum.lnonion import OnionFailureCode
@ -1086,6 +1087,95 @@ class TestPeer(TestCaseForTestnet):
with self.assertRaises(PaymentFailure):
run(f())
@needs_test_with_all_chacha20_implementations
def test_sending_weird_messages_that_should_be_ignored(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def send_weird_messages():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
# peer1 sends known message with trailing garbage
# BOLT-01 says peer2 should ignore trailing garbage
raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4) + bytes(range(55))
p1.transport.send_bytes(raw_msg1)
await asyncio.sleep(0.05)
# peer1 sends unknown 'odd-type' message
# BOLT-01 says peer2 should ignore whole message
raw_msg2 = (43333).to_bytes(length=2, byteorder="big") + bytes(range(55))
p1.transport.send_bytes(raw_msg2)
await asyncio.sleep(0.05)
raise TestSuccess()
async def f():
async with TaskGroup() as group:
for peer in [p1, p2]:
await group.spawn(peer._message_loop())
await group.spawn(peer.htlc_switch())
await asyncio.sleep(0.2)
await group.spawn(send_weird_messages())
with self.assertRaises(TestSuccess):
run(f())
@needs_test_with_all_chacha20_implementations
def test_sending_weird_messages__unknown_even_type(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def send_weird_messages():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
# peer1 sends unknown 'even-type' message
# BOLT-01 says peer2 should close the connection
raw_msg2 = (43334).to_bytes(length=2, byteorder="big") + bytes(range(55))
p1.transport.send_bytes(raw_msg2)
await asyncio.sleep(0.05)
failing_task = None
async def f():
nonlocal failing_task
async with TaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
failing_task = await group.spawn(p2._message_loop())
await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.2)
await group.spawn(send_weird_messages())
with self.assertRaises(lnmsg.UnknownMandatoryMsgType):
run(f())
self.assertTrue(isinstance(failing_task.exception(), lnmsg.UnknownMandatoryMsgType))
@needs_test_with_all_chacha20_implementations
def test_sending_weird_messages__known_msg_with_insufficient_length(self):
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
async def send_weird_messages():
await asyncio.wait_for(p1.initialized, 1)
await asyncio.wait_for(p2.initialized, 1)
# peer1 sends known message with insufficient length for the contents
# BOLT-01 says peer2 should fail the connection
raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4)[:-1]
p1.transport.send_bytes(raw_msg1)
await asyncio.sleep(0.05)
failing_task = None
async def f():
nonlocal failing_task
async with TaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
failing_task = await group.spawn(p2._message_loop())
await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.2)
await group.spawn(send_weird_messages())
with self.assertRaises(lnmsg.UnexpectedEndOfStream):
run(f())
self.assertTrue(isinstance(failing_task.exception(), lnmsg.UnexpectedEndOfStream))
def run(coro):
return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()

Loading…
Cancel
Save