# Copyright 2013 Canonical Ltd.  This software is licensed under the
# GNU Affero General Public License version 3 (see the file LICENSE).

"""Tests for `maastest.utils`."""

from __future__ import (
    absolute_import,
    print_function,
    unicode_literals,
    )

__metaclass__ = type
__all__ = []

from io import BytesIO
import os.path
from pipes import quote
import platform
from subprocess import PIPE
from textwrap import dedent
import time
import re

import distro_info
from fixtures import TempDir
from maastest import utils
from maastest.utils import binary_content
import mock
from six import text_type
import testtools
from testtools.matchers import (
    ContainsAll,
    MatchesRegex,
    MatchesStructure,
    )


class TestRunCommand(testtools.TestCase):

    def test_read_file_returns_file_contents_as_bytes(self):
        temp_dir = self.useFixture(TempDir())
        sample_file = os.path.join(temp_dir.path, self.getUniqueString())
        contents = self.getUniqueString().encode('ascii')
        with open(sample_file, 'wb') as f:
            f.write(contents)

        read_contents = utils.read_file(sample_file)

        self.assertEqual(contents, read_contents)
        self.assertIsInstance(contents, bytes)

    def test_run_command_calls_Popen(self):
        mock_Popen = mock.MagicMock()
        expected_retcode = self.getUniqueInteger()
        expected_stdout = self.getUniqueString()
        expected_stderr = self.getUniqueString()
        mock_Popen.return_value.returncode = expected_retcode
        mock_Popen.return_value.communicate.return_value = (
            expected_stdout, expected_stderr)
        self.patch(utils, 'Popen', mock_Popen)
        args = ['one', 'two']

        retcode, stdout, stderr = utils.run_command(args)

        self.assertEqual(
            [
                mock.call(
                    ['one', 'two'], stdout=PIPE, stderr=PIPE,
                    stdin=PIPE, shell=False),
                mock.call().communicate(None),
            ],
            mock_Popen.mock_calls)
        self.assertEqual(
            (expected_retcode, expected_stdout, expected_stderr),
            (retcode, stdout, stderr))

    def test_run_command_runs_command(self):
        retcode, stdout, stderr = utils.run_command(['ls', '/'])

        self.assertEqual((0, b''), (retcode, stderr))
        self.assertIn(b"boot", stdout)

    def test_run_command_checks_return_value(self):
        mock_Popen = mock.MagicMock()
        expected_retcode = 2
        expected_stdout = self.getUniqueString()
        expected_stderr = self.getUniqueString()
        mock_Popen.return_value.returncode = expected_retcode
        mock_Popen.return_value.communicate.return_value = (
            expected_stdout, expected_stderr)
        self.patch(utils, 'Popen', mock_Popen)
        args = ['one', 'two']

        error = self.assertRaises(
            Exception, utils.run_command, args, check_call=True)
        self.assertIn(expected_stdout, text_type(error))
        self.assertIn(expected_stderr, text_type(error))

    def test_run_command_uses_input(self):
        input_string = self.getUniqueString().encode("ascii")
        retcode, stdout, stderr = utils.run_command(
            ['cat', '-'], input=input_string)

        self.assertEqual(
            (0, b'', input_string), (retcode, stderr, stdout))

    def test_make_exception_contains_details(self):
        args = ['ls', '/']
        retcode = 58723
        stdout = self.getUniqueString()
        stderr = self.getUniqueString()
        exception = utils.make_exception(args, retcode, stdout, stderr)
        self.assertThat(
            text_type(exception),
            ContainsAll([
                stdout,
                stderr,
                str(retcode),
                " ".join(quote(arg) for arg in args),
            ]))


class TestBinaryContent(testtools.TestCase):
    """Tests for `binary_content`."""

    def test_returns_same_content(self):
        content = binary_content(b'abc123')
        self.assertEqual(
            b''.join(content.iter_bytes()),
            b'abc123')

    def test_has_appropriate_content_type(self):
        content = binary_content(b'abc123')
        self.assertThat(
            content.content_type,
            MatchesStructure.byEquality(
                type="application",
                subtype="octet-stream",
                parameters={},
            ))


class TestGetURI(testtools.TestCase):
    """Tests for `get_uri`."""

    def test_returns_api_root_plus_path(self):
        path = "this/is/a/path"
        self.assertEqual("/api/1.0/" + path, utils.get_uri(path))


class TestRetries(testtools.TestCase):

    def test_returns_retry_iterator(self):
        # Patch utils.sleep() so that no time is actually spent
        # sleeping.
        mock_sleep = mock.MagicMock()
        self.patch(utils, 'sleep', mock_sleep)
        # Patch utils.time() so that it will return [0, 0, 1, 2, ...] thus
        # simulating a sleep() of one second during each loop.  The
        # double '0' at the beginning of the list is there to cope with
        # retries() calling time() twice before the first iteration
        # starts.
        mock_time = mock.MagicMock()
        values = iter([0] + range(20))
        mock_time.side_effect = lambda: values.next()
        self.patch(utils, 'time', mock_time)
        timeout = 20
        tries = utils.retries(timeout=timeout)
        self.assertEquals(
            (timeout, timeout),
            (len(list(tries)), len(mock_sleep.mock_calls)))

    def test_returns_immediately(self):
        # When iterating over a generator returned by retries(), the
        # first object is returned immediately; sleep() is only called
        # between iterations.
        delay = 10
        tries = utils.retries(delay=delay)
        start = time.time()
        for _ in tries:
            elapsed = time.time() - start
            break
        self.assertLess(elapsed, delay)


class TestDetermineVMArchitecture(testtools.TestCase):

    def test_determine_vm_architecture_gets_system_arch(self):
        # Additions may become necessary.
        supported_architectures = {
            'amd64',
            'armhf',
            'i386',
        }
        self.assertIn(
            utils.determine_vm_architecture(),
            supported_architectures)

    def test_determine_vm_architecture_returns_ubuntu_arch_name(self):
        # Architecture names according to Python's platform module, and their
        # Ubuntu equivalents.  Unknown names are passed on unchanged.
        arch_names = {
            'i386': 'i386',
            'i686': 'i386',
            'x86_64': 'amd64',
            'armhf': 'armhf',
            'power': 'power',
            'sparc64': 'sparc64',
            'm68k': 'm68k',
        }
        # See what architectures determine_vm_architecture() reports for these
        # systems.
        translations = {}
        machine_func = mock.MagicMock()
        self.patch(platform, 'machine', machine_func)
        for arch in arch_names.keys():
            machine_func.return_value = arch
            translations[arch] = utils.determine_vm_architecture()

        self.assertEqual(arch_names, translations)

    def test_determine_vm_architecture_reports_failure(self):
        # platform.machine() returns an empty string when it can't determine
        # the architecture.  This test simulates that scenario.
        # In python 3 it returns a unicode string, in python 2 a string of
        # bytes.
        return_type = platform.machine().__class__
        self.patch(platform, 'machine', lambda: return_type())

        self.assertRaises(
            utils.UnknownCPUArchitecture,
            utils.determine_vm_architecture)


class TestDetermineVMSeries(testtools.TestCase):

    def test_determine_vm_series_returns_system_series(self):
        # We don't know the names of future series this test might run on, but
        # it'll be a lower-case name.  The supported releases start with
        # Precise, and for the forseeable future, progress through the
        # alphabet from there.
        self.assertThat(
            utils.determine_vm_series(),
            MatchesRegex('[p-z][a-z]+$'))

    def test_determine_vm_series_uses_distro_information(self):
        series = 'nerdy'
        distro = ('Ubuntu', '1.01', series)
        self.patch(platform, 'linux_distribution', lambda: distro)

        self.assertEqual(series, utils.determine_vm_series())


class TestGetSupportedNodeSeries(testtools.TestCase):

    def test_get_supported_node_series_return_series(self):
        supported = (
            distro_info.UbuntuDistroInfo().supported(result="codename"))
        # The list of supported series will evolve in time.  At the time
        # of this writing, 'lucid' is still a supported series but it's
        # not supported by MAAS.
        if 'lucid' in supported:
            supported.remove('lucid')
            self.assertEqual(supported, utils.get_supported_node_series())
        else:
            self.assertEqual(supported, utils.get_supported_node_series())


class TestGetSupportedMAASSeries(testtools.TestCase):

    def test_get_supported_maas_series_return_series(self):
        self.assertEqual(['trusty'], utils.get_supported_maas_series())


class TestExtractMAPIPMapping(testtools.TestCase):

    def test_extract_mac_ip_mapping_parses_output(self):
        NMAP_XML_OUTPUT = dedent("""
            <nmaprun scanner="nmap" args="nmap -sP -oX - 192.168.2.0/24">
              <host>
                <address addr="192.168.2.2" addrtype="ipv4"/>
                <address addr="00:9C:02:A2:82:74" addrtype="mac"/>
              </host>
              <host>
                <address addr="192.168.2.4" addrtype="ipv4"/>
                <address addr="00:9C:02:A0:4D:0A" addrtype="mac"/>
              </host>
            </nmaprun>
        """)
        expected_result = {
            '00:9C:02:A2:82:74': '192.168.2.2',
            '00:9C:02:A0:4D:0A': '192.168.2.4',
        }
        self.assertEqual(
            expected_result, utils.extract_mac_ip_mapping(NMAP_XML_OUTPUT))

    def test_extract_mac_ip_mapping_returns_uppercase_mac(self):
        NMAP_XML_OUTPUT = dedent("""
            <nmaprun scanner="nmap" args="nmap -sP -oX - 192.168.2.0/24">
              <host>
                <address addr="192.168.2.2" addrtype="ipv4"/>
                <address addr="AA:bb:cc:dd:ee:FF" addrtype="mac"/>
              </host>
            </nmaprun>
        """)
        expected_result = {'AA:BB:CC:DD:EE:FF': '192.168.2.2'}
        self.assertEqual(
            expected_result, utils.extract_mac_ip_mapping(NMAP_XML_OUTPUT))

    def test_extract_mac_ip_mapping_with_empty_doc(self):
        self.assertEqual(
            {}, utils.extract_mac_ip_mapping('<nmaprun></nmaprun>'))

    def test_extract_mac_ip_mapping_with_missing_info(self):
        # If either the IP address of the MAC address is missing in a host
        # definition, the host entry is not considered.
        NMAP_XML_OUTPUT = dedent("""
            <nmaprun scanner="nmap" args="nmap -sP -oX - 192.168.2.0/24">
              <host>
                <address addr="192.168.2.2" addrtype="ipv4"/>
              </host>
              <host>
                <address addr="00:9C:02:A0:4D:0A" addrtype="mac"/>
              </host>
            </nmaprun>
        """)
        self.assertEqual({}, utils.extract_mac_ip_mapping(NMAP_XML_OUTPUT))


class TestMipfArchList(testtools.TestCase):

    def test_defaults_to_generic(self):
        self.assertItemsEqual(['i386/generic'], utils.mipf_arch_list('i386'))

    def test_adds_i386_if_missing(self):
        self.assertEqual(
            (['amd64/generic', 'i386/generic'], ['i386/generic']),
            (utils.mipf_arch_list('amd64'), utils.mipf_arch_list('i386')))


class TestVirtualizationType(testtools.TestCase):

    def test_returns_None_if_not_virtualised(self):
        mock_run_command = mock.MagicMock()
        mock_run_command.return_value = (0, '', '')
        self.patch(utils, 'run_command', mock_run_command)
        self.assertIsNone(utils.virtualization_type())

    def test_returns_virt_type_if_virtualised(self):
        virt_type = self.getUniqueString()
        mock_run_command = mock.MagicMock()
        mock_run_command.return_value = (0, virt_type, '')
        self.patch(utils, 'run_command', mock_run_command)
        self.assertEqual(virt_type, utils.virtualization_type())


class TestCheckKVM(testtools.TestCase):

    def test_returns_false_if_kvm_not_available(self):
        mock_run_command = mock.MagicMock()
        mock_run_command.return_value = (1, '', '')
        self.patch(utils, 'run_command', mock_run_command)
        self.assertFalse(utils.check_kvm_ok())

    def test_returns_true_if_kvm_available(self):
        mock_run_command = mock.MagicMock()
        mock_run_command.return_value = (0, '', '')
        self.patch(utils, 'run_command', mock_run_command)
        self.assertTrue(utils.check_kvm_ok())


class TestCasesLoader(testtools.TestCase):

    def test_that_cases_are_sorted_by_lineno(self):

        class ExampleTest(testtools.TestCase):

            def test_333(self):
                pass

            def test_111(self):
                pass

            def test_222(self):
                pass

        loader = utils.CasesLoader()
        suite = loader.loadTestsFromTestCase(ExampleTest)
        self.assertEqual(
            ["test_333", "test_111", "test_222"],
            [test._testMethodName for test in suite])


class TestCachingOutputStream(testtools.TestCase):

    def test_writes_to_stream(self):
        stream = BytesIO()
        caching_stream = utils.CachingOutputStream(stream)
        output_string = self.getUniqueString()
        caching_stream.write(output_string)
        self.assertEqual(output_string, stream.getvalue())

    def test_caches_output(self):
        stream = BytesIO()
        caching_stream = utils.CachingOutputStream(stream)
        caching_stream.write(self.getUniqueString())
        self.assertEqual(
            stream.getvalue(), caching_stream.cache.getvalue())

    def test_init_sets_values(self):
        stream = BytesIO()
        caching_stream = utils.CachingOutputStream(stream)
        self.assertEqual(stream, caching_stream.stream)
        self.assertIsInstance(caching_stream.cache, BytesIO)

    def test_attributes_looked_up_on_stream(self):
        stream = BytesIO()
        some_string = self.getUniqueString()
        stream.foo_bar_baz = mock.MagicMock(return_value=some_string)
        caching_stream = utils.CachingOutputStream(stream)
        self.assertEqual(some_string, caching_stream.foo_bar_baz())


class TestExtractPackageVersion(testtools.TestCase):

    def test_extracts_package_version(self):
        version = '1.4+bzr1693+dfsg-0ubuntu2.2'
        policy = dedent("""
            maas:
              Installed: (none)
              Candidate: %s
              Version table:
                 1.4+bzr1693+dfsg-0ubuntu2.2 0
                    500 http://example.com/ubuntu/ saucy/main amd64 Packages
        """) % version
        self.assertEqual(version, utils.extract_package_version(policy))

    def test_returns_None_if_policy_is_unparseable(self):
        self.assertIsNone(utils.extract_package_version("unparsable"))


class TestComposeFilter(testtools.TestCase):

    def test_compose_filter_returns_single_literal(self):
        key = self.getUniqueString()
        literal = self.getUniqueString()
        self.assertEqual(
            '%s~(%s)' % (key, re.escape(literal)),
            utils.compose_filter(key, [literal]))

    def test_compose_filter_combines_literals(self):
        key = self.getUniqueString()
        values = (self.getUniqueString(), self.getUniqueString())
        self.assertEqual(
            '%s~(%s|%s)' % (
                key, re.escape(values[0]), re.escape(values[1])),
            utils.compose_filter(key, values))

    def test_compose_filter_escapes_literals_for_regex_use(self):
        key = self.getUniqueString()
        self.assertEqual(
            '%s~(x\\.y\\*)' % key,
            utils.compose_filter(key, ['x.y*']))
