/*  Lziprecover - Data recovery tool for the lzip format
    Copyright (C) 2009-2015 Antonio Diaz Diaz.

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

class Range_mtester
  {
  const uint8_t * const buffer;	// input buffer
  const long buffer_size;
  long pos;			// current pos in buffer
  uint32_t code;
  uint32_t range;
  bool at_stream_end;

  void operator=( const Range_mtester & );	// declared as private

public:
  Range_mtester( const uint8_t * const buf, const long buf_size )
    :
    buffer( buf ),
    buffer_size( buf_size ),
    pos( File_header::size ),
    code( 0 ),
    range( 0xFFFFFFFFU ),
    at_stream_end( false )
    {}

  void load()
    {
    for( int i = 0; i < 5; ++i ) code = (code << 8) | get_byte();
    code &= range;		// make sure that first byte is discarded
    }

  bool code_is_zero() const { return ( code == 0 ); }
  bool finished() { return pos >= buffer_size; }
  long member_position() const { return pos; }

  const File_trailer * get_trailer()
    {
    if( buffer_size - pos < File_trailer::size ) return 0;
    const File_trailer * const p = (File_trailer *)(buffer + pos);
    pos += File_trailer::size;
    return p;
    }

  uint8_t get_byte()
    {
    if( finished() ) return 0xAA;		// make code != 0
    return buffer[pos++];
    }

  void normalize()
    {
    if( range <= 0x00FFFFFFU )
      { range <<= 8; code = (code << 8) | get_byte(); }
    }

  int decode( const int num_bits )
    {
    int symbol = 0;
    for( int i = num_bits; i > 0; --i )
      {
      normalize();
      range >>= 1;
//      symbol <<= 1;
//      if( code >= range ) { code -= range; symbol |= 1; }
      const uint32_t mask = 0U - (code < range);
      code -= range;
      code += range & mask;
      symbol = (symbol << 1) + (mask + 1);
      }
    return symbol;
    }

  int decode_bit( Bit_model & bm )
    {
    normalize();
    const uint32_t bound = ( range >> bit_model_total_bits ) * bm.probability;
    if( code < bound )
      {
      range = bound;
      bm.probability += (bit_model_total - bm.probability) >> bit_model_move_bits;
      return 0;
      }
    else
      {
      range -= bound;
      code -= bound;
      bm.probability -= bm.probability >> bit_model_move_bits;
      return 1;
      }
    }

  int decode_tree3( Bit_model bm[] )
    {
    int symbol = 1;
    symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    return symbol & 7;
    }

  int decode_tree6( Bit_model bm[] )
    {
    int symbol = 1;
    symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    return symbol & 0x3F;
    }

  int decode_tree8( Bit_model bm[] )
    {
    int symbol = 1;
    while( symbol < 0x100 )
      symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
    return symbol & 0xFF;
    }

  int decode_tree_reversed( Bit_model bm[], const int num_bits )
    {
    int model = 1;
    int symbol = 0;
    for( int i = 0; i < num_bits; ++i )
      {
      const bool bit = decode_bit( bm[model] );
      model <<= 1;
      if( bit ) { ++model; symbol |= (1 << i); }
      }
    return symbol;
    }

  int decode_tree_reversed4( Bit_model bm[] )
    {
    int model = 1;
    int symbol = decode_bit( bm[model] );
    model = (model << 1) + symbol;
    int bit = decode_bit( bm[model] );
    model = (model << 1) + bit; symbol |= (bit << 1);
    bit = decode_bit( bm[model] );
    model = (model << 1) + bit; symbol |= (bit << 2);
    if( decode_bit( bm[model] ) ) symbol |= 8;
    return symbol;
    }

  int decode_matched( Bit_model bm[], int match_byte )
    {
    Bit_model * const bm1 = bm + 0x100;
    int symbol = 1;
    while( symbol < 0x100 )
      {
      match_byte <<= 1;
      const int match_bit = match_byte & 0x100;
      const int bit = decode_bit( bm1[match_bit+symbol] );
      symbol = ( symbol << 1 ) | bit;
      if( match_bit != bit << 8 )
        {
        while( symbol < 0x100 )
          symbol = ( symbol << 1 ) | decode_bit( bm[symbol] );
        break;
        }
      }
    return symbol & 0xFF;
    }

  int decode_len( Len_model & lm, const int pos_state )
    {
    if( decode_bit( lm.choice1 ) == 0 )
      return decode_tree3( lm.bm_low[pos_state] );
    if( decode_bit( lm.choice2 ) == 0 )
      return len_low_symbols + decode_tree3( lm.bm_mid[pos_state] );
    return len_low_symbols + len_mid_symbols + decode_tree8( lm.bm_high );
    }
  };


class LZ_mtester
  {
  unsigned long long partial_data_pos;
  Range_mtester rdec;
  const unsigned dictionary_size;
  const int buffer_size;
  uint8_t * buffer;		// output buffer
  int pos;			// current pos in buffer
  int stream_pos;		// first byte not yet written to file
  uint32_t crc_;
  unsigned rep0;		// rep[0-3] latest four distances
  unsigned rep1;		// used for efficient coding of
  unsigned rep2;		// repeated distances
  unsigned rep3;
  State state;

  Bit_model bm_literal[1<<literal_context_bits][0x300];
  Bit_model bm_match[State::states][pos_states];
  Bit_model bm_rep[State::states];
  Bit_model bm_rep0[State::states];
  Bit_model bm_rep1[State::states];
  Bit_model bm_rep2[State::states];
  Bit_model bm_len[State::states][pos_states];
  Bit_model bm_dis_slot[len_states][1<<dis_slot_bits];
  Bit_model bm_dis[modeled_distances-end_dis_model];
  Bit_model bm_align[dis_align_size];

  Len_model match_len_model;
  Len_model rep_len_model;

  unsigned long long stream_position() const
    { return partial_data_pos + stream_pos; }
  void flush_data();
  bool verify_trailer();

  uint8_t get_prev_byte() const
    {
    const int i = ( ( pos > 0 ) ? pos : buffer_size ) - 1;
    return buffer[i];
    }

  uint8_t get_byte( const int distance ) const
    {
    int i = pos - distance - 1;
    if( i < 0 ) i += buffer_size;
    return buffer[i];
    }

  void put_byte( const uint8_t b )
    {
    buffer[pos] = b;
    if( ++pos >= buffer_size ) flush_data();
    }

  void copy_block( const int distance, int len )
    {
    int i = pos - distance - 1;
    if( i < 0 ) i += buffer_size;
    if( len < buffer_size - std::max( pos, i ) && len <= std::abs( pos - i ) )
      {
      std::memcpy( buffer + pos, buffer + i, len );	// no wrap, no overlap
      pos += len;
      }
    else for( ; len > 0; --len )
      {
      buffer[pos] = buffer[i];
      if( ++pos >= buffer_size ) flush_data();
      if( ++i >= buffer_size ) i = 0;
      }
    }

  void operator=( const LZ_mtester & );		// declared as private

public:
  LZ_mtester( const uint8_t * const ibuf, const long ibuf_size,
              const int dict_size )
    :
    partial_data_pos( 0 ),
    rdec( ibuf, ibuf_size ),
    dictionary_size( dict_size ),
    buffer_size( std::max( 65536U, dictionary_size ) ),
    buffer( new uint8_t[buffer_size] ),
    pos( 0 ),
    stream_pos( 0 ),
    crc_( 0xFFFFFFFFU ),
    rep0( 0 ),
    rep1( 0 ),
    rep2( 0 ),
    rep3( 0 )
    { buffer[buffer_size-1] = 0; }	// prev_byte of first byte

  ~LZ_mtester() { delete[] buffer; }

  unsigned crc() const { return crc_ ^ 0xFFFFFFFFU; }
  unsigned long long data_position() const { return partial_data_pos + pos; }
  bool finished() { return rdec.finished(); }
  long member_position() const { return rdec.member_position(); }

  void duplicate_buffer();
  int test_member( const long pos_limit = LONG_MAX );
  };


uint8_t * read_member( const int infd, const long long mpos,
                       const long long msize );
const LZ_mtester * prepare_master( const uint8_t * const buffer,
                                   const long buffer_size,
                                   const long pos_limit );
bool test_member_rest( const LZ_mtester & master, long * const failure_posp = 0 );
