#!/usr/bin/env python
#
# Copyright (c) 2006, 2007 Canonical
#
# Written by Gustavo Niemeyer <gustavo@niemeyer.net>
#
# This file is part of Storm Object Relational Mapper.
#
# Storm is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation; either version 2.1 of
# the License, or (at your option) any later version.
#
# Storm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import glob
import optparse
import os
import re
import shutil
import socket
import subprocess
import sys
import time
import unittest

from pkg_resources import parse_version


def add_eggs_to_path():
    here = os.path.dirname(__file__)
    egg_paths = glob.glob(os.path.join(here, "*.egg"))
    sys.path[:0] = map(os.path.abspath, egg_paths)

# python setup.py test [--dry-run] puts $package.egg directories in the
# top directory, so we add them to sys.path here for convenience.
add_eggs_to_path()


def test_with_runner(runner):
    usage = "test.py [options] [<test filename>, ...]"

    parser = optparse.OptionParser(usage=usage)

    parser.add_option('--verbose', action='store_true')
    opts, args = parser.parse_args()

    if opts.verbose:
        runner.verbosity = 2

    # Import late, after any and all sys.path jiggery pokery.
    from storm.tests import find_tests

    suite = find_tests(args)
    result = runner.run(suite)
    return not result.wasSuccessful()


def test_with_trial():
    from twisted.trial import unittest as trial_unittest
    from twisted.trial.reporter import TreeReporter
    from twisted.trial.runner import TrialRunner

    # This is only imported to work around
    # https://twistedmatrix.com/trac/ticket/8267.
    trial_unittest

    runner = TrialRunner(reporterFactory=TreeReporter, realTimeErrors=True)
    return test_with_runner(runner)


def test_with_unittest():
    runner = unittest.TextTestRunner()
    return test_with_runner(runner)


def with_postgresfixture(runner_func):
    """If possible, wrap a test runner with code to set up PostgreSQL."""
    try:
        from postgresfixture import ClusterFixture
    except ImportError:
        return runner_func

    from six.moves.urllib.parse import urlunsplit
    from storm.uri import escape

    def wrapper():
        cluster = ClusterFixture("db")
        cluster.create()
        # Allow use of prepared transactions, which some tests need.
        pg_conf_path = os.path.join(cluster.datadir, "postgresql.conf")
        with open(pg_conf_path) as pg_conf_old:
            with open(pg_conf_path + ".new", "w") as pg_conf_new:
                for line in pg_conf_old:
                    pg_conf_new.write(re.sub(
                        r"^#(max_prepared_transactions.*)= 0", r"\1= 200",
                        line))
                os.fchmod(
                    pg_conf_new.fileno(),
                    os.fstat(pg_conf_old.fileno()).st_mode)
        os.rename(pg_conf_path + ".new", pg_conf_path)

        with cluster:
            cluster.createdb("storm_test")
            uri = urlunsplit(
                ("postgres", escape(cluster.datadir), "/storm_test", "", ""))
            os.environ["STORM_POSTGRES_URI"] = uri
            os.environ["STORM_POSTGRES_HOST_URI"] = uri
            return runner_func()

    return wrapper


def with_mysql(runner_func):
    """If possible, wrap a test runner with code to set up MySQL.

    Loosely based on the approach taken by pytest-mysql, although
    implemented separately.
    """
    try:
        import MySQLdb
    except ImportError:
        return runner_func

    from six.moves.urllib.parse import (
        urlencode,
        urlunsplit,
        )

    def wrapper():
        basedir = os.path.abspath("db-mysql")
        datadir = os.path.join(basedir, "data")
        unix_socket = os.path.join(basedir, "mysql.sock")
        logfile = os.path.join(basedir, "mysql.log")
        if os.path.exists(basedir):
            shutil.rmtree(basedir)
        os.makedirs(basedir)

        mysqld_version_output = subprocess.check_output(
            ["mysqld", "--version"], universal_newlines=True).rstrip("\n")
        version = re.search(
            r"Ver ([\d.]+)", mysqld_version_output, flags=re.I).group(1)
        if ("MariaDB" not in mysqld_version_output and
                parse_version(version) >= parse_version("5.7.6")):
            subprocess.check_call([
                "mysqld",
                "--initialize-insecure",
                "--datadir=%s" % datadir,
                "--tmpdir=%s" % basedir,
                "--log-error=%s" % logfile,
                ])
        else:
            subprocess.check_call([
                "mysql_install_db",
                "--datadir=%s" % datadir,
                "--tmpdir=%s" % basedir,
                ])
        with open("/dev/null", "w") as devnull:
            server_proc = subprocess.Popen([
                "mysqld_safe",
                "--datadir=%s" % datadir,
                "--pid-file=%s" % os.path.join(basedir, "mysql.pid"),
                "--socket=%s" % unix_socket,
                "--skip-networking",
                "--log-error=%s" % logfile,
                "--tmpdir=%s" % basedir,
                "--skip-syslog",
                # We don't care about durability of test data.  Try to
                # persuade MySQL to agree.
                "--innodb-doublewrite=0",
                "--innodb-flush-log-at-trx-commit=0",
                "--innodb-flush-method=O_DIRECT_NO_FSYNC",
                "--skip-innodb-file-per-table",
                "--sync-binlog=0",
                ], stdout=devnull)

        try:
            start_time = time.time()
            while time.time() < start_time + 60:
                code = server_proc.poll()
                if code is not None and code != 0:
                    raise Exception("mysqld_base exited %d" % code)

                sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
                try:
                    sock.connect(unix_socket)
                    break
                except socket.error:
                    pass  # try again
                finally:
                    sock.close()

                time.sleep(0.1)

            try:
                connection = MySQLdb.connect(
                    user="root", unix_socket=unix_socket)
                try:
                    cursor = connection.cursor()
                    try:
                        cursor.execute(
                            "CREATE DATABASE storm_test CHARACTER SET utf8;")
                    finally:
                        cursor.close()
                finally:
                    connection.close()
                uri = urlunsplit(
                    ("mysql", "root@localhost", "/storm_test",
                     urlencode({"unix_socket": unix_socket}), ""))
                os.environ["STORM_MYSQL_URI"] = uri
                os.environ["STORM_MYSQL_HOST_URI"] = uri
                return runner_func()
            finally:
                subprocess.check_call([
                    "mysqladmin",
                    "--socket=%s" % unix_socket,
                    "--user=root",
                    "shutdown",
                    ])
        finally:
            if server_proc.poll() is None:
                server_proc.kill()
            server_proc.wait()

    return wrapper


if __name__ == "__main__":
    runner = os.environ.get("STORM_TEST_RUNNER")
    if not runner:
        runner = "unittest"
    runner_func = globals().get("test_with_%s" % runner.replace(".", "_"))
    if not runner_func:
        sys.exit("Test runner not found: %s" % runner)
    runner_func = with_postgresfixture(runner_func)
    runner_func = with_mysql(runner_func)
    sys.exit(runner_func())

# vim:ts=4:sw=4:et
