// UpdateCallback.cpp

#include "StdAfx.h"

#include "../../../Common/ComTry.h"
#include "../../../Common/IntToString.h"
#include "../../../Common/StringConvert.h"
#include "../../../Common/Wildcard.h"

#include "../../../Windows/FileDir.h"
#include "../../../Windows/FileName.h"
#include "../../../Windows/PropVariant.h"
#include "../../../Windows/Synchronization.h"

#include "../../Common/FileStreams.h"
#include "../../Common/StreamObjects.h"

#include "UpdateCallback.h"

#if defined(_WIN32) && !defined(UNDER_CE)
#define _USE_SECURITY_CODE
#include "../../../Windows/SecurityUtils.h"
#endif

using namespace NWindows;
using namespace NFile;

#ifdef _USE_SECURITY_CODE
bool InitLocalPrivileges();
#endif

CArchiveUpdateCallback::CArchiveUpdateCallback():
  Callback(0),
  ShareForWrite(false),
  StdInMode(false),
  DirItems(0),
  ArcItems(0),
  UpdatePairs(0),
  NewNames(0),
  KeepOriginalItemNames(false),
  ProcessedItemsStatuses(NULL),
  ParentDirItem(NULL),
  StoreNtSecurity(false),
  StoreHardLinks(false),
  StoreSymLinks(false),
  _hardIndex_From((UInt32)(Int32)-1)
{
  #ifdef _USE_SECURITY_CODE
  _saclEnabled = InitLocalPrivileges();
  #endif
}


STDMETHODIMP CArchiveUpdateCallback::SetTotal(UInt64 size)
{
  COM_TRY_BEGIN
  return Callback->SetTotal(size);
  COM_TRY_END
}

STDMETHODIMP CArchiveUpdateCallback::SetCompleted(const UInt64 *completeValue)
{
  COM_TRY_BEGIN
  return Callback->SetCompleted(completeValue);
  COM_TRY_END
}

STDMETHODIMP CArchiveUpdateCallback::SetRatioInfo(const UInt64 *inSize, const UInt64 *outSize)
{
  COM_TRY_BEGIN
  return Callback->SetRatioInfo(inSize, outSize);
  COM_TRY_END
}


/*
static const STATPROPSTG kProps[] =
{
  { NULL, kpidPath, VT_BSTR},
  { NULL, kpidIsDir, VT_BOOL},
  { NULL, kpidSize, VT_UI8},
  { NULL, kpidCTime, VT_FILETIME},
  { NULL, kpidATime, VT_FILETIME},
  { NULL, kpidMTime, VT_FILETIME},
  { NULL, kpidAttrib, VT_UI4},
  { NULL, kpidIsAnti, VT_BOOL}
};

STDMETHODIMP CArchiveUpdateCallback::EnumProperties(IEnumSTATPROPSTG **)
{
  return CStatPropEnumerator::CreateEnumerator(kProps, ARRAY_SIZE(kProps), enumerator);
}
*/

STDMETHODIMP CArchiveUpdateCallback::GetUpdateItemInfo(UInt32 index,
      Int32 *newData, Int32 *newProps, UInt32 *indexInArchive)
{
  COM_TRY_BEGIN
  RINOK(Callback->CheckBreak());
  const CUpdatePair2 &up = (*UpdatePairs)[index];
  if (newData) *newData = BoolToInt(up.NewData);
  if (newProps) *newProps = BoolToInt(up.NewProps);
  if (indexInArchive)
  {
    *indexInArchive = (UInt32)(Int32)-1;
    if (up.ExistInArchive())
      *indexInArchive = (ArcItems == 0) ? up.ArcIndex : (*ArcItems)[up.ArcIndex].IndexInServer;
  }
  return S_OK;
  COM_TRY_END
}

STDMETHODIMP CArchiveUpdateCallback::GetRootProp(PROPID propID, PROPVARIANT *value)
{
  NCOM::CPropVariant prop;
  switch (propID)
  {
    case kpidIsDir:  prop = true; break;
    case kpidAttrib: if (ParentDirItem) prop = ParentDirItem->Attrib; break;
    case kpidCTime:  if (ParentDirItem) prop = ParentDirItem->CTime; break;
    case kpidATime:  if (ParentDirItem) prop = ParentDirItem->ATime; break;
    case kpidMTime:  if (ParentDirItem) prop = ParentDirItem->MTime; break;
  }
  prop.Detach(value);
  return S_OK;
}

STDMETHODIMP CArchiveUpdateCallback::GetParent(UInt32 /* index */, UInt32 *parent, UInt32 *parentType)
{
  *parentType = NParentType::kDir;
  *parent = (UInt32)(Int32)-1;
  return S_OK;
}

STDMETHODIMP CArchiveUpdateCallback::GetNumRawProps(UInt32 *numProps)
{
  *numProps = 0;
  if (StoreNtSecurity)
    *numProps = 1;
  return S_OK;
}

STDMETHODIMP CArchiveUpdateCallback::GetRawPropInfo(UInt32 /* index */, BSTR *name, PROPID *propID)
{
  *name = NULL;
  *propID = kpidNtSecure;
  return S_OK;
}

STDMETHODIMP CArchiveUpdateCallback::GetRootRawProp(PROPID
    #ifdef _USE_SECURITY_CODE
    propID
    #endif
    , const void **data, UInt32 *dataSize, UInt32 *propType)
{
  *data = 0;
  *dataSize = 0;
  *propType = 0;
  if (!StoreNtSecurity)
    return S_OK;
  #ifdef _USE_SECURITY_CODE
  if (propID == kpidNtSecure)
  {
    if (StdInMode)
      return S_OK;

    if (ParentDirItem)
    {
      if (ParentDirItem->SecureIndex < 0)
        return S_OK;
      const CByteBuffer &buf = DirItems->SecureBlocks.Bufs[ParentDirItem->SecureIndex];
      *data = buf;
      *dataSize = (UInt32)buf.Size();
      *propType = NPropDataType::kRaw;
      return S_OK;
    }

    if (GetRootProps)
      return GetRootProps->GetRootRawProp(propID, data, dataSize, propType);
  }
  #endif
  return S_OK;
}

//    #ifdef _USE_SECURITY_CODE
//    #endif

STDMETHODIMP CArchiveUpdateCallback::GetRawProp(UInt32 index, PROPID propID, const void **data, UInt32 *dataSize, UInt32 *propType)
{
  *data = 0;
  *dataSize = 0;
  *propType = 0;

  if (propID == kpidNtSecure ||
      propID == kpidNtReparse)
  {
    if (StdInMode)
      return S_OK;

    const CUpdatePair2 &up = (*UpdatePairs)[index];
    if (up.UseArcProps && up.ExistInArchive() && GetRawProps)
      return GetRawProps->GetRawProp(
          ArcItems ? (*ArcItems)[up.ArcIndex].IndexInServer : up.ArcIndex,
          propID, data, dataSize, propType);

    {
      const CUpdatePair2 &up = (*UpdatePairs)[index];
      /*
      if (!up.NewData)
        return E_FAIL;
      */
      if (up.IsAnti)
        return S_OK;
      
      #ifndef UNDER_CE
      const CDirItem &di = DirItems->Items[up.DirIndex];
      #endif

      #ifdef _USE_SECURITY_CODE
      if (propID == kpidNtSecure)
      {
        if (!StoreNtSecurity)
          return S_OK;
        if (di.SecureIndex < 0)
          return S_OK;
        const CByteBuffer &buf = DirItems->SecureBlocks.Bufs[di.SecureIndex];
        *data = buf;
        *dataSize = (UInt32)buf.Size();
        *propType = NPropDataType::kRaw;
      }
      else
      #endif
      {
        // propID == kpidNtReparse
        if (!StoreSymLinks)
          return S_OK;
        #ifndef UNDER_CE
        const CByteBuffer *buf = &di.ReparseData2;
        if (buf->Size() == 0)
          buf = &di.ReparseData;
        if (buf->Size() != 0)
        {
          *data = *buf;
          *dataSize = (UInt32)buf->Size();
          *propType = NPropDataType::kRaw;
        }
        #endif
      }
 
      return S_OK;
    }
  }

  return S_OK;
}

#ifndef UNDER_CE

static UString GetRelativePath(const UString &to, const UString &from)
{
  UStringVector partsTo, partsFrom;
  SplitPathToParts(to, partsTo);
  SplitPathToParts(from, partsFrom);

  unsigned i;
  for (i = 0;; i++)
  {
    if (i + 1 >= partsFrom.Size() ||
        i + 1 >= partsTo.Size())
      break;
    if (CompareFileNames(partsFrom[i], partsTo[i]) != 0)
      break;
  }

  if (i == 0)
  {
    #ifdef _WIN32
    if (NName::IsDrivePath(to) ||
        NName::IsDrivePath(from))
      return to;
    #endif
  }

  UString s;
  unsigned k;
  
  for (k = i + 1; k < partsFrom.Size(); k++)
    s += L".." WSTRING_PATH_SEPARATOR;
  
  for (k = i; k < partsTo.Size(); k++)
  {
    if (k != i)
      s += WCHAR_PATH_SEPARATOR;
    s += partsTo[k];
  }

  return s;
}

#endif

STDMETHODIMP CArchiveUpdateCallback::GetProperty(UInt32 index, PROPID propID, PROPVARIANT *value)
{
  COM_TRY_BEGIN
  const CUpdatePair2 &up = (*UpdatePairs)[index];
  NCOM::CPropVariant prop;

  if (up.NewData)
  {
    /*
    if (propID == kpidIsHardLink)
    {
      prop = _isHardLink;
      prop.Detach(value);
      return S_OK;
    }
    */
    if (propID == kpidSymLink)
    {
      if (index == _hardIndex_From)
      {
        prop.Detach(value);
        return S_OK;
      }
      if (up.DirIndex >= 0)
      {
        #ifndef UNDER_CE
        const CDirItem &di = DirItems->Items[up.DirIndex];
        // if (di.IsDir())
        {
          CReparseAttr attr;
          if (attr.Parse(di.ReparseData, di.ReparseData.Size()))
          {
            UString simpleName = attr.GetPath();
            if (attr.IsRelative())
              prop = simpleName;
            else
            {
              const UString phyPath = DirItems->GetPhyPath(up.DirIndex);
              FString fullPath;
              if (NDir::MyGetFullPathName(us2fs(phyPath), fullPath))
              {
                prop = GetRelativePath(simpleName, fs2us(fullPath));
              }
            }
            prop.Detach(value);
            return S_OK;
          }
        }
        #endif
      }
    }
    else if (propID == kpidHardLink)
    {
      if (index == _hardIndex_From)
      {
        const CKeyKeyValPair &pair = _map[_hardIndex_To];
        const CUpdatePair2 &up2 = (*UpdatePairs)[pair.Value];
        prop = DirItems->GetLogPath(up2.DirIndex);
        prop.Detach(value);
        return S_OK;
      }
      if (up.DirIndex >= 0)
      {
        prop.Detach(value);
        return S_OK;
      }
    }
  }
  
  if (up.IsAnti
      && propID != kpidIsDir
      && propID != kpidPath
      && propID != kpidIsAltStream)
  {
    switch (propID)
    {
      case kpidSize:  prop = (UInt64)0; break;
      case kpidIsAnti:  prop = true; break;
    }
  }
  else if (propID == kpidPath && up.NewNameIndex >= 0)
    prop = (*NewNames)[up.NewNameIndex];
  else if (propID == kpidShortName && up.NewNameIndex >= 0 && up.IsMainRenameItem)
  {
    // we can generate new ShortName here;
  }
  else if ((up.UseArcProps
      || (KeepOriginalItemNames && (propID == kpidPath || propID == kpidIsAltStream)))
      && up.ExistInArchive() && Archive)
    return Archive->GetProperty(ArcItems ? (*ArcItems)[up.ArcIndex].IndexInServer : up.ArcIndex, propID, value);
  else if (up.ExistOnDisk())
  {
    const CDirItem &di = DirItems->Items[up.DirIndex];
    switch (propID)
    {
      case kpidPath:  prop = DirItems->GetLogPath(up.DirIndex); break;
      case kpidIsDir:  prop = di.IsDir(); break;
      case kpidSize:  prop = di.Size; break;
      case kpidAttrib:  prop = di.Attrib; break;
      case kpidCTime:  prop = di.CTime; break;
      case kpidATime:  prop = di.ATime; break;
      case kpidMTime:  prop = di.MTime; break;
      case kpidIsAltStream:  prop = di.IsAltStream; break;
      #if defined(_WIN32) && !defined(UNDER_CE)
      // case kpidShortName:  prop = di.ShortName; break;
      #endif
    }
  }
  prop.Detach(value);
  return S_OK;
  COM_TRY_END
}

static NSynchronization::CCriticalSection CS;

STDMETHODIMP CArchiveUpdateCallback::GetStream(UInt32 index, ISequentialInStream **inStream)
{
  COM_TRY_BEGIN
  *inStream = NULL;
  const CUpdatePair2 &up = (*UpdatePairs)[index];
  if (!up.NewData)
    return E_FAIL;
  
  RINOK(Callback->CheckBreak());
  RINOK(Callback->Finilize());

  bool isDir = IsDir(up);

  if (up.IsAnti)
  {
    UString name;
    if (up.ArcIndex >= 0)
      name = (*ArcItems)[up.ArcIndex].Name;
    else if (up.DirIndex >= 0)
      name = DirItems->GetLogPath(up.DirIndex);
    RINOK(Callback->GetStream(name, true));
    
    /* 9.33: fixed. Handlers expect real stream object for files, even for anti-file.
       so we return empty stream */

    if (!isDir)
    {
      CBufInStream *inStreamSpec = new CBufInStream();
      CMyComPtr<ISequentialInStream> inStreamLoc = inStreamSpec;
      inStreamSpec->Init(NULL, 0);
      *inStream = inStreamLoc.Detach();
    }
    return S_OK;
  }
  
  RINOK(Callback->GetStream(DirItems->GetLogPath(up.DirIndex), false));
 
  if (isDir)
    return S_OK;

  if (StdInMode)
  {
    CStdInFileStream *inStreamSpec = new CStdInFileStream;
    CMyComPtr<ISequentialInStream> inStreamLoc(inStreamSpec);
    *inStream = inStreamLoc.Detach();
  }
  else
  {
    CInFileStream *inStreamSpec = new CInFileStream;
    CMyComPtr<ISequentialInStream> inStreamLoc(inStreamSpec);

    inStreamSpec->SupportHardLinks = StoreHardLinks;

    const UString path = DirItems->GetPhyPath(up.DirIndex);

    #if defined(_WIN32) && !defined(UNDER_CE)
    if (DirItems->Items[up.DirIndex].AreReparseData())
    {
      if (!inStreamSpec->File.OpenReparse(us2fs(path)))
      {
        return Callback->OpenFileError(path, ::GetLastError());
      }
    }
    else
    #endif
    if (!inStreamSpec->OpenShared(us2fs(path), ShareForWrite))
    {
      return Callback->OpenFileError(path, ::GetLastError());
    }

    if (StoreHardLinks)
    {
      CStreamFileProps props;
      if (inStreamSpec->GetProps2(&props) == S_OK)
      {
        if (props.NumLinks > 1)
        {
          CKeyKeyValPair pair;
          pair.Key1 = props.VolID;
          pair.Key2 = props.FileID_Low;
          pair.Value = index;
          unsigned numItems = _map.Size();
          unsigned pairIndex = _map.AddToUniqueSorted2(pair);
          if (numItems == _map.Size())
          {
            // const CKeyKeyValPair &pair2 = _map.Pairs[pairIndex];
            _hardIndex_From = index;
            _hardIndex_To = pairIndex;
            // we could return NULL as stream, but it's better to return real stream
            // return S_OK;
          }
        }
      }
    }

    if (ProcessedItemsStatuses)
    {
      NSynchronization::CCriticalSectionLock lock(CS);
      ProcessedItemsStatuses[up.DirIndex] = 1;
    }
    *inStream = inStreamLoc.Detach();
  }
  
  return S_OK;
  COM_TRY_END
}

STDMETHODIMP CArchiveUpdateCallback::SetOperationResult(Int32 operationResult)
{
  COM_TRY_BEGIN
  return Callback->SetOperationResult(operationResult);
  COM_TRY_END
}

STDMETHODIMP CArchiveUpdateCallback::GetVolumeSize(UInt32 index, UInt64 *size)
{
  if (VolumesSizes.Size() == 0)
    return S_FALSE;
  if (index >= (UInt32)VolumesSizes.Size())
    index = VolumesSizes.Size() - 1;
  *size = VolumesSizes[index];
  return S_OK;
}

STDMETHODIMP CArchiveUpdateCallback::GetVolumeStream(UInt32 index, ISequentialOutStream **volumeStream)
{
  COM_TRY_BEGIN
  FChar temp[16];
  ConvertUInt32ToString(index + 1, temp);
  FString res = temp;
  while (res.Len() < 2)
    res.InsertAtFront(FTEXT('0'));
  FString fileName = VolName;
  fileName += L'.';
  fileName += res;
  fileName += VolExt;
  COutFileStream *streamSpec = new COutFileStream;
  CMyComPtr<ISequentialOutStream> streamLoc(streamSpec);
  if (!streamSpec->Create(fileName, false))
    return ::GetLastError();
  *volumeStream = streamLoc.Detach();
  return S_OK;
  COM_TRY_END
}

STDMETHODIMP CArchiveUpdateCallback::CryptoGetTextPassword2(Int32 *passwordIsDefined, BSTR *password)
{
  COM_TRY_BEGIN
  return Callback->CryptoGetTextPassword2(passwordIsDefined, password);
  COM_TRY_END
}

STDMETHODIMP CArchiveUpdateCallback::CryptoGetTextPassword(BSTR *password)
{
  COM_TRY_BEGIN
  return Callback->CryptoGetTextPassword(password);
  COM_TRY_END
}