// 7zAes.cpp

#include "StdAfx.h"

#include "../../../C/Sha256.h"

#include "../../Common/ComTry.h"

#ifndef _7ZIP_ST
#include "../../Windows/Synchronization.h"
#endif

#include "../Common/StreamUtils.h"

#include "7zAes.h"
#include "MyAes.h"

#ifndef EXTRACT_ONLY
#include "RandGen.h"
#endif

namespace NCrypto {
namespace N7z {

static const unsigned k_NumCyclesPower_Supported_MAX = 24;

bool CKeyInfo::IsEqualTo(const CKeyInfo &a) const
{
  if (SaltSize != a.SaltSize || NumCyclesPower != a.NumCyclesPower)
    return false;
  for (unsigned i = 0; i < SaltSize; i++)
    if (Salt[i] != a.Salt[i])
      return false;
  return (Password == a.Password);
}

void CKeyInfo::CalcKey()
{
  if (NumCyclesPower == 0x3F)
  {
    unsigned pos;
    for (pos = 0; pos < SaltSize; pos++)
      Key[pos] = Salt[pos];
    for (unsigned i = 0; i < Password.Size() && pos < kKeySize; i++)
      Key[pos++] = Password[i];
    for (; pos < kKeySize; pos++)
      Key[pos] = 0;
  }
  else
  {
    size_t bufSize = 8 + SaltSize + Password.Size();
    CObjArray<Byte> buf(bufSize);
    memcpy(buf, Salt, SaltSize);
    memcpy(buf + SaltSize, Password, Password.Size());
    
    CSha256 sha;
    Sha256_Init(&sha);
    
    Byte *ctr = buf + SaltSize + Password.Size();
    
    for (unsigned i = 0; i < 8; i++)
      ctr[i] = 0;
    
    UInt64 numRounds = (UInt64)1 << NumCyclesPower;

    do
    {
      Sha256_Update(&sha, buf, bufSize);
      for (unsigned i = 0; i < 8; i++)
        if (++(ctr[i]) != 0)
          break;
    }
    while (--numRounds != 0);

    Sha256_Final(&sha, Key);
  }
}

bool CKeyInfoCache::GetKey(CKeyInfo &key)
{
  FOR_VECTOR (i, Keys)
  {
    const CKeyInfo &cached = Keys[i];
    if (key.IsEqualTo(cached))
    {
      for (unsigned j = 0; j < kKeySize; j++)
        key.Key[j] = cached.Key[j];
      if (i != 0)
        Keys.MoveToFront(i);
      return true;
    }
  }
  return false;
}

void CKeyInfoCache::FindAndAdd(const CKeyInfo &key)
{
  FOR_VECTOR (i, Keys)
  {
    const CKeyInfo &cached = Keys[i];
    if (key.IsEqualTo(cached))
    {
      if (i != 0)
        Keys.MoveToFront(i);
      return;
    }
  }
  Add(key);
}

void CKeyInfoCache::Add(const CKeyInfo &key)
{
  if (Keys.Size() >= Size)
    Keys.DeleteBack();
  Keys.Insert(0, key);
}

static CKeyInfoCache g_GlobalKeyCache(32);

#ifndef _7ZIP_ST
  static NWindows::NSynchronization::CCriticalSection g_GlobalKeyCacheCriticalSection;
  #define MT_LOCK NWindows::NSynchronization::CCriticalSectionLock lock(g_GlobalKeyCacheCriticalSection);
#else
  #define MT_LOCK
#endif

CBase::CBase():
  _cachedKeys(16),
  _ivSize(0)
{
  for (unsigned i = 0; i < sizeof(_iv); i++)
    _iv[i] = 0;
}

void CBase::PrepareKey()
{
  // BCJ2 threads use same password. So we use long lock.
  MT_LOCK
  
  bool finded = false;
  if (!_cachedKeys.GetKey(_key))
  {
    finded = g_GlobalKeyCache.GetKey(_key);
    if (!finded)
      _key.CalcKey();
    _cachedKeys.Add(_key);
  }
  if (!finded)
    g_GlobalKeyCache.FindAndAdd(_key);
}

#ifndef EXTRACT_ONLY

/*
STDMETHODIMP CEncoder::ResetSalt()
{
  _key.SaltSize = 4;
  g_RandomGenerator.Generate(_key.Salt, _key.SaltSize);
  return S_OK;
}
*/

STDMETHODIMP CEncoder::ResetInitVector()
{
  for (unsigned i = 0; i < sizeof(_iv); i++)
    _iv[i] = 0;
  _ivSize = 8;
  g_RandomGenerator.Generate(_iv, _ivSize);
  return S_OK;
}

STDMETHODIMP CEncoder::WriteCoderProperties(ISequentialOutStream *outStream)
{
  Byte props[2 + sizeof(_key.Salt) + sizeof(_iv)];
  unsigned propsSize = 1;

  props[0] = (Byte)(_key.NumCyclesPower
      | (_key.SaltSize == 0 ? 0 : (1 << 7))
      | (_ivSize       == 0 ? 0 : (1 << 6)));

  if (_key.SaltSize != 0 || _ivSize != 0)
  {
    props[1] = (Byte)(
        ((_key.SaltSize == 0 ? 0 : _key.SaltSize - 1) << 4)
        | (_ivSize      == 0 ? 0 : _ivSize - 1));
    memcpy(props + 2, _key.Salt, _key.SaltSize);
    propsSize = 2 + _key.SaltSize;
    memcpy(props + propsSize, _iv, _ivSize);
    propsSize += _ivSize;
  }

  return WriteStream(outStream, props, propsSize);
}

CEncoder::CEncoder()
{
  // _key.SaltSize = 4; g_RandomGenerator.Generate(_key.Salt, _key.SaltSize);
  // _key.NumCyclesPower = 0x3F;
  _key.NumCyclesPower = 19;
  _aesFilter = new CAesCbcEncoder(kKeySize);
}

#endif

CDecoder::CDecoder()
{
  _aesFilter = new CAesCbcDecoder(kKeySize);
}

STDMETHODIMP CDecoder::SetDecoderProperties2(const Byte *data, UInt32 size)
{
  _key.ClearProps();
 
  _ivSize = 0;
  unsigned i;
  for (i = 0; i < sizeof(_iv); i++)
    _iv[i] = 0;
  
  if (size == 0)
    return S_OK;
  
  Byte b0 = data[0];

  _key.NumCyclesPower = b0 & 0x3F;
  if ((b0 & 0xC0) == 0)
    return size == 1 ? S_OK : E_INVALIDARG;

  if (size <= 1)
    return E_INVALIDARG;

  Byte b1 = data[1];

  unsigned saltSize = ((b0 >> 7) & 1) + (b1 >> 4);
  unsigned ivSize   = ((b0 >> 6) & 1) + (b1 & 0x0F);
  
  if (size != 2 + saltSize + ivSize)
    return E_INVALIDARG;
  _key.SaltSize = saltSize;
  data += 2;
  for (i = 0; i < saltSize; i++)
    _key.Salt[i] = *data++;
  for (i = 0; i < ivSize; i++)
    _iv[i] = *data++;
  return (_key.NumCyclesPower <= k_NumCyclesPower_Supported_MAX
      || _key.NumCyclesPower == 0x3F) ? S_OK : E_NOTIMPL;
}


STDMETHODIMP CBaseCoder::CryptoSetPassword(const Byte *data, UInt32 size)
{
  COM_TRY_BEGIN
  
  _key.Password.CopyFrom(data, (size_t)size);
  return S_OK;
  
  COM_TRY_END
}

STDMETHODIMP CBaseCoder::Init()
{
  COM_TRY_BEGIN
  
  PrepareKey();
  CMyComPtr<ICryptoProperties> cp;
  RINOK(_aesFilter.QueryInterface(IID_ICryptoProperties, &cp));
  if (!cp)
    return E_FAIL;
  RINOK(cp->SetKey(_key.Key, kKeySize));
  RINOK(cp->SetInitVector(_iv, sizeof(_iv)));
  return _aesFilter->Init();
  
  COM_TRY_END
}

STDMETHODIMP_(UInt32) CBaseCoder::Filter(Byte *data, UInt32 size)
{
  return _aesFilter->Filter(data, size);
}

}}