You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

170 lines
4.9 KiB

from ephemeral_port_reserve import reserve
import logging
import os
import psycopg2
import random
import shutil
import signal
import sqlite3
import string
import subprocess
import time
class Sqlite3Db(object):
def __init__(self, path):
self.path = path
def get_dsn(self):
"""SQLite3 doesn't provide a DSN, resulting in no CLI-option.
"""
return None
def query(self, query):
orig = os.path.join(self.path)
copy = self.path + ".copy"
shutil.copyfile(orig, copy)
db = sqlite3.connect(copy)
db.row_factory = sqlite3.Row
c = db.cursor()
c.execute(query)
rows = c.fetchall()
result = []
for row in rows:
result.append(dict(zip(row.keys(), row)))
db.commit()
c.close()
db.close()
return result
def execute(self, query):
db = sqlite3.connect(self.path)
c = db.cursor()
c.execute(query)
db.commit()
c.close()
db.close()
class PostgresDb(object):
def __init__(self, dbname, port):
self.dbname = dbname
self.port = port
self.conn = psycopg2.connect("dbname={dbname} user=postgres host=localhost port={port}".format(
dbname=dbname, port=port
))
cur = self.conn.cursor()
cur.execute('SELECT 1')
cur.close()
def get_dsn(self):
return "postgres://postgres:password@localhost:{port}/{dbname}".format(
port=self.port, dbname=self.dbname
)
def query(self, query):
cur = self.conn.cursor()
cur.execute(query)
# Collect the results into a list of dicts.
res = []
for r in cur:
t = {}
# Zip the column definition with the value to get its name.
for c, v in zip(cur.description, r):
t[c.name] = v
res.append(t)
cur.close()
return res
def execute(self, query):
with self.conn, self.conn.cursor() as cur:
cur.execute(query)
class SqliteDbProvider(object):
def __init__(self, directory):
self.directory = directory
def start(self):
pass
def get_db(self, node_directory, testname, node_id):
path = os.path.join(
node_directory,
'lightningd.sqlite3'
)
return Sqlite3Db(path)
def stop(self):
pass
class PostgresDbProvider(object):
def __init__(self, directory):
self.directory = directory
self.port = None
self.proc = None
print("Starting PostgresDbProvider")
def start(self):
passfile = os.path.join(self.directory, "pgpass.txt")
self.pgdir = os.path.join(self.directory, 'pgsql')
# Need to write a tiny file containing the password so `initdb` can pick it up
with open(passfile, 'w') as f:
f.write('cltest\n')
subprocess.check_call([
'/usr/lib/postgresql/10/bin/initdb',
'--pwfile={}'.format(passfile),
'--pgdata={}'.format(self.pgdir),
'--auth=trust',
'--username=postgres',
])
self.port = reserve()
self.proc = subprocess.Popen([
'/usr/lib/postgresql/10/bin/postgres',
'-k', '/tmp/', # So we don't use /var/lib/...
'-D', self.pgdir,
'-p', str(self.port),
'-F',
'-i',
])
# Hacky but seems to work ok (might want to make the postgres proc a TailableProc as well if too flaky).
time.sleep(1)
self.conn = psycopg2.connect("dbname=template1 user=postgres host=localhost port={}".format(self.port))
# Required for CREATE DATABASE to work
self.conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
def get_db(self, node_directory, testname, node_id):
# Random suffix to avoid collisions on repeated tests
nonce = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8))
dbname = "{}_{}_{}".format(testname, node_id, nonce)
cur = self.conn.cursor()
cur.execute("CREATE DATABASE {};".format(dbname))
cur.close()
db = PostgresDb(dbname, self.port)
return db
def stop(self):
# Send fast shutdown signal see [1] for details:
#
# SIGINT
#
# This is the Fast Shutdown mode. The server disallows new connections
# and sends all existing server processes SIGTERM, which will cause
# them to abort their current transactions and exit promptly. It then
# waits for all server processes to exit and finally shuts down. If
# the server is in online backup mode, backup mode will be terminated,
# rendering the backup useless.
#
# [1] https://www.postgresql.org/docs/9.1/server-shutdown.html
self.proc.send_signal(signal.SIGINT)
self.proc.wait()