diff --git a/wallet/db.c b/wallet/db.c index 2e741ab1f..965feee1b 100644 --- a/wallet/db.c +++ b/wallet/db.c @@ -10,7 +10,6 @@ #include #include -#define DB_FILE "lightningd.sqlite3" #define NSEC_IN_SEC 1000000000 struct migration { @@ -645,17 +644,6 @@ void db_commit_transaction(struct db *db) db->in_transaction = NULL; } -static void setup_open_db(struct db *db) -{ - /* This must be outside a transaction, so catch it */ - assert(!db->in_transaction); - - db_prepare_for_changes(db); - if (db->config->setup_fn) - db->config->setup_fn(db); - db_report_changes(db, NULL, 0); -} - static struct db_config *db_config_find(const char *driver_name) { size_t num_configs; @@ -671,38 +659,28 @@ static struct db_config *db_config_find(const char *driver_name) */ static struct db *db_open(const tal_t *ctx, char *filename) { - int err; struct db *db; - sqlite3 *sql; const char *driver_name = "sqlite3"; - int flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE; - err = sqlite3_open_v2(filename, &sql, flags, NULL); - - if (err != SQLITE_OK) { - db_fatal("failed to open database %s: %s", filename, - sqlite3_errstr(err)); - } - db = tal(ctx, struct db); db->filename = tal_strdup(db, filename); - db->sql = sql; - db->config = NULL; list_head_init(&db->pending_statements); db->config = db_config_find(driver_name); if (!db->config) db_fatal("Unable to find DB driver for %s", driver_name); - // FIXME(cdecker) Once we parse DB connection strings this needs to be - // instantiated correctly. - db->conn = sql; - tal_add_destructor(db, destroy_db); db->in_transaction = NULL; db->changes = NULL; - setup_open_db(db); + /* This must be outside a transaction, so catch it */ + assert(!db->in_transaction); + + db_prepare_for_changes(db); + if (db->config->setup_fn) + db->config->setup_fn(db); + db_report_changes(db, NULL, 0); return db; } @@ -787,8 +765,7 @@ static void db_migrate(struct lightningd *ld, struct db *db, struct log *log) struct db *db_setup(const tal_t *ctx, struct lightningd *ld, struct log *log) { - struct db *db = db_open(ctx, DB_FILE); - + struct db *db = db_open(ctx, ld->wallet_dsn); db_migrate(ld, db, log); return db; } diff --git a/wallet/db_common.h b/wallet/db_common.h index 74b21b3ed..7b570a7ce 100644 --- a/wallet/db_common.h +++ b/wallet/db_common.h @@ -15,7 +15,6 @@ struct db { char *filename; const char *in_transaction; - sqlite3 *sql; /* DB-specific context */ void *conn; diff --git a/wallet/db_sqlite3.c b/wallet/db_sqlite3.c index 1dd606460..d6523aeac 100644 --- a/wallet/db_sqlite3.c +++ b/wallet/db_sqlite3.c @@ -24,8 +24,25 @@ static const char *db_sqlite3_fmt_error(struct db_stmt *stmt) static bool db_sqlite3_setup(struct db *db) { + char *filename; sqlite3_stmt *stmt; - int err; + sqlite3 *sql; + int err, flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE; + + if (!strstarts(db->filename, "sqlite3://") || strlen(db->filename) < 10) + db_fatal("Could not parse the wallet DSN: %s", db->filename); + + /* Strip the scheme from the dsn. */ + filename = db->filename + strlen("sqlite3://"); + + err = sqlite3_open_v2(filename, &sql, flags, NULL); + + if (err != SQLITE_OK) { + db_fatal("failed to open database %s: %s", filename, + sqlite3_errstr(err)); + } + db->conn = sql; + sqlite3_prepare_v2(db->conn, "PRAGMA foreign_keys = ON;", -1, &stmt, NULL); err = sqlite3_step(stmt); sqlite3_finalize(stmt); @@ -197,7 +214,8 @@ static size_t db_sqlite3_count_changes(struct db_stmt *stmt) static void db_sqlite3_close(struct db *db) { - sqlite3_close(db->sql); + sqlite3_close(db->conn); + db->conn = NULL; } static u64 db_sqlite3_last_insert_id(struct db_stmt *stmt) diff --git a/wallet/test/run-db.c b/wallet/test/run-db.c index 1e444eff8..80a2b8815 100644 --- a/wallet/test/run-db.c +++ b/wallet/test/run-db.c @@ -49,14 +49,16 @@ void plugin_hook_db_sync(struct db *db UNNEEDED, const char **changes UNNEEDED, static struct db *create_test_db(void) { struct db *db; - char filename[] = "/tmp/ldb-XXXXXX"; + char *dsn, filename[] = "/tmp/ldb-XXXXXX"; int fd = mkstemp(filename); if (fd == -1) return NULL; close(fd); - db = db_open(NULL, filename); + dsn = tal_fmt(NULL, "sqlite3://%s", filename); + db = db_open(NULL, dsn); + tal_free(dsn); return db; } diff --git a/wallet/test/run-wallet.c b/wallet/test/run-wallet.c index 5406e669a..05b3c8b9d 100644 --- a/wallet/test/run-wallet.c +++ b/wallet/test/run-wallet.c @@ -727,14 +727,16 @@ static void cleanup_test_wallet(struct wallet *w, char *filename) static struct wallet *create_test_wallet(struct lightningd *ld, const tal_t *ctx) { - char *filename = tal_fmt(ctx, "/tmp/ldb-XXXXXX"); + char *dsn, *filename = tal_fmt(ctx, "/tmp/ldb-XXXXXX"); int fd = mkstemp(filename); struct wallet *w = tal(ctx, struct wallet); static unsigned char badseed[BIP32_ENTROPY_LEN_128]; CHECK_MSG(fd != -1, "Unable to generate temp filename"); close(fd); - w->db = db_open(w, filename); + dsn = tal_fmt(NULL, "sqlite3://%s", filename); + w->db = db_open(w, dsn); + tal_free(dsn); tal_add_destructor2(w, cleanup_test_wallet, filename); list_head_init(&w->unstored_payments);