// Copyright 2010 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/http/mock_sspi_library_win.h" #include #include #include #include #include "base/check_op.h" #include "base/memory/raw_ptr.h" #include "base/strings/string_util_win.h" #include "base/strings/stringprintf.h" #include "base/strings/utf_string_conversions.h" #include "base/time/time.h" #include "testing/gtest/include/gtest/gtest.h" // Comparator so we can use CredHandle and CtxtHandle with std::set. Both of // those classes are typedefs for _SecHandle. bool operator<(const _SecHandle left, const _SecHandle right) { return left.dwUpper < right.dwUpper || left.dwLower < right.dwLower; } namespace net { namespace { int uniquifier_ = 0; struct MockCredential { std::u16string source_principal; std::u16string package; bool has_explicit_credentials = false; int uniquifier = ++uniquifier_; // CredHandle and CtxtHandle both shared the following definition: // // typedef struct _SecHandle { // ULONG_PTR dwLower; // ULONG_PTR dwUpper; // } SecHandle, * PSecHandle; // // ULONG_PTR type can hold a pointer. This function stuffs |this| into dwUpper // and adds a uniquifier to dwLower. This ensures that all PCredHandles issued // by this method during the lifetime of this process is unique. void StoreInHandle(PCredHandle handle) { DCHECK(uniquifier > 0); EXPECT_FALSE(SecIsValidHandle(handle)); handle->dwLower = uniquifier; handle->dwUpper = reinterpret_cast(this); DCHECK(SecIsValidHandle(handle)); } static MockCredential* FromHandle(PCredHandle handle) { return reinterpret_cast(handle->dwUpper); } }; struct MockContext { raw_ptr credential = nullptr; std::u16string target_principal; int uniquifier = ++uniquifier_; int rounds = 0; // CredHandle and CtxtHandle both shared the following definition: // // typedef struct _SecHandle { // ULONG_PTR dwLower; // ULONG_PTR dwUpper; // } SecHandle, * PSecHandle; // // ULONG_PTR type can hold a pointer. This function stuffs |this| into dwUpper // and adds a uniquifier to dwLower. This ensures that all PCredHandles issued // by this method during the lifetime of this process is unique. void StoreInHandle(PCtxtHandle handle) { EXPECT_FALSE(SecIsValidHandle(handle)); DCHECK(uniquifier > 0); handle->dwLower = uniquifier; handle->dwUpper = reinterpret_cast(this); DCHECK(SecIsValidHandle(handle)); } std::string ToString() const { return base::StringPrintf( "%s's token #%d for %S", base::UTF16ToUTF8(credential->source_principal).c_str(), rounds + 1, base::as_wcstr(target_principal)); } static MockContext* FromHandle(PCtxtHandle handle) { return reinterpret_cast(handle->dwUpper); } }; } // namespace MockSSPILibrary::MockSSPILibrary(const wchar_t* package) : SSPILibrary(package) {} MockSSPILibrary::~MockSSPILibrary() { EXPECT_TRUE(expected_package_queries_.empty()); EXPECT_TRUE(expected_freed_packages_.empty()); EXPECT_TRUE(active_credentials_.empty()); EXPECT_TRUE(active_contexts_.empty()); } SECURITY_STATUS MockSSPILibrary::AcquireCredentialsHandle( LPWSTR pszPrincipal, unsigned long fCredentialUse, void* pvLogonId, void* pvAuthData, SEC_GET_KEY_FN pGetKeyFn, void* pvGetKeyArgument, PCredHandle phCredential, PTimeStamp ptsExpiry) { DCHECK(!SecIsValidHandle(phCredential)); auto* credential = new MockCredential; credential->source_principal = pszPrincipal ? base::as_u16cstr(pszPrincipal) : u""; credential->package = base::as_u16cstr(package_name_.c_str()); credential->has_explicit_credentials = !!pvAuthData; credential->StoreInHandle(phCredential); if (ptsExpiry) { ptsExpiry->LowPart = 0xBAA5B780; ptsExpiry->HighPart = 0x01D54E17; } active_credentials_.insert(*phCredential); return SEC_E_OK; } SECURITY_STATUS MockSSPILibrary::InitializeSecurityContext( PCredHandle phCredential, PCtxtHandle phContext, SEC_WCHAR* pszTargetName, unsigned long fContextReq, unsigned long Reserved1, unsigned long TargetDataRep, PSecBufferDesc pInput, unsigned long Reserved2, PCtxtHandle phNewContext, PSecBufferDesc pOutput, unsigned long* contextAttr, PTimeStamp ptsExpiry) { MockContext* new_context = new MockContext; new_context->credential = MockCredential::FromHandle(phCredential); new_context->target_principal = base::as_u16cstr(pszTargetName); new_context->rounds = 0; // Always rotate contexts. That way tests will fail if the caller's context // management is broken. if (phContext && SecIsValidHandle(phContext)) { std::unique_ptr old_context{ MockContext::FromHandle(phContext)}; EXPECT_EQ(old_context->credential, new_context->credential); EXPECT_EQ(1u, active_contexts_.erase(*phContext)); new_context->rounds = old_context->rounds + 1; SecInvalidateHandle(phContext); } new_context->StoreInHandle(phNewContext); active_contexts_.insert(*phNewContext); auto token = new_context->ToString(); PSecBuffer out_buffer = pOutput->pBuffers; out_buffer->cbBuffer = std::min(out_buffer->cbBuffer, token.size()); std::memcpy(out_buffer->pvBuffer, token.data(), out_buffer->cbBuffer); if (ptsExpiry) { ptsExpiry->LowPart = 0xBAA5B780; ptsExpiry->HighPart = 0x01D54E15; } return SEC_E_OK; } SECURITY_STATUS MockSSPILibrary::QueryContextAttributesEx(PCtxtHandle phContext, ULONG ulAttribute, PVOID pBuffer, ULONG cbBuffer) { static const SecPkgInfoW kNegotiatedPackage = { 0, 0, 0, 0, const_cast(L"Itsa me Kerberos!!"), const_cast(L"I like turtles")}; auto* context = MockContext::FromHandle(phContext); switch (ulAttribute) { case SECPKG_ATTR_NATIVE_NAMES: { auto* native_names = reinterpret_cast(pBuffer); DCHECK_EQ(sizeof(*native_names), cbBuffer); native_names->sClientName = base::as_writable_wcstr(context->credential->source_principal); native_names->sServerName = base::as_writable_wcstr(context->target_principal); return SEC_E_OK; } case SECPKG_ATTR_NEGOTIATION_INFO: { auto* negotiation_info = reinterpret_cast(pBuffer); DCHECK_EQ(sizeof(*negotiation_info), cbBuffer); negotiation_info->PackageInfo = const_cast(&kNegotiatedPackage); negotiation_info->NegotiationState = (context->rounds == 1) ? SECPKG_NEGOTIATION_COMPLETE : SECPKG_NEGOTIATION_IN_PROGRESS; return SEC_E_OK; } case SECPKG_ATTR_AUTHORITY: { auto* authority = reinterpret_cast(pBuffer); DCHECK_EQ(sizeof(*authority), cbBuffer); authority->sAuthorityName = const_cast(L"Dodgy Server"); return SEC_E_OK; } default: return SEC_E_UNSUPPORTED_FUNCTION; } } SECURITY_STATUS MockSSPILibrary::QuerySecurityPackageInfo( PSecPkgInfoW* pkgInfo) { if (expected_package_queries_.empty()) { static SecPkgInfoW kDefaultPkgInfo{ 0, 0, 0, kDefaultMaxTokenLength, nullptr, nullptr}; *pkgInfo = &kDefaultPkgInfo; expected_freed_packages_.insert(&kDefaultPkgInfo); return SEC_E_OK; } PackageQuery package_query = expected_package_queries_.front(); expected_package_queries_.pop_front(); *pkgInfo = package_query.package_info; if (package_query.response_code == SEC_E_OK) expected_freed_packages_.insert(package_query.package_info); return package_query.response_code; } SECURITY_STATUS MockSSPILibrary::FreeCredentialsHandle( PCredHandle phCredential) { DCHECK(SecIsValidHandle(phCredential)); EXPECT_EQ(1u, active_credentials_.erase(*phCredential)); std::unique_ptr owned{ MockCredential::FromHandle(phCredential)}; SecInvalidateHandle(phCredential); return SEC_E_OK; } SECURITY_STATUS MockSSPILibrary::DeleteSecurityContext(PCtxtHandle phContext) { std::unique_ptr context{MockContext::FromHandle(phContext)}; EXPECT_EQ(1u, active_contexts_.erase(*phContext)); SecInvalidateHandle(phContext); return SEC_E_OK; } SECURITY_STATUS MockSSPILibrary::FreeContextBuffer(PVOID pvContextBuffer) { PSecPkgInfoW package_info = static_cast(pvContextBuffer); std::set::iterator it = expected_freed_packages_.find( package_info); EXPECT_TRUE(it != expected_freed_packages_.end()); expected_freed_packages_.erase(it); return SEC_E_OK; } void MockSSPILibrary::ExpectQuerySecurityPackageInfo( SECURITY_STATUS response_code, PSecPkgInfoW package_info) { expected_package_queries_.emplace_back( PackageQuery{response_code, package_info}); } } // namespace net