diff --git a/lightningd/hsm_control.c b/lightningd/hsm_control.c index 5d67e3bfa..b0c313d54 100644 --- a/lightningd/hsm_control.c +++ b/lightningd/hsm_control.c @@ -42,6 +42,24 @@ int hsm_get_client_fd(struct lightningd *ld, return hsm_fd; } +static unsigned int hsm_msg(struct subd *hsmd, + const u8 *msg, const int *fds UNUSED) +{ + /* We only expect one thing from the HSM that's not a STATUS message */ + struct pubkey client_id; + u8 *bad_msg; + + if (!fromwire_hsmstatus_client_bad_request(tmpctx, msg, &client_id, + &bad_msg)) + fatal("Bad status message from hsmd: %s", tal_hex(tmpctx, msg)); + + /* This should, of course, never happen. */ + log_broken(hsmd->log, "client %s sent bad hsm request %s", + type_to_string(tmpctx, struct pubkey, &client_id), + tal_hex(tmpctx, bad_msg)); + return 0; +} + void hsm_init(struct lightningd *ld) { u8 *msg; @@ -51,7 +69,9 @@ void hsm_init(struct lightningd *ld) if (socketpair(AF_LOCAL, SOCK_STREAM, 0, fds) != 0) err(1, "Could not create hsm socketpair"); - ld->hsm = new_global_subd(ld, "lightning_hsmd", NULL, NULL, + ld->hsm = new_global_subd(ld, "lightning_hsmd", + hsm_client_wire_type_name, + hsm_msg, take(&fds[1]), NULL); if (!ld->hsm) err(1, "Could not subd hsm"); diff --git a/lightningd/subd.c b/lightningd/subd.c index 589f1a412..c6a667879 100644 --- a/lightningd/subd.c +++ b/lightningd/subd.c @@ -400,6 +400,8 @@ static struct io_plan *sd_msg_read(struct io_conn *conn, struct subd *sd) struct subd_req *sr; struct db *db = sd->ld->wallet->db; struct io_plan *plan; + unsigned int i; + bool freed = false; /* Everything we do, we wrap in a database transaction */ db_begin_transaction(db); @@ -464,29 +466,25 @@ static struct io_plan *sd_msg_read(struct io_conn *conn, struct subd *sd) } log_debug(sd->log, "UPDATE %s", sd->msgname(type)); - if (sd->msgcb) { - unsigned int i; - bool freed = false; - /* Might free sd (if returns negative); save/restore sd->conn */ - sd->conn = NULL; - tal_add_destructor2(sd, mark_freed, &freed); + /* Might free sd (if returns negative); save/restore sd->conn */ + sd->conn = NULL; + tal_add_destructor2(sd, mark_freed, &freed); - i = sd->msgcb(sd, sd->msg_in, sd->fds_in); - if (freed) - goto close; - tal_del_destructor2(sd, mark_freed, &freed); + i = sd->msgcb(sd, sd->msg_in, sd->fds_in); + if (freed) + goto close; + tal_del_destructor2(sd, mark_freed, &freed); - sd->conn = conn; + sd->conn = conn; - if (i != 0) { - /* Don't ask for fds twice! */ - assert(!sd->fds_in); - /* Don't free msg_in: we go around again. */ - tal_steal(sd, sd->msg_in); - plan = sd_collect_fds(conn, sd, i); - goto out; - } + if (i != 0) { + /* Don't ask for fds twice! */ + assert(!sd->fds_in); + /* Don't free msg_in: we go around again. */ + tal_steal(sd, sd->msg_in); + plan = sd_collect_fds(conn, sd, i); + goto out; } next: @@ -651,7 +649,9 @@ static struct subd *new_subd(struct lightningd *ld, sd->must_not_exit = false; sd->talks_to_peer = talks_to_peer; sd->msgname = msgname; + assert(msgname); sd->msgcb = msgcb; + assert(msgcb); sd->errcb = errcb; sd->billboardcb = billboardcb; sd->fds_in = NULL; diff --git a/lightningd/subd.h b/lightningd/subd.h index 72e929eab..a0ea87a36 100644 --- a/lightningd/subd.h +++ b/lightningd/subd.h @@ -76,7 +76,7 @@ struct subd { * @ld: global state * @name: basename of daemon * @msgname: function to get name from messages - * @msgcb: function to call (inside db transaction) when non-fatal message received (or NULL) + * @msgcb: function to call (inside db transaction) when non-fatal message received * @...: NULL-terminated list of pointers to fds to hand as fd 3, 4... * (can be take, if so, set to -1) * diff --git a/tests/fixtures.py b/tests/fixtures.py index a097e2498..b9a2e07a4 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -134,6 +134,11 @@ def node_factory(request, directory, test_name, bitcoind, executor): err_count += checkBadReestablish(node) check_errors(request, err_count, "{} nodes had bad reestablish") + for node in nf.nodes: + err_count += checkBadHSMRequest(node) + if err_count: + raise ValueError("{} nodes had bad hsm requests".format(err_count)) + if not ok: request.node.has_errors = True raise Exception("At least one lightning exited with unexpected non-zero return code") @@ -201,6 +206,12 @@ def checkBadReestablish(node): return 0 +def checkBadHSMRequest(node): + if node.daemon.is_in_log('bad hsm request'): + return 1 + return 0 + + @pytest.fixture def executor(): ex = futures.ThreadPoolExecutor(max_workers=20)