diff --git a/wallet/db.c b/wallet/db.c index 180a8d174..1e2f7958a 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -554,6 +554,50 @@ sqlite3_stmt *db_select_prepare_(const char *location, struct db *db, const char return stmt; } +struct db_stmt *db_prepare_v2_(const char *location, struct db *db, + const char *query_id) +{ + struct db_stmt *stmt = tal(db, struct db_stmt); + size_t num_slots; + stmt->query = NULL; + + /* Normalize query_id paths, because unit tests are compiled with this + * prefix. */ + if (strncmp(query_id, "./", 2) == 0) + query_id += 2; + + /* Look up the query by its ID */ + for (size_t i = 0; i < db->config->num_queries; i++) { + if (streq(query_id, db->config->queries[i].query)) { + stmt->query = &db->config->queries[i]; + break; + } + } + if (stmt->query == NULL) + fatal("Could not resolve query %s", query_id); + + num_slots = stmt->query->placeholders; + /* Allocate the slots for placeholders/bindings, zeroed next since + * that sets the type to DB_BINDING_UNINITIALIZED for later checks. */ + stmt->bindings = tal_arr(stmt, struct db_binding, num_slots); + for (size_t i=0; ibindings[i].type = DB_BINDING_UNINITIALIZED; + + stmt->location = location; + stmt->error = NULL; + stmt->db = db; + return stmt; +} + +void db_stmt_free(struct db_stmt *stmt) +{ + stmt->db->config->stmt_free_fn(stmt); +} + +#define db_prepare_v2(db,query) \ + db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query) + + bool db_select_step_(const char *location, struct db *db, struct sqlite3_stmt *stmt) { int ret; @@ -1247,3 +1291,64 @@ struct timeabs sqlite3_column_timeabs(sqlite3_stmt *stmt, int col) return t; } + +void db_bind_null(struct db_stmt *stmt, int pos) +{ + assert(pos < tal_count(stmt->bindings)); + stmt->bindings[pos].type = DB_BINDING_NULL; +} + +void db_bind_int(struct db_stmt *stmt, int pos, int val) +{ + assert(pos < tal_count(stmt->bindings)); + stmt->bindings[pos].type = DB_BINDING_INT; + stmt->bindings[pos].v.i = val; +} + +void db_bind_u64(struct db_stmt *stmt, int pos, u64 val) +{ + assert(pos < tal_count(stmt->bindings)); + stmt->bindings[pos].type = DB_BINDING_UINT64; + stmt->bindings[pos].v.u64 = val; +} + +void db_bind_blob(struct db_stmt *stmt, int pos, u8 *val, size_t len) +{ + assert(pos < tal_count(stmt->bindings)); + stmt->bindings[pos].type = DB_BINDING_BLOB; + stmt->bindings[pos].v.blob = val; + stmt->bindings[pos].len = len; +} + +void db_bind_text(struct db_stmt *stmt, int pos, const char *val) +{ + assert(pos < tal_count(stmt->bindings)); + stmt->bindings[pos].type = DB_BINDING_TEXT; + stmt->bindings[pos].v.text = val; + stmt->bindings[pos].len = strlen(val); +} + +bool db_exec_prepared_v2(struct db_stmt *stmt TAKES) +{ + const char *expanded_sql; + bool ret = stmt->db->config->exec_fn(stmt); + if (stmt->db->config->expand_fn != NULL && ret && + !stmt->query->readonly) { + expanded_sql = stmt->db->config->expand_fn(stmt); + tal_arr_expand(&stmt->db->changes, + tal_strdup(stmt->db->changes, expanded_sql)); + } + + /* The driver itself doesn't call `fatal` since we want to override it + * for testing. Instead we check here that the error message is set if + * we report an error. */ + if (!ret) { + assert(stmt->error); + db_fatal("Error executing statement: %s", stmt->error); + } + + if (taken(stmt)) + tal_free(stmt); + + return ret; +} diff --git a/wallet/db.h b/wallet/db.h index 5c5d8c7b0..8a996d71b 100644 --- a/wallet/db.h +++ b/wallet/db.h @@ -18,7 +18,7 @@ struct lightningd; struct log; struct node_id; - +struct db_stmt; struct db; /** @@ -241,4 +241,19 @@ void sqlite3_bind_amount_sat(sqlite3_stmt *stmt, int col, void sqlite3_bind_timeabs(sqlite3_stmt *stmt, int col, struct timeabs t); struct timeabs sqlite3_column_timeabs(sqlite3_stmt *stmt, int col); + +void db_bind_null(struct db_stmt *stmt, int pos); +void db_bind_int(struct db_stmt *stmt, int pos, int val); +void db_bind_u64(struct db_stmt *stmt, int pos, u64 val); +void db_bind_blob(struct db_stmt *stmt, int pos, u8 *val, size_t len); +void db_bind_text(struct db_stmt *stmt, int pos, const char *val); +bool db_exec_prepared_v2(struct db_stmt *stmt TAKES); + +void db_stmt_free(struct db_stmt *stmt); + +struct db_stmt *db_prepare_v2_(const char *location, struct db *db, + const char *query_id); +#define db_prepare_v2(db,query) \ + db_prepare_v2_(__FILE__ ":" stringify(__LINE__), db, query) + #endif /* LIGHTNING_WALLET_DB_H */ diff --git a/wallet/test/run-db.c b/wallet/test/run-db.c index 355c6458b..ba36aa152 100644 --- a/wallet/test/run-db.c +++ b/wallet/test/run-db.c @@ -24,6 +24,9 @@ size_t bigsize_get(const u8 *p UNNEEDED, size_t max UNNEEDED, bigsize_t *val UNN /* Generated stub for bigsize_put */ size_t bigsize_put(u8 buf[BIGSIZE_MAX_LEN] UNNEEDED, bigsize_t v UNNEEDED) { fprintf(stderr, "bigsize_put called!\n"); abort(); } +/* Generated stub for fatal */ +void fatal(const char *fmt UNNEEDED, ...) +{ fprintf(stderr, "fatal called!\n"); abort(); } /* AUTOGENERATED MOCKS END */ static char *db_err;