/*

Copyright (C) 2002 Hayato Fujiwara <h_fujiwara@users.sourceforge.net>
Copyright (C) 2010-2020 Olaf Till <i7tiol@t-online.de.de>

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; If not, see <http://www.gnu.org/licenses/>.

*/

// PKG_ADD: autoload ("pconnect", "parallel_interface.oct");
// PKG_DEL: autoload ("pconnect", "parallel_interface.oct", "remove");

#include "parallel-gnutls.h"

static
int assert_file (std::string &path)
{
  // Windows headers expand stat(), so stat_ is used as object name
  OCTAVE__SYS__FILE_STAT stat_ (path);

  if (! stat_.is_reg ())
    return -1;
  else
    return 0;
}

DEFUN_DLD (pconnect, args, nargout,
           "-*- texinfo -*-\n\
@deftypefn {Loadable Function} {@var{connections} =} pconnect (@var{hosts})\n\
@deftypefnx {Loadable Function} {@var{connections} =} pconnect (@var{hosts}, @var{options})\n\
Connects to a network of parallel cluster servers.\n\
\n\
As a precondition, a server must have been started at each machine of\n\
the cluster, see @code{pserver}. Connections are not guaranteed to\n\
work if client and server are from @code{parallel} packages of\n\
different versions, so versions should be kept equal.\n\
\n\
@var{hosts} is a cell-array of strings, holding the names of all\n\
server machines. The machines must be unique, and their names must be\n\
resolvable to the correct addresses also at each server machine, not\n\
only at the client. This means e.g. that the name @code{localhost} is\n\
not acceptable (exception: @code{localhost} is acceptable as the first\n\
of all names).\n\
\n\
Alternatively, but deprecated, @var{hosts} can be given as previously,\n\
as a character array with a machine name in each row. If it is given\n\
in this way, the first row must contain the name of the client machine\n\
(for backwards compatibility), so that there is one row more than the\n\
now preferred cell-array @var{hosts} would have entries.\n\
\n\
@code{pconnect} returns an opaque variable holding the network\n\
connections. This variable can be indexed to obtain a subset of\n\
connections or even a single connection. (For backwards compatibility,\n\
a second index of @code{:} is allowed, which has no effect). At the\n\
first index position is the client machine, so this position does not\n\
correspond to a real connection. At the following index positions are\n\
the server machines in the same order as specified in the cell-array\n\
@var{hosts}. So in the whole the variable of network connections has\n\
one position more than the number of servers given in @var{hosts}\n\
(except if @var{hosts} was given in the above mentioned deprecated\n\
way). You can display the variable of network connections to see what\n\
is in it. The variable of network connections, or subsets of it, is\n\
passed to the other functions for parallel cluster excecution\n\
(@code{reval}, @code{psend}, @code{precv}, @code{sclose},\n\
@code{select_sockets} among others -- see documentation of these\n\
functions).\n\
\n\
@var{options}: structure of options; field @code{use_tls} is\n\
@code{true} by default (TLS with SRP authentication); if set to\n\
@code{false}, there will be no encryption or authentication. Field\n\
@code{password_file} can be set to an alternative path to the file\n\
with authentication information (see below). Field @code{user} can\n\
specify the username for authentication; if the username is so\n\
specified, no file with authentication information will be used at the\n\
client, but the password will be queried from the user.\n\
\n\
The client and the server must both use or both not use TLS. If TLS is\n\
switched off, different measures must be taken to protect ports 12501\n\
and 12502 at the servers and the client against unauthorized access;\n\
e.g. by a firewall or by physical isolation of the network.\n\
\n\
For using TLS, authorization data must be present at the server\n\
machine. These data can conveniently be generated by\n\
@code{parallel_generate_srp_data}. By default, the client\n\
authentication file is created in the same run. The helptext of\n\
@code{parallel_generate_srp_data} documents the expected locations of\n\
the authentication data.\n\
\n\
The SRP password will be sent over the encrypted TLS channel from the\n\
client to each server, to avoid permanently storing passwords at the\n\
server for server-to-server data connections. Due to inevitable usage\n\
of external libraries, memory with sensitive data can, however, be on\n\
the swap device even after shutdown of the application, both at the\n\
client and at the server machines.\n\
\n\
Example (let data travel through all machines), assuming\n\
@code{pserver} was called on each remote machine and authentication\n\
data is present (e.g. generated with\n\
@code{parallel_generate_srp_data}):\n\
\n\
@example\n\
@group\n\
sockets = pconnect (@{'remote.machine.1', 'remote.machine.2'@});\n\
reval ('psend (precv (sockets(2)), sockets(1))', sockets(3));\n\
reval ('psend (precv (sockets(1)), sockets(3))', sockets(2));\n\
psend ('some data', sockets(2));\n\
precv (sockets(3))\n\
--> ans = some data\n\
sclose (sockets);\n\
@end group\n\
@end example\n\
\n\
@seealso{pserver, reval, psend, precv, sclose, parallel_generate_srp_data, select_sockets, rfeval}\n\
@end deftypefn")
{
  std::string fname ("pconnect");

  octave_value retval;

  if (args.length () < 1 || args.length () > 2)
    {
      print_usage ();
      return retval;
    }

  if (nargout == 0)
    {
      error ("%s: An output argument must be given to hold the network connections.",
             fname.c_str ());
      return retval;
    }

  // A negative integer might be sent as Octave data, and Octave
  // doesn't care about coding of negative integers. (I know, there
  // probably will never be a current C-compiler with something
  // different than twos complement, but the C99-standard allows for
  // it.)
  if (octave_parallel_stream::signed_int_rep ())
    {
      error ("This machine doesn't seem to use twos complement as negative integer representation. If you want this machine to be supported, please file a bug report.");

      return retval;
    }

  // The original character array argument had client name in first
  // row. We still accept such an argument. But currently the client
  // and server names of connections at the server sides are taken
  // from the systems connection structures. This means e.g. that the
  // name under which a client appears to servers can be different on
  // different servers and different from the client hostname
  // retrieved at the client itself.
  //
  // The new cell-array argument does not give the client name; so it
  // has one entry less than the previous character array would have
  // had rows.
  Array<std::string> hosts;
  bool err;
  SET_ERR (hosts = args(0).cellstr_value (), err);
  if (err)
    {
      charMatrix cm;
      CHECK_ERROR (cm = args(0).char_matrix_value (), retval,
                   "%s: first argument must be a cell array of strings or a character matrix",
                   fname.c_str ());
      int rows = cm.rows ();
      int cols = cm.columns ();
      cm = cm.transpose (); 
      hosts.resize1 (rows - 1, std::string ());
      const char *pts, *pte;
      for (int i = 1; i < rows; i++)      
        {
          if ((pte = (char *) memchr
               ((void *) (pts = &cm.data ()[cols * i]), ' ', cols)))
            hosts(i - 1) = std::string (pts, pte - pts);
          else
            hosts(i - 1) = std::string (pts, cols);
        }
    }
  uint32_t nhosts = hosts.numel ();

  inthandler_dont_restart_syscalls __inthandler_guard__;

  // "canonicalize" host names and check for uniqueness
  Array<std::string> canhosts (dim_vector (nhosts, 1));
  for (uint32_t i = 0; i < nhosts; i++)
    {
      struct gnulib_guard __guard (&gnulib_freeaddrinfo);
      canhosts(i) = std::string (gnulib_get_canonname (hosts(i).c_str ()));
      if (canhosts(i).empty ())
        {
          error ("getaddrinfo returned an error");
          return retval;
        }
      dcprintf ("orig: %s, canon: %s\n",
                hosts(i).c_str (), canhosts(i).c_str ());
    }
  Array<std::string> s_canhosts (canhosts.sort ());
  for (uint32_t i = 1; i < nhosts; i++)
    {
      if (! s_canhosts(i - 1).compare (s_canhosts(i)))
        {
          error ("%s: hostnames not unique after canonicalizing",
                 fname.c_str ());
          return retval;
        }
    }

  // default options
  bool use_gnutls = true;
  bool use_pfile = true;
  std::string cpfile;
  std::string user;

  // get options, if any
  if (args.length () == 2)
    {
      octave_scalar_map options;

      CHECK_ERROR (options = args(1).scalar_map_value (), retval,
                   "%s: could not convert second argument to scalar structure",
                   fname.c_str ());

      octave_value tmp;

      // use TLS
      tmp = options.contents ("use_tls");

      if (tmp.is_defined ())
        {
          CHECK_ERROR (use_gnutls = tmp.bool_value (), retval,
                       "%s: could not convert option 'use_tls' to bool",
                       fname.c_str ());
        }

      // custom password file
      tmp = options.contents ("password_file");

      if (tmp.is_defined ())
        {
          CHECK_ERROR (cpfile = tmp.string_value (), retval,
                       "%s: could not convert option 'password_file' to string",
                       fname.c_str ());
        }

      // user name
      tmp = options.contents ("user");

      if (tmp.is_defined ())
        {
          CHECK_ERROR (user = tmp.string_value (), retval,
                       "%s: could not convert option 'user' to string",
                       fname.c_str ());

          if (user.length ())
            use_pfile = false;
        }
    } // args.length () == 2

  // check options integrity
#ifndef HAVE_LIBGNUTLS
  if (use_gnutls)
    {
      error ("TLS not available");
      return retval;
    }
#endif
  if (! use_gnutls && (! use_pfile || cpfile.length ()))
    warning ("no TLS used, options 'user' and 'password_file' have no effect");
  if (use_gnutls && ! use_pfile && cpfile.length ())
    warning ("option 'password_file' has no effect since option 'user' was given and is not empty");


#ifdef HAVE_LIBGNUTLS

  // if necessary, initialize gnutls and create credentials

  struct __credguard
  {
    octave_parallel_gnutls_srp_client_credentials *__c;
    __credguard (void) : __c (NULL) { }
    ~__credguard (void)
    {
      if (__c && ! __c->check_ref ())
        {
          dcprintf ("__credguard will delete cred\n");
          delete __c;
        }
    }
    octave_parallel_gnutls_srp_client_credentials *__get (void) { return __c; }
    void __set (octave_parallel_gnutls_srp_client_credentials *__ic)
    { __c = __ic; }
    void __release (void) { __c = NULL; }
  }  __cg;
  octave_parallel_gnutls_srp_client_credentials *ccred;

  if (use_gnutls)
    {
      // There is no deinitialization routine for extra, _init_extra()
      // requires _init() to have been called before, and only the
      // first call to each of these functions does anything beyond
      // incrementing a counter. So we can't ever deinitialize even
      // the general globals, since we would not be able to
      // reinitialize _extra after this. So we can just as well call
      // the init functions for each pconnect, as long as their global
      // counters don't overflow.
      gnutls_global_init (); 
#ifdef HAVE_LIBGNUTLS_EXTRA
      gnutls_global_init_extra ();  // for SRP
      parallel_gnutls_set_mem_functions ();
#endif

      if (use_pfile)
        {
          if (! cpfile.length ())
            {
#ifdef HAVE_OCTAVE_CONFIG_FCNS
              std::string octave_home = octave::config::octave_home ();
#else
              extern std::string Voctave_home;
              std::string octave_home = Voctave_home;
#endif
              cpfile = octave_home + OCTAVE__SYS__FILE_OPS::dir_sep_str () +
                "share" + OCTAVE__SYS__FILE_OPS::dir_sep_str () + "octave" +
                OCTAVE__SYS__FILE_OPS::dir_sep_str () + "parallel-srp-data" +
                OCTAVE__SYS__FILE_OPS::dir_sep_str () + "client" +
                OCTAVE__SYS__FILE_OPS::dir_sep_str () + "user_passwd";
              if (assert_file (cpfile))
                {
                  octave_value_list f_args (1);
                  f_args(0) = octave_value ("prefix");
                  octave_value_list f_ret;
                  CHECK_ERROR (f_ret = OCTAVE__FEVAL ("pkg", f_args, 1), retval,
                               "%s: could not get prefix from pkg",
                               fname.c_str ());
                  CHECK_ERROR (cpfile = f_ret(0).string_value (), retval,
                               "%s: could not convert output of pkg ('prefix') to string)",
                               fname.c_str ());
                  cpfile = cpfile + OCTAVE__SYS__FILE_OPS::dir_sep_str () +
                    "parallel-srp-data" +
                    OCTAVE__SYS__FILE_OPS::dir_sep_str () + "client" +
                    OCTAVE__SYS__FILE_OPS::dir_sep_str () + "user_passwd";
                  if (assert_file (cpfile))
                    {
                      error ("%s: no regular file found at default password file paths",
                             fname.c_str ());
                      return retval;
                    }
                }
            }
          else if (assert_file (cpfile))
            {
              error ("%s: no regular file found at password file path given by user",
                     fname.c_str ());
              return retval;
            }
          __cg.__set (ccred = new octave_parallel_gnutls_srp_client_credentials
                      (cpfile)); // arg is std::string
        }
      else
        __cg.__set (ccred = new octave_parallel_gnutls_srp_client_credentials
                    (user.c_str ())); // arg is char* string
      if (! __cg.__get ()->check_cred ())
        {
          error ("%s: could not create credentials",
                 fname.c_str ());
          return retval;
        }
    }
#endif // HAVE_LIBGNUTLS


  char tuuid[37];
  char *uuid = (char *) tuuid;
  if (oct_parallel_store_unique_identifier (uuid))
    return retval;      

  octave_parallel_network *network;
  struct __netwguard
  {
    octave_parallel_network *__n;
    __netwguard (octave_parallel_network *__an): __n (__an) { __n->get_ref (); }
    ~__netwguard (void)
    {
      if (__n->release_ref () <= 0)
        {
          dcprintf ("__netwguard will delete network\n");
          delete __n;
        }
    }
  } __ng (network = new octave_parallel_network (nhosts + 1));

  // a pseudo-connection, representing the own node in the network
  octave_parallel_connection *conn =
    new octave_parallel_connection (false, uuid);

  network->insert_connection (conn, 0);

  // store number of processor cores available in client
  conn->set_nproc (num_processors (NPROC_CURRENT));

  for (uint32_t i = 0; i < nhosts; i++)
    {
      dcprintf ("host number %i\n", i);
      int sock = gnulib_socket_pfinet_sockstream (0);
      if (sock == -1)
        {
          error ("socket error");
          return retval;
        }
      struct gnulib_close_guard __sockg (sock);

      struct gnulib_guard __guard (&gnulib_freeaddrinfo);
      if (gnulib_set_port (hosts(i).c_str (), "12502"))
        {
          error ("getaddrinfo returned an error");
          return retval;
        }

      int not_connected = 1;
      for (int j = 0; j < N_CONNECT_RETRIES; j++)
        {
          if((not_connected = gnulib_connect (sock)) == 0)
            break;
          else if (errno != ECONNREFUSED && errno != EINTR)
            {
              _p_error ("connect error");
              break;
            }
          else
            usleep(5000);
        }

      if (not_connected)
        {
          error ("Unable to connect to %s", hosts(i).c_str ());
          return retval;
        }
      else
        {
          conn = new octave_parallel_connection
            (hosts(i).c_str (), false, uuid);

          network->insert_connection (conn, i + 1);

#ifdef HAVE_LIBGNUTLS
          if (use_gnutls)
            {
              conn->insert_cmd_stream
                (new octave_parallel_stream
                 (new octave_parallel_gnutls_streambuf
                  (sock, ccred, false)));
              __cg.__release ();
            }
          else
#endif
            conn->insert_cmd_stream
              (new octave_parallel_stream
               (new octave_parallel_socket_streambuf (sock, false)));
          __sockg.release ();
          if (! conn->get_cmd_stream ()->good ())
            {
              error ("could not create command stream to %s",
                     hosts(i).c_str ());
              return retval;
            }

          conn->get_cmd_stream ()->network_send_4byteint (nhosts, true);
          dcprintf ("nhosts written\n");

          uint32_t nproc;
          conn->get_cmd_stream ()->network_recv_4byteint (nproc);
          dcprintf ("nproc read (%u)\n", nproc);

          conn->set_nproc (nproc);

          conn->get_cmd_stream ()->network_send_4byteint (i + 1, true);
          dcprintf ("current host number written (%i)\n", i + 1);

          conn->get_cmd_stream ()->network_send_string (uuid);
          dcprintf ("uuid written (%s)\n", uuid);

          for (uint32_t j = 0; j < nhosts; j++)
            {
              conn->get_cmd_stream ()->network_send_string (hosts(j).c_str ());
              dcprintf ("hostname %i written (%s)\n", j, hosts(j).c_str ());
            }

          std::string directory = OCTAVE__SYS__ENV::get_current_directory ();

          conn->get_cmd_stream ()->network_send_string (directory.c_str ());
          dcprintf ("directory written (%s)\n", directory.c_str ());

#ifdef HAVE_LIBGNUTLS
          // This is to enable the servers to authenticate to each
          // other with TLS-SRP without having passwords stored at the
          // servers. The username can be read from the TLS-SRP
          // structures by the server.
          if (use_gnutls)
            {
              conn->get_cmd_stream ()->network_send_string
                (ccred->get_passwd ());
              dcprintf ("password written (%s)\n", ccred->get_passwd ());
            }
#endif

          if (! conn->get_cmd_stream ()->good ())
            {
              error ("communication error in initialization");
              return retval;
            }
        }
    }

  // go through the (now short) chain of deallocation-safeguarding
  // objects up to octave_value
  octave_parallel_connections *cconns = new octave_parallel_connections
    (network, uuid, false);
  retval = octave_value (cconns);
  octave_parallel_connections_rep *conns = cconns->get_rep ();

  // usleep (100);

  dcprintf ("\n data socket \n\n");
  for (uint32_t i = 0; i < nhosts; i++)
    {
      dcprintf ("host number %i\n", i);
      struct gnulib_guard __guard (&gnulib_freeaddrinfo);
      if (gnulib_set_port (hosts(i).c_str (), "12501"))
        {
          error ("getaddrinfo returned an error");
          return retval;
        }

      int not_connected = 1;
      for (int j = 0; j < N_CONNECT_RETRIES; j++)
        {
          dcprintf ("host %i, connect retry %i\n", i, j);
          int sock = gnulib_socket_pfinet_sockstream (0);
          if (sock == -1)
            {
              error ("socket error");
              return retval;
            }
          struct gnulib_close_guard __sockg (sock);

          dcprintf ("%i, %i, trying to connect \n", i, j);
          if (gnulib_connect (sock) == 0)
            {
              dcprintf ("%i, %i, connect successful\n", i, j);
              octave_parallel_stream *data_stream;
#ifdef HAVE_LIBGNUTLS
              if (use_gnutls)
                conns->get_connections ()(i + 1)->insert_data_stream
                  (data_stream = new octave_parallel_stream
                   (new octave_parallel_gnutls_streambuf
                    (sock, ccred, false)));
              else
#endif
                conns->get_connections ()(i + 1)->insert_data_stream
                  (data_stream = new octave_parallel_stream
                   (new octave_parallel_socket_streambuf (sock, false)));
              __sockg.release ();
              if (! data_stream->good ())
                {
                  error ("%i, %i, could not create data stream to %s", i, j,
                         hosts(i).c_str ());
                  return retval;
                }

              dcprintf ("data connection established\n");
              data_stream->network_send_string (uuid);
              dcprintf ("%i, %i, uuid written (%s)\n", i, j, uuid);

              // send host number 0 (me, the master)
              data_stream->network_send_4byteint (0, true);
              dcprintf ("%i, %i, me (0) written\n", i, j);

              // recv result code
              int32_t res;
              data_stream->network_recv_4byteint (res); 

              if (! data_stream->good ())
                {
                  error ("communication error in initialization");
                  return retval;
                }

              if (res == -1)
                {
                  if (conns->get_connections ()(i + 1)->delete_data_stream ())
                    {
                      error ("could not delete data stream");
                      return retval;                    
                    }

                  dcprintf ("%i, %i, sleeping after receiving bad result (%i)\n", i, j, res);
                  usleep (5000);
                }
              else if (res)
                {
                  error ("unexpected server error");
                  return retval;
                }
              else
                {
                  minimal_write_header (data_stream->get_ostream ());
                  if (conns->get_connections ()(i + 1)->
                      connection_read_header () ||
                      ! data_stream->good ())
                    {
                      error ("communication error in initialization");
                      return retval;
                    }
                  not_connected = 0;
                  dcprintf ("%i, %i, good result read, header written and read and datastream good, breaking\n", i, j);
                  break;
                }
            }
          else if (errno != ECONNREFUSED && errno != EINTR)
            {
              _p_error ("connect error");
              break;
            }
          else
            {
              dcprintf ("%i, %i, sleeping after failed connect\n", i, j);
              usleep (5000);
            }
        }

      if (not_connected)
        {
          error ("unable to connect to %s", hosts(i).c_str ());
          return retval;
        }
    }

  char dummy;
  for (uint32_t i = 1; i <= nhosts; i++)
    {
      conns->get_connections ()(i)->get_cmd_stream ()->get_istream ()
        >> std::noskipws >> dummy;
      dcprintf ("host %i, final dummy read from command stream (%i)\n", i, dummy);
      dummy = '\n';
      conns->get_connections ()(i)->get_cmd_stream ()->get_ostream () << dummy;
      dcprintf ("host %i, final newline written to command stream\n", i);

      if (! conns->get_connections ()(i)->get_cmd_stream ()->good ())
        {
          error ("could not finish initialization");
          return retval;
        }
      dcprintf ("host %i, command stream good\n", i);
    }

  dcprintf ("returning retval\n");
  return retval;
}


/*
;;; Local Variables: ***
;;; mode: C++ ***
;;; End: ***
*/
