/*
 * Copyright (C) 2008 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/* ---- includes ----------------------------------------------------------- */

#include "b_TensorEm/Int32Mat.h"
#include "b_TensorEm/Functions.h"
#include "b_BasicEm/Math.h"
#include "b_BasicEm/Functions.h"
#include "b_BasicEm/Memory.h"

/* ------------------------------------------------------------------------- */

/* ========================================================================= */
/*                                                                           */
/* ---- \ghd{ auxiliary functions } ---------------------------------------- */
/*                                                                           */
/* ========================================================================= */

/* ------------------------------------------------------------------------- */

void bts_Int32Mat_reduceToNBits( int32* ptrA, uint32 sizeA, int32* bbpPtrA, uint32 nBitsA )
{
	int32 shiftL;

	/* find max element */
	int32 maxL = 0;
	int32* ptrL = ptrA;
	int32 iL = sizeA;
	while( iL-- )
	{
		int32 xL = *ptrL++;
		if( xL < 0 ) xL = -xL;
		if( xL > maxL ) maxL = xL;
	}

	/* determine shift */
	shiftL = bts_absIntLog2( maxL ) + 1 - nBitsA;

	if( shiftL > 0 )
	{
		ptrL = ptrA;
		iL = sizeA;
		while( iL-- )
		{
			*ptrL = ( ( *ptrL >> ( shiftL - 1 ) ) + 1 ) >> 1;
			ptrL++;
		}

		*bbpPtrA -= shiftL;
	}
}

/* ------------------------------------------------------------------------- */

/* ========================================================================= */
/*                                                                           */
/* ---- \ghd{ constructor / destructor } ----------------------------------- */
/*                                                                           */
/* ========================================================================= */

/* ------------------------------------------------------------------------- */

void bts_Int32Mat_init( struct bbs_Context* cpA,
					    struct bts_Int32Mat* ptrA )
{
	ptrA->widthE = 0;
	bbs_Int32Arr_init( cpA, &ptrA->arrE );
}

/* ------------------------------------------------------------------------- */

void bts_Int32Mat_exit( struct bbs_Context* cpA,
					    struct bts_Int32Mat* ptrA )
{
	ptrA->widthE = 0;
	bbs_Int32Arr_exit( cpA, &ptrA->arrE );
}
/* ------------------------------------------------------------------------- */

/* ========================================================================= */
/*                                                                           */
/* ---- \ghd{ operators } -------------------------------------------------- */
/*                                                                           */
/* ========================================================================= */

/* ------------------------------------------------------------------------- */

/* ========================================================================= */
/*                                                                           */
/* ---- \ghd{ query functions } -------------------------------------------- */
/*                                                                           */
/* ========================================================================= */

/* ------------------------------------------------------------------------- */

/* ========================================================================= */
/*                                                                           */
/* ---- \ghd{ modify functions } ------------------------------------------- */
/*                                                                           */
/* ========================================================================= */

/* ------------------------------------------------------------------------- */
	
void bts_Int32Mat_create( struct bbs_Context* cpA,
						  struct bts_Int32Mat* ptrA, 
						  int32 widthA,
				          struct bbs_MemSeg* mspA )
{
	if( bbs_Context_error( cpA ) ) return;
	bbs_Int32Arr_create( cpA, &ptrA->arrE, widthA * widthA, mspA );
	ptrA->widthE = widthA;
}

/* ------------------------------------------------------------------------- */
	
void bts_Int32Mat_copy( struct bbs_Context* cpA,
					    struct bts_Int32Mat* ptrA, 
						const struct bts_Int32Mat* srcPtrA )
{
	if( ptrA->widthE != srcPtrA->widthE )
	{
		bbs_ERROR0( "void bts_Int32Mat_copy( struct bts_Int32Mat* ptrA, struct bts_Int32Mat* srcPtrA ):\n"
			       "size mismatch" );
		return;
	}

	bbs_Int32Arr_copy( cpA, &ptrA->arrE, &srcPtrA->arrE );
}

/* ------------------------------------------------------------------------- */
	
/* ========================================================================= */
/*                                                                           */
/* ---- \ghd{ I/O } -------------------------------------------------------- */
/*                                                                           */
/* ========================================================================= */

/* ------------------------------------------------------------------------- */
	
uint32 bts_Int32Mat_memSize( struct bbs_Context* cpA,
							 const struct bts_Int32Mat *ptrA )
{
	return  bbs_SIZEOF16( uint32 )
		  + bbs_SIZEOF16( uint32 ) /* version */
		  + bbs_SIZEOF16( ptrA->widthE ) 
		  + bbs_Int32Arr_memSize( cpA, &ptrA->arrE );
}

/* ------------------------------------------------------------------------- */
	
uint32 bts_Int32Mat_memWrite( struct bbs_Context* cpA,
							  const struct bts_Int32Mat* ptrA, 
							  uint16* memPtrA )
{
	uint32 memSizeL = bts_Int32Mat_memSize( cpA, ptrA );
	memPtrA += bbs_memWrite32( &memSizeL, memPtrA );
	memPtrA += bbs_memWriteUInt32( bts_INT32MAT_VERSION, memPtrA );
	memPtrA += bbs_memWrite32( &ptrA->widthE, memPtrA );
	memPtrA += bbs_Int32Arr_memWrite( cpA, &ptrA->arrE, memPtrA );
	return memSizeL;
}

/* ------------------------------------------------------------------------- */
	
uint32 bts_Int32Mat_memRead( struct bbs_Context* cpA,
							 struct bts_Int32Mat* ptrA, 
							 const uint16* memPtrA,
				             struct bbs_MemSeg* mspA )
{
	uint32 memSizeL, versionL;
	if( bbs_Context_error( cpA ) ) return 0;
	memPtrA += bbs_memRead32( &memSizeL, memPtrA );
	memPtrA += bbs_memReadVersion32( cpA, &versionL, bts_INT32MAT_VERSION, memPtrA );
	memPtrA += bbs_memRead32( &ptrA->widthE, memPtrA );
	memPtrA += bbs_Int32Arr_memRead( cpA, &ptrA->arrE, memPtrA, mspA );

	if( memSizeL != bts_Int32Mat_memSize( cpA, ptrA ) )
	{
		bbs_ERR0( bbs_ERR_CORRUPT_DATA, "uint32 bts_Int32Mat_memRead( const struct bts_Int32Mat* ptrA, const void* memPtrA ):\n"
                  "size mismatch" ); 
	}
	return memSizeL;
}

/* ------------------------------------------------------------------------- */
	
/* ========================================================================= */
/*                                                                           */
/* ---- \ghd{ exec functions } --------------------------------------------- */
/*                                                                           */
/* ========================================================================= */

/* ------------------------------------------------------------------------- */

flag bts_Int32Mat_solve( struct bbs_Context* cpA,
						 const int32* matA,
						 int32 matWidthA,
						 const int32* inVecA,
						 int32* outVecA,
						 int32 bbpA,
						 int32* tmpMatA,
						 int32* tmpVecA )
{
	bbs_memcpy32( tmpMatA, matA, ( matWidthA * matWidthA ) * bbs_SIZEOF32( int32 ) );

	return bts_Int32Mat_solve2( cpA, 
		                        tmpMatA,
								matWidthA,
								inVecA,
								outVecA,
								bbpA,
								tmpVecA );
}

/* ------------------------------------------------------------------------- */

flag bts_Int32Mat_solve2( struct bbs_Context* cpA,
						  int32* matA,
						  int32 matWidthA,
						  const int32* inVecA,
						  int32* outVecA,
						  int32 bbpA,
						  int32* tmpVecA )
{
	int32 sizeL = matWidthA;
	int32 bbpL = bbpA;
	int32 iL, jL, kL;
	int32 iPivL;
	int32 jPivL;

	int32* vecL      = outVecA;
	int32* matL      = matA;
	int32* checkArrL = tmpVecA;

	for( iL = 0; iL < sizeL; iL++ )
	{
		checkArrL[ iL ] = 0;
	}
	
	bbs_memcpy32( outVecA, inVecA, sizeL * bbs_SIZEOF32( int32 ) );

	iPivL = 0;

	for( kL = 0; kL < sizeL; kL++ )
	{
		/* find pivot */
		int32 maxAbsL = 0;
		int32* pivRowL;

		int32 bbp_pivRowL, bbp_vecL, shiftL;

		jPivL = -1;
		for( iL = 0; iL < sizeL; iL++ )
		{
			if( checkArrL[ iL ] != 1 )
			{
				int32* rowL = matL + ( iL * sizeL );
				for( jL = 0; jL < sizeL; jL++ )
				{
					if( checkArrL[ jL ] == 0 )
					{
						int32 absElemL = rowL[ jL ];
						if( absElemL < 0 ) absElemL = -absElemL;
						if( maxAbsL < absElemL )
						{
							maxAbsL = absElemL;
							iPivL = iL;
							jPivL = jL;
						}
					} 
					else if( checkArrL[ jL ] > 1 )
					{
						return FALSE;
					}
				}
			}
		}

		/* successfull ? */
		if( jPivL < 0 )
		{
			return FALSE;
		}

		checkArrL[ jPivL ]++; 

		/* exchange rows to put pivot on diagonal, if neccessary */
		if( iPivL != jPivL )
		{
			int32* row1PtrL = matL + ( iPivL * sizeL );
			int32* row2PtrL = matL + ( jPivL * sizeL );
			for( jL = 0; jL < sizeL; jL++ )
			{
				int32 tmpL = *row1PtrL;
				*row1PtrL++ = *row2PtrL;
				*row2PtrL++ = tmpL;
			}

			{
				int32 tmpL = vecL[ jPivL ];
				vecL[ jPivL ] = vecL[ iPivL ];
				vecL[ iPivL ] = tmpL;
			}
		}
		/* now index jPivL specifies pivot row and maximum element */


		/**	Overflow protection: only if the highest bit of the largest matrix element is set,
		 *	we need to shift the whole matrix and the right side vector 1 bit to the right,
		 *	to make sure there can be no overflow when the pivot row gets subtracted from the
		 *	other rows.
		 *	Getting that close to overflow is a rare event, so this shift will happen only 
		 *	occasionally, or not at all.
		 */
		if( maxAbsL & 1073741824 )  /*( 1 << 30 )*/
		{
			/* right shift matrix by 1 */
			int32 iL = sizeL * sizeL;
			int32* ptrL = matL;
			while( iL-- )
			{
				*ptrL = ( *ptrL + 1 ) >> 1;
				ptrL++;
			}

			/* right shift right side vector by 1 */
			iL = sizeL;
			ptrL = vecL;
			while( iL-- )
			{
				*ptrL = ( *ptrL + 1 ) >> 1;
				ptrL++;
			}

			/* decrement bbpL */
			bbpL--;
		}


		/* reduce elements of pivot row to 15 bit */
		pivRowL = matL + jPivL * sizeL;
		bbp_pivRowL = bbpL;
		bts_Int32Mat_reduceToNBits( pivRowL, sizeL, &bbp_pivRowL, 15 );

		/* scale pivot row such that maximum equals 1 */
		{
			int32 maxL = pivRowL[ jPivL ];
			int32 bbp_maxL = bbp_pivRowL;
			int32 factorL = 1073741824 / maxL; /*( 1 << 30 )*/

			for( jL = 0; jL < sizeL; jL++ )
			{
				pivRowL[ jL ] = ( pivRowL[ jL ] * factorL + ( 1 << 14 ) ) >> 15;
			}
			bbp_pivRowL = 15;

			/* set to 1 to avoid computational errors */
			pivRowL[ jPivL ] = ( int32 )1 << bbp_pivRowL; 

			shiftL = 30 - bts_absIntLog2( vecL[ jPivL ] );

			vecL[ jPivL ] = ( vecL[ jPivL ] << shiftL ) / maxL;
			bbp_vecL = bbpL + shiftL - bbp_maxL;

			bbs_int32ReduceToNBits( &( vecL[ jPivL ] ), &bbp_vecL, 15 );
		}

		/* subtract pivot row from all other rows */
		for( iL = 0; iL < sizeL; iL++ )
		{
			if( iL != jPivL )
			{
				int32* rowPtrL = matL + iL * sizeL;

				int32 tmpL = *( rowPtrL + jPivL );
				int32 bbp_tmpL = bbpL;
				bbs_int32ReduceToNBits( &tmpL, &bbp_tmpL, 15 );

				shiftL = bbp_tmpL + bbp_pivRowL - bbpL;
				if( shiftL > 0 )
				{
					for( jL = 0; jL < sizeL; jL++ )
					{
						*rowPtrL++ -= ( ( ( tmpL * pivRowL[ jL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
					}
				}
				else
				{
					for( jL = 0; jL < sizeL; jL++ )
					{
						*rowPtrL++ -= ( tmpL * pivRowL[ jL ] ) << -shiftL;
					}
				}

				shiftL = bbp_tmpL + bbp_vecL - bbpL;
				if( shiftL > 0 )
				{
					vecL[ iL ] -= ( ( ( tmpL * vecL[ jPivL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1;
				}
				else
				{
					vecL[ iL ] -= ( tmpL * vecL[ jPivL ] ) << -shiftL;
				}
			}
		}

		/* change bbp of pivot row back to bbpL */
		shiftL = bbpL - bbp_pivRowL;
		if( shiftL >= 0 )
		{
			for( jL = 0; jL < sizeL; jL++ )
			{
				pivRowL[ jL ] <<= shiftL;
			}
		}
		else
		{
			shiftL = -shiftL;
			for( jL = 0; jL < sizeL; jL++ )
			{
				pivRowL[ jL ] = ( ( pivRowL[ jL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
			}
		}

		shiftL = bbpL - bbp_vecL;
		if( shiftL >= 0 )
		{
			vecL[ jPivL ] <<= shiftL;
		}
		else
		{
			shiftL = -shiftL;
			vecL[ jPivL ] = ( ( vecL[ jPivL ] >> ( shiftL - 1 ) ) + 1 ) >> 1;
		}
/*
if( sizeL <= 5 ) bts_Int32Mat_print( matL, vecL, sizeL, bbpL );
*/
	}	/* of kL */

	/* in case bbpL has been decreased by the overflow protection, change it back now */
	if( bbpA > bbpL )
	{
		/* find largest element of solution vector */
		int32 maxL = 0;
		int32 iL, shiftL;
		for( iL = 0; iL < sizeL; iL++ )
		{
			int32 xL = vecL[ iL ];
			if( xL < 0 ) xL = -xL;
			if( xL > maxL ) maxL = xL;
		}
		
		/* check whether we can left shift without overflow */
		shiftL = 30 - bts_absIntLog2( maxL );
		if( shiftL < ( bbpA - bbpL ) )
		{
			/* 
			    bbs_WARNING1( "flag bts_Int32Mat_solve2( ... ): getting overflow when trying to "
				"compute solution vector with bbp = %d. Choose smaller bbp.\n", bbpA );
			*/

			return FALSE;
		}	

		/* shift left */
		shiftL = bbpA - bbpL;
		for( iL = 0; iL < sizeL; iL++ ) vecL[ iL ] <<= shiftL;
	}

	return TRUE;
}

/* ------------------------------------------------------------------------- */

/* ========================================================================= */