/* SPDX-License-Identifier: MIT */
/* SPDX-FileCopyrightText: (c) Copyright 2024-2026 Andrew Bower <andrew@bower.uk> */

/* xchpst: eXtended Change Process State
 * A tool that is backwards compatible with chpst(8) from runit(8),
 * offering additional options to harden process with namespace isolation
 * and more. */

#include <poll.h>
#include <signal.h>
#include <linux/prctl.h>
#include <strings.h>
#include <sys/file.h>
#include <sys/signalfd.h>
#include <sys/wait.h>

#include "xchpst.h"
#include "join.h"

static const bool terminate_if_main_child_fails = false;

int for_all_children(int dir, void (*fn)(void *context, pid_t p), void *context) {
  /* Avoid dynamic allocations, including libc streams so we can clean up
   * well under memory stress */
  char buf[128];
  int fd;
  int rc;
  int count = 0;
  bool eof = false;
  ssize_t i = 0; /* Invariant: read pointer */
  ssize_t len = 0; /* Invariant: valid data in buffer */
  unsigned long d;

  fd = openat(dir, "children", O_RDONLY);
  if (fd == -1)
    return -errno;

  buf[0] = '\0';
  d = 0;
  while (!eof || i < len) {
    int c;
    if (!eof && i == len) {
      rc = read(fd, buf, sizeof buf);
      if (rc == -1) {
        count = -errno;
        goto finish;
      } else if (rc == 0) {
        goto finish;
      }
      i = 0;
      len = rc;
    }
    c = buf[i++];
    if (isdigit(c)) {
      d = d * 10 + c - '0';
    } else if (!isspace(c)) {
      fprintf(stderr, "Error reading child task list\n");
      count = -EINVAL;
      goto finish;
    } else {
      fn(context, (pid_t) d);
      d = 0;
      count++;
    }
  }
finish:
  close(fd);
  return count;
}

static void signal_child(void *context, pid_t p) {
  kill(p, (intptr_t) context);
}

/* Rules for determining the final return code. First matching rule wins.
 *   1. main child non-zero exit code
 *   2. last non-zero exit code before any "terminate" flag
 *   3. last non-zero exit code after any "terminate" flag
 *   4. synthesised failure exit code after "terminate" flag
 *   5. main child zero exit code
 *   6. CHPST_ERROR_CHANGING_STATE
 */
enum retcode_source { E_MAIN_NZ, E_ANY_NZ, E_ANY_NZ_AFTER_TERM, E_TERM, E_MAIN_Z, E_CHPST, _E_NUM };

struct subreaper_context {
  int retcode[_E_NUM];
  unsigned int retcodes_from;
  pid_t main_child;
  bool terminate;
};

static struct subreaper_context init_subreaper(pid_t main_child) {
  struct subreaper_context context = {
    .retcode = { [E_CHPST] = CHPST_ERROR_CHANGING_STATE, },
    .retcodes_from = 1 << E_CHPST,
    .main_child = main_child,
  };
  static_assert((sizeof context.retcodes_from) * 8 >= _E_NUM);
  return context;
}

/* Handle result of wait*() */
static void reap_one(struct subreaper_context *subreaper, siginfo_t *pidinf) {
  int retcode = 0;
  bool exited = false;

  if (pidinf->si_signo == SIGCHLD &&
      pidinf->si_pid != 0) {
    if (pidinf->si_code == CLD_KILLED || pidinf->si_code == CLD_DUMPED) {
      if (is_verbose())
        fprintf(stderr, "subreaper: child killed by signal %d\n", pidinf->si_status);
      retcode = 128 + pidinf->si_status;
      exited = true;
    } else if (pidinf->si_code == CLD_EXITED) {
      retcode = pidinf->si_status;
      exited = true;
    }
  }

  if (exited) {
    bool is_main = pidinf->si_pid == subreaper->main_child;
    typeof(subreaper->retcodes_from) from = 0;

    if (retcode != 0) {
      if (is_main)
        from |= 1U << E_MAIN_NZ;
      from |= 1U << E_ANY_NZ;
      if (subreaper->terminate)
        from |= 1U << E_ANY_NZ_AFTER_TERM;
    } else if (is_main) {
      from |= 1U << E_MAIN_Z;
    }

    subreaper->retcodes_from |= from;
    for (unsigned source; from ; from &= ~(1 << source))
      subreaper->retcode[source = ffs(from) - 1] = retcode;

    if (exited &&
        pidinf->si_pid == subreaper->main_child && retcode != 0) {
      if (is_verbose())
        fprintf(stderr, "subreaper: main child exited uncleanly: %d\n", retcode);
      if (terminate_if_main_child_fails) {
        subreaper->terminate = true;
        subreaper->retcodes_from |= 1U << E_TERM;
        subreaper->retcode[E_TERM] = 128 + SIGHUP;
      }
    }
  }
}

bool sig_proxy_mask(sigset_t *newmask, sigset_t *oldmask) {
  sigfillset(newmask);
  sigdelset(newmask, SIGBUS);
  sigdelset(newmask, SIGFPE);
  sigdelset(newmask, SIGILL);
  sigdelset(newmask, SIGSEGV);
  if (sigprocmask(SIG_SETMASK, newmask, oldmask) == -1) {
    perror("join: setting up mask for signalfds");
    return false;
  } else {
    return true;
  }
}

bool join(pid_t child, sigset_t *mask, sigset_t *oldmask, int *retcode,
          bool subreaper, int proc_self, const char *tool) {
  enum {
    /* Offsets into poll set */
    my_signalfd = 0,
  };
  struct subreaper_context subcontext = init_subreaper(child);
  struct signalfd_siginfo siginf;
  int ready;
  int sfd = -1;
  int proc_task = -1;
  ssize_t rc;
  bool success = false;
  bool terminating = false;

  if (subreaper) {

    /* If we become the subreaper, pre-open the procfs directory that will
       contain our list of children */
    char task_path[] = "task/4294967295";
    static_assert(sizeof(pid_t) <= 4);
    rc = snprintf(task_path, sizeof task_path, "task/%d", gettid());
    if (rc == -1 || rc >= (ssize_t) sizeof task_path) {
      fprintf(stderr, "%s: formatting task path: %s\n", tool, strerror(errno));
      goto finish;
    }
    proc_task = openat(proc_self, task_path, O_RDONLY | O_DIRECTORY | O_CLOEXEC);
  }

  sfd = signalfd(-1, mask, SFD_NONBLOCK);
  if (sfd == -1) {
    fprintf(stderr, "%s: error setting up signal proxy: %s\n", tool, strerror(errno));
    goto finish;
  }

  struct pollfd pollset[1] = {
    [my_signalfd] = { .fd = sfd, .events = POLLIN },
  };
  while(!success) {
    ready = poll(pollset, 1, -1);
    if (ready == -1 && errno != EINTR) {
      fprintf(stderr, "%s: poll: %s\n", tool, strerror(errno));
    } else if (ready != 0) {
      if (pollset[my_signalfd].revents & POLLIN) {

        /* Handle a signal received by parent process and pass to child */
        rc = read(sfd, &siginf, sizeof siginf);
        if (rc != sizeof siginf) {
          fprintf(stderr, "%s: read signalfd: %s\n", tool, strerror(errno));
          goto finish;
        }

        if (is_debug())
          fprintf(stderr, "%s: got signal %d\n", tool, siginf.ssi_signo);

        if (siginf.ssi_signo == SIGQUIT) {
          if (is_verbose())
            fprintf(stderr, "%s: translating SIGQUIT to SIGKILL for child\n", tool);
          siginf.ssi_signo = SIGKILL;
          terminating = true;
          subcontext.retcodes_from |= 1U << E_TERM;
          subcontext.retcode[E_TERM] = 128 + siginf.ssi_signo;
        }
        if (siginf.ssi_signo != SIGCHLD) {
          if (is_verbose())
            fprintf(stderr, "%s: passing on signal %d to child%s\n", tool,
                    siginf.ssi_signo, subreaper ? "ren" : "");
          if (subreaper) {
            for_all_children(proc_task, signal_child,
                             (void *)(intptr_t) siginf.ssi_signo);
          } else {
            kill(child, siginf.ssi_signo);
          }
        } else {
          siginfo_t pidinf = { 0 };
          int rc;

          while((rc = waitid(P_ALL, -1, &pidinf, WEXITED | WNOHANG)) == 0 && pidinf.si_signo == SIGCHLD) {
            if (pidinf.si_pid == child) {
              child = -1; /* main child died */
              if (!subreaper)
                success = true;
            }
            reap_one(&subcontext, &pidinf);
          }

          if (rc == -1) {
            if (errno != ECHILD) {
              fprintf(stderr, "%s: reaping children: %s\n", tool, strerror(errno));
              goto finish;
            } else {
              /* OK, exited the loop because no children left */
              success = true;
            }
          }

          if (subreaper) {
            if (subcontext.terminate || terminating) {
              terminating = true;
              if (is_debug())
                fprintf(stderr, "%s: terminating\n", tool);
              for_all_children(proc_task, signal_child,
                               (void *)(intptr_t) SIGKILL);
            }
          }
        }
      }
    }
  }

  success = true;
  if (is_verbose())
    fprintf(stderr, "%s: child terminated; cleaning up\n", tool);

  if (ffs(subcontext.retcodes_from) - 1 < _E_NUM) {
    if (is_debug())
      fprintf(stderr, "%s: retcode sources: 0o%o\n", tool, subcontext.retcodes_from);
    *retcode = subcontext.retcode[ffs(subcontext.retcodes_from) - 1];
  }
finish:
  if (!success && child != -1)
    kill(child, SIGKILL);
  if (sfd != -1)
    close(sfd);
  if (proc_task != -1)
    close(proc_task);
  sigprocmask(SIG_SETMASK, oldmask, NULL);

  return success;
}

bool pico_init(int proc_self) {
  sigset_t newmask, oldmask;
  enum {
    /* Offsets into poll set */
    my_signalfd = 0,
  };
  int last_retcode;
  int child;

  if (!sig_proxy_mask(&newmask, &oldmask))
    return false;

  child = fork();
  if (child == -1) {
    perror("pico_init: fork");
    return false;
  } else if (child == 0) {
    if (sigprocmask(SIG_SETMASK, &oldmask, NULL) == -1)
      perror("pico_init: warning: could not restore signal mask in child");
    return true;
  }

  if (!join(child, &newmask, &oldmask, &last_retcode, true, proc_self, "pico_init"))
    goto fatal_error;

  if (is_verbose())
    fprintf(stderr, "pico_init: children terminated; cleaning up\n");

  exit(last_retcode);
  return true;

fatal_error:
  exit(CHPST_ERROR_CHANGING_STATE);
  return false;
}
