diff --git a/tests/db.py b/tests/db.py new file mode 100644 index 000000000..0d12309e2 --- /dev/null +++ b/tests/db.py @@ -0,0 +1,169 @@ +from ephemeral_port_reserve import reserve + +import logging +import os +import psycopg2 +import random +import shutil +import signal +import sqlite3 +import string +import subprocess +import time + + +class Sqlite3Db(object): + def __init__(self, path): + self.path = path + + def get_dsn(self): + """SQLite3 doesn't provide a DSN, resulting in no CLI-option. + """ + return None + + def query(self, query): + orig = os.path.join(self.path) + copy = self.path + ".copy" + shutil.copyfile(orig, copy) + db = sqlite3.connect(copy) + + db.row_factory = sqlite3.Row + c = db.cursor() + c.execute(query) + rows = c.fetchall() + + result = [] + for row in rows: + result.append(dict(zip(row.keys(), row))) + + db.commit() + c.close() + db.close() + return result + + def execute(self, query): + db = sqlite3.connect(self.path) + c = db.cursor() + c.execute(query) + db.commit() + c.close() + db.close() + + +class PostgresDb(object): + def __init__(self, dbname, port): + self.dbname = dbname + self.port = port + + self.conn = psycopg2.connect("dbname={dbname} user=postgres host=localhost port={port}".format( + dbname=dbname, port=port + )) + cur = self.conn.cursor() + cur.execute('SELECT 1') + cur.close() + + def get_dsn(self): + return "postgres://postgres:password@localhost:{port}/{dbname}".format( + port=self.port, dbname=self.dbname + ) + + def query(self, query): + cur = self.conn.cursor() + cur.execute(query) + + # Collect the results into a list of dicts. + res = [] + for r in cur: + t = {} + # Zip the column definition with the value to get its name. + for c, v in zip(cur.description, r): + t[c.name] = v + res.append(t) + cur.close() + return res + + def execute(self, query): + with self.conn, self.conn.cursor() as cur: + cur.execute(query) + + +class SqliteDbProvider(object): + def __init__(self, directory): + self.directory = directory + + def start(self): + pass + + def get_db(self, node_directory, testname, node_id): + path = os.path.join( + node_directory, + 'lightningd.sqlite3' + ) + return Sqlite3Db(path) + + def stop(self): + pass + + +class PostgresDbProvider(object): + def __init__(self, directory): + self.directory = directory + self.port = None + self.proc = None + print("Starting PostgresDbProvider") + + def start(self): + passfile = os.path.join(self.directory, "pgpass.txt") + self.pgdir = os.path.join(self.directory, 'pgsql') + # Need to write a tiny file containing the password so `initdb` can pick it up + with open(passfile, 'w') as f: + f.write('cltest\n') + subprocess.check_call([ + '/usr/lib/postgresql/10/bin/initdb', + '--pwfile={}'.format(passfile), + '--pgdata={}'.format(self.pgdir), + '--auth=trust', + '--username=postgres', + ]) + self.port = reserve() + self.proc = subprocess.Popen([ + '/usr/lib/postgresql/10/bin/postgres', + '-k', '/tmp/', # So we don't use /var/lib/... + '-D', self.pgdir, + '-p', str(self.port), + '-F', + '-i', + ]) + # Hacky but seems to work ok (might want to make the postgres proc a TailableProc as well if too flaky). + time.sleep(1) + self.conn = psycopg2.connect("dbname=template1 user=postgres host=localhost port={}".format(self.port)) + + # Required for CREATE DATABASE to work + self.conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + + def get_db(self, node_directory, testname, node_id): + # Random suffix to avoid collisions on repeated tests + nonce = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8)) + dbname = "{}_{}_{}".format(testname, node_id, nonce) + + cur = self.conn.cursor() + cur.execute("CREATE DATABASE {};".format(dbname)) + cur.close() + db = PostgresDb(dbname, self.port) + return db + + def stop(self): + # Send fast shutdown signal see [1] for details: + # + # SIGINT + # + # This is the Fast Shutdown mode. The server disallows new connections + # and sends all existing server processes SIGTERM, which will cause + # them to abort their current transactions and exit promptly. It then + # waits for all server processes to exit and finally shuts down. If + # the server is in online backup mode, backup mode will be terminated, + # rendering the backup useless. + # + # [1] https://www.postgresql.org/docs/9.1/server-shutdown.html + self.proc.send_signal(signal.SIGINT) + self.proc.wait() diff --git a/tests/fixtures.py b/tests/fixtures.py index f6c4456c6..0b38b2de2 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,4 +1,5 @@ from concurrent import futures +from db import SqliteDbProvider, PostgresDbProvider from utils import NodeFactory, BitcoinD import logging @@ -149,12 +150,13 @@ def teardown_checks(request): @pytest.fixture -def node_factory(request, directory, test_name, bitcoind, executor, teardown_checks): +def node_factory(request, directory, test_name, bitcoind, executor, db_provider, teardown_checks): nf = NodeFactory( test_name, bitcoind, executor, directory=directory, + db_provider=db_provider, ) yield nf @@ -275,6 +277,21 @@ def checkMemleak(node): return 0 +# Mapping from TEST_DB_PROVIDER env variable to class to be used +providers = { + 'sqlite3': SqliteDbProvider, + 'postgres': PostgresDbProvider, +} + + +@pytest.fixture(scope="session") +def db_provider(test_base_dir): + provider = providers[os.getenv('TEST_DB_PROVIDER', 'sqlite3')](test_base_dir) + provider.start() + yield provider + provider.stop() + + @pytest.fixture def executor(teardown_checks): ex = futures.ThreadPoolExecutor(max_workers=20) diff --git a/tests/requirements.txt b/tests/requirements.txt index 4208596b7..bfa85f963 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -10,3 +10,4 @@ pytest-xdist==1.29.0 python-bitcoinlib==0.10.1 tqdm==4.32.2 pytest-timeout==1.3.3 +psycopg2==2.8.3 diff --git a/tests/test_misc.py b/tests/test_misc.py index 4248b2cc8..bae6fa1b2 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -969,12 +969,11 @@ def test_reserve_enforcement(node_factory, executor): l2.stop() # They should both aim for 1%. - reserves = l2.db_query('SELECT channel_reserve_satoshis FROM channel_configs') + reserves = l2.db.query('SELECT channel_reserve_satoshis FROM channel_configs') assert reserves == [{'channel_reserve_satoshis': 10**6 // 100}] * 2 # Edit db to reduce reserve to 0 so it will try to violate it. - l2.db_query('UPDATE channel_configs SET channel_reserve_satoshis=0', - use_copy=False) + l2.db.execute('UPDATE channel_configs SET channel_reserve_satoshis=0') l2.start() wait_for(lambda: only_one(l2.rpc.listpeers(l1.info['id'])['peers'])['connected']) diff --git a/tests/utils.py b/tests/utils.py index 92be43f5e..b9f946add 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -462,7 +462,9 @@ class LightningD(TailableProc): class LightningNode(object): - def __init__(self, daemon, rpc, btc, executor, may_fail=False, may_reconnect=False, allow_broken_log=False, allow_bad_gossip=False): + def __init__(self, daemon, rpc, btc, executor, may_fail=False, + may_reconnect=False, allow_broken_log=False, + allow_bad_gossip=False, db=None): self.rpc = rpc self.daemon = daemon self.bitcoin = btc @@ -471,6 +473,7 @@ class LightningNode(object): self.may_reconnect = may_reconnect self.allow_broken_log = allow_broken_log self.allow_bad_gossip = allow_bad_gossip + self.db = db def connect(self, remote_node): self.rpc.connect(remote_node.info['id'], '127.0.0.1', remote_node.daemon.port) @@ -510,28 +513,8 @@ class LightningNode(object): def getactivechannels(self): return [c for c in self.rpc.listchannels()['channels'] if c['active']] - def db_query(self, query, use_copy=True): - orig = os.path.join(self.daemon.lightning_dir, "lightningd.sqlite3") - if use_copy: - copy = os.path.join(self.daemon.lightning_dir, "lightningd-copy.sqlite3") - shutil.copyfile(orig, copy) - db = sqlite3.connect(copy) - else: - db = sqlite3.connect(orig) - - db.row_factory = sqlite3.Row - c = db.cursor() - c.execute(query) - rows = c.fetchall() - - result = [] - for row in rows: - result.append(dict(zip(row.keys(), row))) - - db.commit() - c.close() - db.close() - return result + def db_query(self, query): + return self.db.query(query) # Assumes node is stopped! def db_manip(self, query): @@ -771,7 +754,7 @@ class LightningNode(object): class NodeFactory(object): """A factory to setup and start `lightningd` daemons. """ - def __init__(self, testname, bitcoind, executor, directory): + def __init__(self, testname, bitcoind, executor, directory, db_provider): self.testname = testname self.next_id = 1 self.nodes = [] @@ -779,6 +762,7 @@ class NodeFactory(object): self.bitcoind = bitcoind self.directory = directory self.lock = threading.Lock() + self.db_provider = db_provider def split_options(self, opts): """Split node options from cli options @@ -880,11 +864,17 @@ class NodeFactory(object): if options is not None: daemon.opts.update(options) + # Get the DB backend DSN we should be using for this test and this node. + db = self.db_provider.get_db(lightning_dir, self.testname, node_id) + dsn = db.get_dsn() + if dsn is not None: + daemon.opts['wallet'] = dsn + rpc = LightningRpc(socket_path, self.executor) node = LightningNode(daemon, rpc, self.bitcoind, self.executor, may_fail=may_fail, may_reconnect=may_reconnect, allow_broken_log=allow_broken_log, - allow_bad_gossip=allow_bad_gossip) + allow_bad_gossip=allow_bad_gossip, db=db) # Regtest estimatefee are unusable, so override. node.set_feerates(feerates, False)