# arch-tag: da321d26-9379-469d-b966-1cca8e5822ae
# Copyright (C) 2004 David Allouche <david@allouche.net>
#               2004 Canonical Ltd.
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

"""Testing framework
"""

import sys
import types
import inspect
import os
import unittest
import logging

from pybaz.util import DirName
import fixtures


class Sandbox(fixtures.Sandbox):

    def __init__(self):
        super(Sandbox, self).__init__()
        self._observers = list()
        self._set_up = False

    def subscribe(self, observer):
        if observer in self._observers:
            return
        self._observers.append(observer)

    def setUp(self):
        self._set_up = True
        super(Sandbox, self).setUp()
        for obs in self._observers:
            obs.notify_setup(self)

    def tearDown(self):
        if not self._set_up:
            return
        super(Sandbox, self).tearDown()


class TestParams(object):

    def __init__(self, sandbox=None):
        if sandbox: sandbox.subscribe(self)
        self.my_id = "John Doe <jdoe@example.com>"
        self.arch_name = 'jdoe@example.com--example--9999'
        self.arch_dir_base = DirName(r'pyarch tests')
        self.tla_archive_format = False

    def notify_setup(self, sandbox):
        self.arch_dir = sandbox.tmp_dir / self.arch_dir_base
        self.sandbox_dir = sandbox.tmp_dir
        os.mkdir(self.arch_dir)
        self.working_dir = DirName(self.arch_dir/'workingtree')
        self.nested_dir = DirName(self.working_dir/'nestedtree')
        self.working_dir_newline = DirName(self.arch_dir/'working\ntree')
        self.other_working_dir = DirName(self.arch_dir/'othertree')

    def set_my_id(self):
        import pybaz
        pybaz.set_my_id(self.my_id)

    def create_archive(self):
        import pybaz
        name = self.arch_name
        self.archive = pybaz.make_archive(
            name, self.arch_dir/name, tla=self.tla_archive_format)

    def _get_version(self):
        import pybaz
        return pybaz.Version(self.arch_name+'/cat--brn--1.0')
    version = property(_get_version)

    def _get_other_version(self):
        import pybaz
        return pybaz.Version(self.arch_name+'/cat--other--0')
    other_version = property(_get_other_version)

    def create_archive_and_mirror(self):
        self.create_archive()
        master = self.archive
        mirror_name = master.name + '-MIRROR'
        location = self.arch_dir/mirror_name
        self.mirror = master.make_mirror(mirror_name, location)

    def _create_tree_helper(self, dirname, version):
        import pybaz as arch
        os.mkdir(dirname)
        if version is None:
            return arch.init_tree(dirname)
        else:
            return arch.init_tree(dirname, version)

    def create_working_tree(self, version=None):
        self.working_tree = self._create_tree_helper(self.working_dir, version)

    def create_other_tree(self, version=None):
        self.other_tree = (
            self._create_tree_helper(self.other_working_dir, version))

    def commit_summary(self, tree, summary):
        m = tree.log_message()
        m['Summary'] = m.description = summary
        tree.commit(m)

    def make_history(self, tree, history):
        fixtures.History(history).run(tree)


class TestCase(unittest.TestCase):

    sandbox = None
    params_factory = TestParams
    params = None

    def setUp(self):
        self.sandbox = Sandbox()
        self.params = self.params_factory(self.sandbox)
        self.sandbox.setUp()
        self.extraSetup()

    def extraSetup(self):
        pass

    def tearDown(self):
        if self.sandbox:
            self.sandbox.tearDown()


class NewTestCase(unittest.TestCase):

    __mementos = {}

    fixture = fixtures.NullFixture()

    def setUp(self):
        key = self.fixture.memento_key()
        if key not in self.__mementos:
            self.fixture.setUp()
            self.__mementos[key] = self.fixture.create_memento()
        else:
            memento = self.__mementos[key]
            self.fixture.set_memento(memento)

    def tearDown(self):
        self.fixture.tearDown()


class OrderedTestLoader(unittest.TestLoader):
    """TestLoader that orders TestCases and methods according to the source."""

    def getTestCaseNames(self, testCaseClass):
        """Return a sorted sequence of method names found within testCaseClass
        """
        if hasattr(testCaseClass, 'tests'):
            return testCaseClass.tests # explicit ordering
        base_classes = inspect.getmro(testCaseClass)
        # process subclasses first, so the position of overriding methods is
        # what is considered for ordering.
        all_methods = []
        for base_class in base_classes:
            methods = []
            for name in dir(base_class):
                if not name.startswith(self.testMethodPrefix):
                    continue
                if name in all_methods:
                    continue # handle overriden methods
                method = getattr(base_class, name)
                ignored, starting = inspect.getsourcelines(method)
                methods.append((starting, name))
            # sort per-class so we do not get confused by base classes defined
            # in different modules.
            methods.sort()
            for ignored, name in methods:
                all_methods.append(name)
        return all_methods

    def loadTestsFromModule(self, module):
        """Return a suite of all tests cases contained in the given module"""
        classes = []
        for name in dir(module):
            obj = getattr(module, name)
            if (isinstance(obj, (type, types.ClassType)) and
                issubclass(obj, unittest.TestCase)):
                ignored, starting = inspect.getsourcelines(obj)
                classes.append((starting, obj))
        classes.sort()
        tests = []
        for ignored, class_ in classes:
            tests.append(self.loadTestsFromTestCase(class_))
        return self.suiteClass(tests)

def _set_logging_level(level_name):
    mapping = {'debug': logging.DEBUG,
               'info': logging.INFO,
               'warning': logging.WARNING,
               'error': logging.ERROR,
               'critical': logging.critical}
    logging.getLogger().setLevel(mapping[level_name])

def main(**kwargs):
    logging.basicConfig()
    _set_logging_level(os.environ.get('LOGGING', 'warning'))
    #unittest.main(testLoader=OrderedTestLoader(), **kwargs)
    unittest.main(**kwargs)

def register(name):
    def test_suite():
        #return OrderedTestLoader().loadTestsFromModule(sys.modules[name])
        return unittest.findTestCases(sys.modules[name])
    module = sys.modules[name]
    module.test_suite = test_suite
    test_modules.append(name)
    if name == "__main__":
        main()

def collect_test_modules():
    global test_modules
    retval, test_modules = test_modules, []
    return retval

test_modules = []
