#include "THGeneral.h"
#include "THAtomic.h"

#ifdef _OPENMP
#include <omp.h>
#endif

#ifndef TH_HAVE_THREAD
#define __thread
#elif _MSC_VER
#define __thread __declspec( thread )
#endif

#if defined(__APPLE__)
#include <malloc/malloc.h>
#endif

#if defined(__linux__)
#include <malloc.h>
#endif

#if defined(__FreeBSD__)
#include <malloc_np.h>
#endif

/* Torch Error Handling */
static void defaultErrorHandlerFunction(const char *msg, void *data)
{
  printf("$ Error: %s\n", msg);
  abort();
}

static THErrorHandlerFunction defaultErrorHandler = defaultErrorHandlerFunction;
static void *defaultErrorHandlerData;
static __thread THErrorHandlerFunction threadErrorHandler = NULL;
static __thread void *threadErrorHandlerData;

void _THError(const char *file, const int line, const char *fmt, ...)
{
  char msg[2048];
  va_list args;

  /* vasprintf not standard */
  /* vsnprintf: how to handle if does not exists? */
  va_start(args, fmt);
  int n = vsnprintf(msg, 2048, fmt, args);
  va_end(args);

  if(n < 2048) {
    snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
  }

  if (threadErrorHandler)
    (*threadErrorHandler)(msg, threadErrorHandlerData);
  else
    (*defaultErrorHandler)(msg, defaultErrorHandlerData);
}

void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...) {
  char msg[1024];
  va_list args;
  va_start(args, fmt);
  vsnprintf(msg, 1024, fmt, args);
  va_end(args);
  _THError(file, line, "Assertion `%s' failed. %s", exp, msg);
}

void THSetErrorHandler(THErrorHandlerFunction new_handler, void *data)
{
  threadErrorHandler = new_handler;
  threadErrorHandlerData = data;
}

void THSetDefaultErrorHandler(THErrorHandlerFunction new_handler, void *data)
{
  if (new_handler)
    defaultErrorHandler = new_handler;
  else
    defaultErrorHandler = defaultErrorHandlerFunction;
  defaultErrorHandlerData = data;
}

/* Torch Arg Checking Handling */
static void defaultArgErrorHandlerFunction(int argNumber, const char *msg, void *data)
{
  if(msg)
    printf("$ Invalid argument %d: %s\n", argNumber, msg);
  else
    printf("$ Invalid argument %d\n", argNumber);
  exit(-1);
}

static THArgErrorHandlerFunction defaultArgErrorHandler = defaultArgErrorHandlerFunction;
static void *defaultArgErrorHandlerData;
static __thread THArgErrorHandlerFunction threadArgErrorHandler = NULL;
static __thread void *threadArgErrorHandlerData;

void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...)
{
  if(!condition) {
    char msg[2048];
    va_list args;

    /* vasprintf not standard */
    /* vsnprintf: how to handle if does not exists? */
    va_start(args, fmt);
    int n = vsnprintf(msg, 2048, fmt, args);
    va_end(args);

    if(n < 2048) {
      snprintf(msg + n, 2048 - n, " at %s:%d", file, line);
    }

    if (threadArgErrorHandler)
      (*threadArgErrorHandler)(argNumber, msg, threadArgErrorHandlerData);
    else
      (*defaultArgErrorHandler)(argNumber, msg, defaultArgErrorHandlerData);
  }
}

void THSetArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data)
{
  threadArgErrorHandler = new_handler;
  threadArgErrorHandlerData = data;
}

void THSetDefaultArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data)
{
  if (new_handler)
    defaultArgErrorHandler = new_handler;
  else
    defaultArgErrorHandler = defaultArgErrorHandlerFunction;
  defaultArgErrorHandlerData = data;
}

static __thread void (*torchGCFunction)(void *data) = NULL;
static __thread void *torchGCData;
static ptrdiff_t heapSize = 0;
static __thread ptrdiff_t heapDelta = 0;
static const ptrdiff_t heapMaxDelta = (ptrdiff_t)1e6; // limit to +/- 1MB before updating heapSize
static const ptrdiff_t heapMinDelta = (ptrdiff_t)-1e6;
static __thread ptrdiff_t heapSoftmax = (ptrdiff_t)3e8; // 300MB, adjusted upward dynamically
static const double heapSoftmaxGrowthThresh = 0.8; // grow softmax if >80% max after GC
static const double heapSoftmaxGrowthFactor = 1.4; // grow softmax by 40%

/* Optional hook for integrating with a garbage-collected frontend.
 *
 * If torch is running with a garbage-collected frontend (e.g. Lua),
 * the GC isn't aware of TH-allocated memory so may not know when it
 * needs to run. These hooks trigger the GC to run in two cases:
 *
 * (1) When a memory allocation (malloc, realloc, ...) fails
 * (2) When the total TH-allocated memory hits a dynamically-adjusted
 *     soft maximum.
 */
void THSetGCHandler( void (*torchGCFunction_)(void *data), void *data )
{
  torchGCFunction = torchGCFunction_;
  torchGCData = data;
}

/* it is guaranteed the allocated size is not bigger than PTRDIFF_MAX */
static ptrdiff_t getAllocSize(void *ptr) {
#if defined(__unix) && defined(HAVE_MALLOC_USABLE_SIZE)
  return malloc_usable_size(ptr);
#elif defined(__APPLE__)
  return malloc_size(ptr);
#elif defined(_WIN32)
  if(ptr) { return _msize(ptr); } else { return 0; }
#else
  return 0;
#endif
}

static ptrdiff_t applyHeapDelta() {
  ptrdiff_t oldHeapSize = THAtomicAddPtrdiff(&heapSize, heapDelta);
#ifdef DEBUG
  if (heapDelta > 0 && oldHeapSize > PTRDIFF_MAX - heapDelta)
    THError("applyHeapDelta: heapSize(%td) + increased(%td) > PTRDIFF_MAX, heapSize overflow!", oldHeapSize, heapDelta);
  if (heapDelta < 0 && oldHeapSize < PTRDIFF_MIN - heapDelta)
    THError("applyHeapDelta: heapSize(%td) + decreased(%td) < PTRDIFF_MIN, heapSize underflow!", oldHeapSize, heapDelta);
#endif
  ptrdiff_t newHeapSize = oldHeapSize + heapDelta;
  heapDelta = 0;
  return newHeapSize;
}

/* (1) if the torch-allocated heap size exceeds the soft max, run GC
 * (2) if post-GC heap size exceeds 80% of the soft max, increase the
 *     soft max by 40%
 */
static void maybeTriggerGC(ptrdiff_t curHeapSize) {
  if (torchGCFunction && curHeapSize > heapSoftmax) {
    torchGCFunction(torchGCData);

    // ensure heapSize is accurate before updating heapSoftmax
    ptrdiff_t newHeapSize = applyHeapDelta();

    if (newHeapSize > heapSoftmax * heapSoftmaxGrowthThresh) {
      heapSoftmax = (ptrdiff_t)(heapSoftmax * heapSoftmaxGrowthFactor);
    }
  }
}

// hooks into the TH heap tracking
void THHeapUpdate(ptrdiff_t size) {
#ifdef DEBUG
  if (size > 0 && heapDelta > PTRDIFF_MAX - size)
    THError("THHeapUpdate: heapDelta(%td) + increased(%td) > PTRDIFF_MAX, heapDelta overflow!", heapDelta, size);
  if (size < 0 && heapDelta < PTRDIFF_MIN - size)
    THError("THHeapUpdate: heapDelta(%td) + decreased(%td) < PTRDIFF_MIN, heapDelta underflow!", heapDelta, size);
#endif

  heapDelta += size;

  // batch updates to global heapSize to minimize thread contention
  if (heapDelta < heapMaxDelta && heapDelta > heapMinDelta) {
    return;
  }

  ptrdiff_t newHeapSize = applyHeapDelta();

  if (size > 0) {
    maybeTriggerGC(newHeapSize);
  }
}

static void* THAllocInternal(ptrdiff_t size)
{
  void *ptr;

  if (size > 5120)
  {
#if (defined(__unix) || defined(__APPLE__)) && (!defined(DISABLE_POSIX_MEMALIGN))
    if (posix_memalign(&ptr, 64, size) != 0)
      ptr = NULL;
/*
#elif defined(_WIN32)
    ptr = _aligned_malloc(size, 64);
*/
#else
    ptr = malloc(size);
#endif
  }
  else
  {
    ptr = malloc(size);
  }

  THHeapUpdate(getAllocSize(ptr));
  return ptr;
}

void* THAlloc(ptrdiff_t size)
{
  void *ptr;

  if(size < 0)
    THError("$ Torch: invalid memory size -- maybe an overflow?");

  if(size == 0)
    return NULL;

  ptr = THAllocInternal(size);

  if(!ptr && torchGCFunction) {
    torchGCFunction(torchGCData);
    ptr = THAllocInternal(size);
  }

  if(!ptr)
    THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824);

  return ptr;
}

void* THRealloc(void *ptr, ptrdiff_t size)
{
  if(!ptr)
    return(THAlloc(size));

  if(size == 0)
  {
    THFree(ptr);
    return NULL;
  }

  if(size < 0)
    THError("$ Torch: invalid memory size -- maybe an overflow?");

  ptrdiff_t oldSize = -getAllocSize(ptr);
  void *newptr = realloc(ptr, size);

  if(!newptr && torchGCFunction) {
    torchGCFunction(torchGCData);
    newptr = realloc(ptr, size);
  }

  if(!newptr)
    THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824);

  // update heapSize only after successfully reallocated
  THHeapUpdate(oldSize + getAllocSize(newptr));

  return newptr;
}

void THFree(void *ptr)
{
  THHeapUpdate(-getAllocSize(ptr));
  free(ptr);
}

double THLog1p(const double x)
{
#if (defined(_MSC_VER) || defined(__MINGW32__))
  volatile double y = 1 + x;
  return log(y) - ((y-1)-x)/y ;  /* cancels errors with IEEE arithmetic */
#else
  return log1p(x);
#endif
}

void THSetNumThreads(int num_threads)
{
#ifdef _OPENMP
  omp_set_num_threads(num_threads);
#endif
#ifdef TH_BLAS_OPEN
  extern void openblas_set_num_threads(int);
  openblas_set_num_threads(num_threads);
#endif
#ifdef TH_BLAS_MKL
  extern void mkl_set_num_threads(int);
  mkl_set_num_threads(num_threads);

#endif
}

int THGetNumThreads(void)
{
  int nthreads = 1;
#ifdef _OPENMP
  nthreads = omp_get_max_threads();
#endif
#ifdef TH_BLAS_OPEN
  int bl_threads = 1;
  extern int openblas_get_num_threads(void);
  bl_threads = openblas_get_num_threads();
  nthreads = nthreads > bl_threads ? bl_threads : nthreads;
#endif
#ifdef TH_BLAS_MKL
  int bl_threads = 1;
  extern int mkl_get_max_threads(void);
  bl_threads = mkl_get_max_threads();
  nthreads = nthreads > bl_threads ? bl_threads : nthreads;
#endif
  return nthreads;
}

int THGetNumCores(void)
{
#ifdef _OPENMP
  return omp_get_num_procs();
#else
  return 1;
#endif
}

#ifdef TH_BLAS_MKL
extern int mkl_get_max_threads(void);
#endif

TH_API void THInferNumThreads(void)
{
#if defined(_OPENMP) && defined(TH_BLAS_MKL)
  // If we are using MKL an OpenMP make sure the number of threads match.
  // Otherwise, MKL and our OpenMP-enabled functions will keep changing the
  // size of the OpenMP thread pool, resulting in worse performance (and memory
  // leaks in GCC 5.4)
  omp_set_num_threads(mkl_get_max_threads());
#endif
}

TH_API THDescBuff _THSizeDesc(const long *size, const long ndim) {
  const int L = TH_DESC_BUFF_LEN;
  THDescBuff buf;
  char *str = buf.str;
  int n = 0;
  n += snprintf(str, L-n, "[");
  int i;
  for(i = 0; i < ndim; i++) {
    if(n >= L) break;
    n += snprintf(str+n, L-n, "%ld", size[i]);
    if(i < ndim-1) {
      n += snprintf(str+n, L-n, " x ");
    }
  }
  if(n < L - 2) {
    snprintf(str+n, L-n, "]");
  } else {
    snprintf(str+L-5, 5, "...]");
  }
  return buf;
}

