#!/usr/bin/env python3

import numpy as np
from unittest import TestCase, main
from vbz import compress, decompress, compression_options
from numpy.testing import assert_array_equal


class Tests:
    vbz_version = None

    """ Base test functions for all 3 encoding/decoding functions """
    def test_decode(self):
        """ S1: simple decode """
        options = None
        if self.vbz_version:
            zigzag = np.issubdtype(self.data.dtype, np.signedinteger)
            options = compression_options(zigzag, self.data.dtype.itemsize, 1, self.vbz_version)

        res = compress(self.data, options)
        rec = decompress(res, self.dtype, options)
        assert_array_equal(self.data, rec)


class BasicTest(Tests):
    """ Base test for encoding and decoding with small simple test case """
    def setUp(self):
        self.data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=self.dtype)

class RandTest(Tests):
    """ Base test encoding and decoding with large random array """
    def setUp(self):
        self.data = np.random.randint(self.rmin, self.rmax, size=self.size, dtype=self.dtype)

# Default configuration test (default options)
class BasicTestInt8(BasicTest, TestCase):
    dtype = np.int8

class BasicTestUInt8(BasicTest, TestCase):
    dtype = np.uint8


class BasicTestInt16(BasicTest, TestCase):
    dtype = np.int16

class BasicTestUInt16(BasicTest, TestCase):
    dtype = np.uint16


class BasicTestInt32(BasicTest, TestCase):
    dtype = np.int32

class BasicTestUInt32(BasicTest, TestCase):
    dtype = np.uint32


# Using v1 (half support test)
class BasicV1TestInt8(BasicTestInt8):
    vbz_version = 1

class BasicV1TestUInt8(BasicTestUInt8):
    vbz_version = 1

class BasicV1TestInt16(BasicTestInt16):
    vbz_version = 1

class BasicV1TestUInt16(BasicTestUInt16):
    vbz_version = 1

# Default configuration test (default options)
class RandTestInt8(RandTest, TestCase):
    dtype = np.int8
    rmin = -2**7
    rmax = 2**7 -1
    size = 200000

class RandTestUInt8(RandTest, TestCase):
    dtype = np.uint8
    rmin = 0
    rmax = 2**7 -1
    size = 200000


class RandTestInt16(RandTest, TestCase):
    dtype = np.int16
    rmin = -2**15
    rmax = 2**15 -1
    size = 200000

class RandTestUInt16(RandTest, TestCase):
    dtype = np.uint16
    rmin = 0
    rmax = 2**15 - 1
    size = 200000


class RandTestInt32(RandTest, TestCase):
    dtype = np.int32
    rmin = -2**31
    rmax = 2**31 - 1
    size = 200000

class RandTestIntU32(RandTest, TestCase):
    dtype = np.uint32
    rmin = 0
    rmax = 2**31 - 1
    size = 200000

# Random tests Using v1 (half support test)
class RandV1TestInt8(RandTestInt8):
    vbz_version = 1

class RandV1TestUInt8(RandTestUInt8):
    vbz_version = 1

class RandV1TestInt16(RandTestUInt16):
    vbz_version = 1

class RandV1TestUInt16(RandTestUInt16):
    vbz_version = 1

if __name__ == '__main__':
    main()
