/* -*- Mode: C; tab-width: 4 -*-
*
* Copyright (c) 2002-2004 Apple Computer, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "Secret.h"
#include <stdarg.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <winsock2.h>
#include <ws2tcpip.h>
#include <windows.h>
#include <process.h>
#include <ntsecapi.h>
#include <lm.h>
#include "DebugServices.h"
mDNSlocal OSStatus MakeLsaStringFromUTF8String( PLSA_UNICODE_STRING output, const char * input );
mDNSlocal OSStatus MakeUTF8StringFromLsaString( char * output, size_t len, PLSA_UNICODE_STRING input );
BOOL
LsaGetSecret( const char * inDomain, char * outDomain, unsigned outDomainSize, char * outKey, unsigned outKeySize, char * outSecret, unsigned outSecretSize )
{
PLSA_UNICODE_STRING domainLSA;
PLSA_UNICODE_STRING keyLSA;
PLSA_UNICODE_STRING secretLSA;
size_t i;
size_t dlen;
LSA_OBJECT_ATTRIBUTES attrs;
LSA_HANDLE handle = NULL;
NTSTATUS res;
OSStatus err;
check( inDomain );
check( outDomain );
check( outKey );
check( outSecret );
// Initialize
domainLSA = NULL;
keyLSA = NULL;
secretLSA = NULL;
// Make sure we have enough space to add trailing dot
dlen = strlen( inDomain );
err = strcpy_s( outDomain, outDomainSize - 2, inDomain );
require_noerr( err, exit );
// If there isn't a trailing dot, add one because the mDNSResponder
// presents names with the trailing dot.
if ( outDomain[ dlen - 1 ] != '.' )
{
outDomain[ dlen++ ] = '.';
outDomain[ dlen ] = '\0';
}
// Canonicalize name by converting to lower case (keychain and some name servers are case sensitive)
for ( i = 0; i < dlen; i++ )
{
outDomain[i] = (char) tolower( outDomain[i] ); // canonicalize -> lower case
}
// attrs are reserved, so initialize to zeroes.
ZeroMemory( &attrs, sizeof( attrs ) );
// Get a handle to the Policy object on the local system
res = LsaOpenPolicy( NULL, &attrs, POLICY_GET_PRIVATE_INFORMATION, &handle );
err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
require_noerr( err, exit );
// Get the encrypted data
domainLSA = ( PLSA_UNICODE_STRING ) malloc( sizeof( LSA_UNICODE_STRING ) );
require_action( domainLSA != NULL, exit, err = mStatus_NoMemoryErr );
err = MakeLsaStringFromUTF8String( domainLSA, outDomain );
require_noerr( err, exit );
// Retrieve the key
res = LsaRetrievePrivateData( handle, domainLSA, &keyLSA );
err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
require_noerr_quiet( err, exit );
// <rdar://problem/4192119> Lsa secrets use a flat naming space. Therefore, we will prepend "$" to the keyname to
// make sure it doesn't conflict with a zone name.
// Strip off the "$" prefix.
err = MakeUTF8StringFromLsaString( outKey, outKeySize, keyLSA );
require_noerr( err, exit );
require_action( outKey[0] == '$', exit, err = kUnknownErr );
memcpy( outKey, outKey + 1, strlen( outKey ) );
// Retrieve the secret
res = LsaRetrievePrivateData( handle, keyLSA, &secretLSA );
err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
require_noerr_quiet( err, exit );
// Convert the secret to UTF8 string
err = MakeUTF8StringFromLsaString( outSecret, outSecretSize, secretLSA );
require_noerr( err, exit );
exit:
if ( domainLSA != NULL )
{
if ( domainLSA->Buffer != NULL )
{
free( domainLSA->Buffer );
}
free( domainLSA );
}
if ( keyLSA != NULL )
{
LsaFreeMemory( keyLSA );
}
if ( secretLSA != NULL )
{
LsaFreeMemory( secretLSA );
}
if ( handle )
{
LsaClose( handle );
handle = NULL;
}
return ( !err ) ? TRUE : FALSE;
}
mDNSBool
LsaSetSecret( const char * inDomain, const char * inKey, const char * inSecret )
{
size_t inDomainLength;
size_t inKeyLength;
char domain[ 1024 ];
char key[ 1024 ];
LSA_OBJECT_ATTRIBUTES attrs;
LSA_HANDLE handle = NULL;
NTSTATUS res;
LSA_UNICODE_STRING lucZoneName;
LSA_UNICODE_STRING lucKeyName;
LSA_UNICODE_STRING lucSecretName;
BOOL ok = TRUE;
OSStatus err;
require_action( inDomain != NULL, exit, ok = FALSE );
require_action( inKey != NULL, exit, ok = FALSE );
require_action( inSecret != NULL, exit, ok = FALSE );
// If there isn't a trailing dot, add one because the mDNSResponder
// presents names with the trailing dot.
ZeroMemory( domain, sizeof( domain ) );
inDomainLength = strlen( inDomain );
require_action( inDomainLength > 0, exit, ok = FALSE );
err = strcpy_s( domain, sizeof( domain ) - 2, inDomain );
require_action( !err, exit, ok = FALSE );
if ( domain[ inDomainLength - 1 ] != '.' )
{
domain[ inDomainLength++ ] = '.';
domain[ inDomainLength ] = '\0';
}
// <rdar://problem/4192119>
//
// Prepend "$" to the key name, so that there will
// be no conflict between the zone name and the key
// name
ZeroMemory( key, sizeof( key ) );
inKeyLength = strlen( inKey );
require_action( inKeyLength > 0 , exit, ok = FALSE );
key[ 0 ] = '$';
err = strcpy_s( key + 1, sizeof( key ) - 3, inKey );
require_action( !err, exit, ok = FALSE );
inKeyLength++;
if ( key[ inKeyLength - 1 ] != '.' )
{
key[ inKeyLength++ ] = '.';
key[ inKeyLength ] = '\0';
}
// attrs are reserved, so initialize to zeroes.
ZeroMemory( &attrs, sizeof( attrs ) );
// Get a handle to the Policy object on the local system
res = LsaOpenPolicy( NULL, &attrs, POLICY_ALL_ACCESS, &handle );
err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
require_noerr( err, exit );
// Intializing PLSA_UNICODE_STRING structures
err = MakeLsaStringFromUTF8String( &lucZoneName, domain );
require_noerr( err, exit );
err = MakeLsaStringFromUTF8String( &lucKeyName, key );
require_noerr( err, exit );
err = MakeLsaStringFromUTF8String( &lucSecretName, inSecret );
require_noerr( err, exit );
// Store the private data.
res = LsaStorePrivateData( handle, &lucZoneName, &lucKeyName );
err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
require_noerr( err, exit );
res = LsaStorePrivateData( handle, &lucKeyName, &lucSecretName );
err = translate_errno( res == 0, LsaNtStatusToWinError( res ), kUnknownErr );
require_noerr( err, exit );
exit:
if ( handle )
{
LsaClose( handle );
handle = NULL;
}
return ok;
}
//===========================================================================================================================
// MakeLsaStringFromUTF8String
//===========================================================================================================================
mDNSlocal OSStatus
MakeLsaStringFromUTF8String( PLSA_UNICODE_STRING output, const char * input )
{
int size;
OSStatus err;
check( input );
check( output );
output->Buffer = NULL;
size = MultiByteToWideChar( CP_UTF8, 0, input, -1, NULL, 0 );
err = translate_errno( size > 0, GetLastError(), kUnknownErr );
require_noerr( err, exit );
output->Length = (USHORT)( size * sizeof( wchar_t ) );
output->Buffer = (PWCHAR) malloc( output->Length );
require_action( output->Buffer, exit, err = mStatus_NoMemoryErr );
size = MultiByteToWideChar( CP_UTF8, 0, input, -1, output->Buffer, size );
err = translate_errno( size > 0, GetLastError(), kUnknownErr );
require_noerr( err, exit );
// We're going to subtrace one wchar_t from the size, because we didn't
// include it when we encoded the string
output->MaximumLength = output->Length;
output->Length -= sizeof( wchar_t );
exit:
if ( err && output->Buffer )
{
free( output->Buffer );
output->Buffer = NULL;
}
return( err );
}
//===========================================================================================================================
// MakeUTF8StringFromLsaString
//===========================================================================================================================
mDNSlocal OSStatus
MakeUTF8StringFromLsaString( char * output, size_t len, PLSA_UNICODE_STRING input )
{
size_t size;
OSStatus err = kNoErr;
// The Length field of this structure holds the number of bytes,
// but WideCharToMultiByte expects the number of wchar_t's. So
// we divide by sizeof(wchar_t) to get the correct number.
size = (size_t) WideCharToMultiByte(CP_UTF8, 0, input->Buffer, ( input->Length / sizeof( wchar_t ) ), NULL, 0, NULL, NULL);
err = translate_errno( size != 0, GetLastError(), kUnknownErr );
require_noerr( err, exit );
// Ensure that we have enough space (Add one for trailing '\0')
require_action( ( size + 1 ) <= len, exit, err = mStatus_NoMemoryErr );
// Convert the string
size = (size_t) WideCharToMultiByte( CP_UTF8, 0, input->Buffer, ( input->Length / sizeof( wchar_t ) ), output, (int) size, NULL, NULL);
err = translate_errno( size != 0, GetLastError(), kUnknownErr );
require_noerr( err, exit );
// have to add the trailing 0 because WideCharToMultiByte doesn't do it,
// although it does return the correct size
output[size] = '\0';
exit:
return err;
}