// Copyright (c) 2009 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // See "SSPI Sample Application" at // http://msdn.microsoft.com/en-us/library/aa918273.aspx #include "net/http/http_auth_sspi_win.h" #include "base/base64.h" #include "base/logging.h" #include "base/string_util.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" #include "net/http/http_auth.h" namespace net { HttpAuthSSPI::HttpAuthSSPI(const std::string& scheme, SEC_WCHAR* security_package) : scheme_(scheme), security_package_(security_package), max_token_length_(0) { SecInvalidateHandle(&cred_); SecInvalidateHandle(&ctxt_); } HttpAuthSSPI::~HttpAuthSSPI() { ResetSecurityContext(); if (SecIsValidHandle(&cred_)) { FreeCredentialsHandle(&cred_); SecInvalidateHandle(&cred_); } } bool HttpAuthSSPI::NeedsIdentity() const { return decoded_server_auth_token_.empty(); } bool HttpAuthSSPI::IsFinalRound() const { return !decoded_server_auth_token_.empty(); } void HttpAuthSSPI::ResetSecurityContext() { if (SecIsValidHandle(&ctxt_)) { DeleteSecurityContext(&ctxt_); SecInvalidateHandle(&ctxt_); } } bool HttpAuthSSPI::ParseChallenge(std::string::const_iterator challenge_begin, std::string::const_iterator challenge_end) { // Verify the challenge's auth-scheme. HttpAuth::ChallengeTokenizer challenge_tok(challenge_begin, challenge_end); if (!challenge_tok.valid() || !LowerCaseEqualsASCII(challenge_tok.scheme(), StringToLowerASCII(scheme_).c_str())) return false; // Extract the auth-data. We can't use challenge_tok.GetNext() because // auth-data is base64-encoded and may contain '=' padding at the end, // which would be mistaken for a name=value pair. challenge_begin += scheme_.length(); // Skip over scheme name. HttpUtil::TrimLWS(&challenge_begin, &challenge_end); std::string encoded_auth_token(challenge_begin, challenge_end); int encoded_length = encoded_auth_token.length(); // Strip off any padding. // (See https://bugzilla.mozilla.org/show_bug.cgi?id=230351.) // // Our base64 decoder requires that the length be a multiple of 4. while (encoded_length > 0 && encoded_length % 4 != 0 && encoded_auth_token[encoded_length - 1] == '=') encoded_length--; encoded_auth_token.erase(encoded_length); std::string decoded_auth_token; bool rv = base::Base64Decode(encoded_auth_token, &decoded_auth_token); if (rv) { decoded_server_auth_token_ = decoded_auth_token; } return rv; } int HttpAuthSSPI::GenerateCredentials(const std::wstring& username, const std::wstring& password, const GURL& origin, const HttpRequestInfo* request, const ProxyInfo* proxy, std::string* out_credentials) { // |username| may be in the form "DOMAIN\user". Parse it into the two // components. std::wstring domain; std::wstring user; SplitDomainAndUser(username, &domain, &user); // Initial challenge. if (!IsFinalRound()) { int rv = OnFirstRound(domain, user, password); if (rv != OK) return rv; } void* out_buf; int out_buf_len; int rv = GetNextSecurityToken( origin, static_cast<void *>(const_cast<char *>( decoded_server_auth_token_.c_str())), decoded_server_auth_token_.length(), &out_buf, &out_buf_len); if (rv != OK) return rv; // Base64 encode data in output buffer and prepend the scheme. std::string encode_input(static_cast<char*>(out_buf), out_buf_len); std::string encode_output; bool ok = base::Base64Encode(encode_input, &encode_output); // OK, we are done with |out_buf| free(out_buf); if (!ok) return rv; *out_credentials = scheme_ + " " + encode_output; return OK; } int HttpAuthSSPI::OnFirstRound(const std::wstring& domain, const std::wstring& user, const std::wstring& password) { int rv = DetermineMaxTokenLength(security_package_, &max_token_length_); if (rv != OK) { return rv; } rv = AcquireCredentials(security_package_, domain, user, password, &cred_); return rv; } int HttpAuthSSPI::GetNextSecurityToken( const GURL& origin, const void * in_token, int in_token_len, void** out_token, int* out_token_len) { SECURITY_STATUS status; TimeStamp expiry; DWORD ctxt_attr; CtxtHandle* ctxt_ptr; SecBufferDesc in_buffer_desc, out_buffer_desc; SecBufferDesc* in_buffer_desc_ptr; SecBuffer in_buffer, out_buffer; if (in_token_len > 0) { // Prepare input buffer. in_buffer_desc.ulVersion = SECBUFFER_VERSION; in_buffer_desc.cBuffers = 1; in_buffer_desc.pBuffers = &in_buffer; in_buffer.BufferType = SECBUFFER_TOKEN; in_buffer.cbBuffer = in_token_len; in_buffer.pvBuffer = const_cast<void*>(in_token); ctxt_ptr = &ctxt_; in_buffer_desc_ptr = &in_buffer_desc; } else { // If there is no input token, then we are starting a new authentication // sequence. If we have already initialized our security context, then // we're incorrectly reusing the auth handler for a new sequence. if (SecIsValidHandle(&ctxt_)) { LOG(ERROR) << "Cannot restart authentication sequence"; return ERR_UNEXPECTED; } ctxt_ptr = NULL; in_buffer_desc_ptr = NULL; } // Prepare output buffer. out_buffer_desc.ulVersion = SECBUFFER_VERSION; out_buffer_desc.cBuffers = 1; out_buffer_desc.pBuffers = &out_buffer; out_buffer.BufferType = SECBUFFER_TOKEN; out_buffer.cbBuffer = max_token_length_; out_buffer.pvBuffer = malloc(out_buffer.cbBuffer); if (!out_buffer.pvBuffer) return ERR_OUT_OF_MEMORY; // The service principal name of the destination server. See // http://msdn.microsoft.com/en-us/library/ms677949%28VS.85%29.aspx std::wstring target(L"HTTP/"); target.append(ASCIIToWide(GetHostAndPort(origin))); wchar_t* target_name = const_cast<wchar_t*>(target.c_str()); // This returns a token that is passed to the remote server. status = InitializeSecurityContext(&cred_, // phCredential ctxt_ptr, // phContext target_name, // pszTargetName 0, // fContextReq 0, // Reserved1 (must be 0) SECURITY_NATIVE_DREP, // TargetDataRep in_buffer_desc_ptr, // pInput 0, // Reserved2 (must be 0) &ctxt_, // phNewContext &out_buffer_desc, // pOutput &ctxt_attr, // pfContextAttr &expiry); // ptsExpiry // On success, the function returns SEC_I_CONTINUE_NEEDED on the first call // and SEC_E_OK on the second call. On failure, the function returns an // error code. if (status != SEC_I_CONTINUE_NEEDED && status != SEC_E_OK) { LOG(ERROR) << "InitializeSecurityContext failed: " << status; ResetSecurityContext(); free(out_buffer.pvBuffer); return ERR_UNEXPECTED; // TODO(wtc): map error code. } if (!out_buffer.cbBuffer) { free(out_buffer.pvBuffer); out_buffer.pvBuffer = NULL; } *out_token = out_buffer.pvBuffer; *out_token_len = out_buffer.cbBuffer; return OK; } void SplitDomainAndUser(const std::wstring& combined, std::wstring* domain, std::wstring* user) { size_t backslash_idx = combined.find(L'\\'); if (backslash_idx == std::wstring::npos) { domain->clear(); *user = combined; } else { *domain = combined.substr(0, backslash_idx); *user = combined.substr(backslash_idx + 1); } } int DetermineMaxTokenLength(const std::wstring& package, ULONG* max_token_length) { PSecPkgInfo pkg_info; SECURITY_STATUS status = QuerySecurityPackageInfo( const_cast<wchar_t *>(package.c_str()), &pkg_info); if (status != SEC_E_OK) { LOG(ERROR) << "Security package " << package << " not found"; return ERR_UNEXPECTED; } *max_token_length = pkg_info->cbMaxToken; FreeContextBuffer(pkg_info); return OK; } int AcquireCredentials(const SEC_WCHAR* package, const std::wstring& domain, const std::wstring& user, const std::wstring& password, CredHandle* cred) { SEC_WINNT_AUTH_IDENTITY identity; identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE; identity.User = reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(user.c_str())); identity.UserLength = user.size(); identity.Domain = reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(domain.c_str())); identity.DomainLength = domain.size(); identity.Password = reinterpret_cast<unsigned short*>(const_cast<wchar_t*>(password.c_str())); identity.PasswordLength = password.size(); TimeStamp expiry; // Pass the username/password to get the credentials handle. // Note: If the 5th argument is NULL, it uses the default cached credentials // for the logged in user, which can be used for single sign-on. SECURITY_STATUS status = AcquireCredentialsHandle( NULL, // pszPrincipal const_cast<SEC_WCHAR*>(package), // pszPackage SECPKG_CRED_OUTBOUND, // fCredentialUse NULL, // pvLogonID &identity, // pAuthData NULL, // pGetKeyFn (not used) NULL, // pvGetKeyArgument (not used) cred, // phCredential &expiry); // ptsExpiry if (status != SEC_E_OK) return ERR_UNEXPECTED; return OK; } } // namespace net