# Copyright (C) 2013 Canonical Ltd.
# Author: Robie Basak <robie.basak@canonical.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function
from __future__ import unicode_literals

import argparse
import contextlib
import functools
import socket
import sys
import time

import uvtool.libvirt

SSH_PORT = 22


def wait_for_ip(name, timeout):
    timeout_time = time.time() + timeout
    while len(uvtool.libvirt.name_to_ips(name)) < 1:
        current_time = time.time()
        if current_time > timeout_time:
            return False
        time.sleep(1)
    return True


def has_open_ssh_port(host, timeout=4):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    with contextlib.closing(s):
        s.settimeout(timeout)
        try:
            s.connect((host, SSH_PORT))
        except:
            return False
        else:
            return True


def poll_for_true(fn, interval, timeout):
    timeout_time = time.time() + timeout
    while time.time() < timeout_time:
        if fn():
            return True
        # This could do with a little more care to ensure that we never
        # sleep beyond timeout_time.
        time.sleep(interval)
    return False


def wait_for_open_ssh_port(host, interval, timeout):
    return poll_for_true(
        functools.partial(has_open_ssh_port, host),
        interval, timeout
    )


def main_libvirt_ipwait(parser, args):
    if not wait_for_ip(name=args.name, timeout=args.timeout):
        print("cloud-wait: timed out", file=sys.stderr)
        sys.exit(1)


def main_ssh(parser, args):
    if not wait_for_open_ssh_port(args.host, args.interval, args.timeout):
        print("cloud-wait: timed out", file=sys.stderr)
        sys.exit(1)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--timeout', type=float, default=120.0)
    subparsers = parser.add_subparsers()

    libvirt_dnsmasq_ipwait = subparsers.add_parser(
        'libvirt-dnsmasq-lease')
    libvirt_dnsmasq_ipwait.set_defaults(func=main_libvirt_ipwait)
    libvirt_dnsmasq_ipwait.add_argument('name')

    ssh_parser = subparsers.add_parser('ssh')
    ssh_parser.set_defaults(func=main_ssh)
    ssh_parser.add_argument('--interval', type=float, default=8.0)
    ssh_parser.add_argument('host')

    args = parser.parse_args()
    args.func(parser, args)


if __name__ == '__main__':
    main()
