|
|
@ -38,7 +38,7 @@ static void OneTimeInit() {
|
|
|
|
// certificate, and presenting one to some arbitrary server
|
|
|
|
// certificate, and presenting one to some arbitrary server
|
|
|
|
// might be a privacy concern? Who knows, though.
|
|
|
|
// might be a privacy concern? Who knows, though.
|
|
|
|
|
|
|
|
|
|
|
|
SECURITY_STATUS ret =
|
|
|
|
const SECURITY_STATUS ret =
|
|
|
|
AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
|
|
|
|
AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND,
|
|
|
|
nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
|
|
|
|
nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr);
|
|
|
|
if (ret != SEC_E_OK) {
|
|
|
|
if (ret != SEC_E_OK) {
|
|
|
@ -121,14 +121,14 @@ public:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Result FillCiphertextReadBuf() {
|
|
|
|
Result FillCiphertextReadBuf() {
|
|
|
|
size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096;
|
|
|
|
const size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096;
|
|
|
|
read_buf_fill_size_ = 0;
|
|
|
|
read_buf_fill_size_ = 0;
|
|
|
|
// This unnecessarily zeroes the buffer; oh well.
|
|
|
|
// This unnecessarily zeroes the buffer; oh well.
|
|
|
|
size_t offset = ciphertext_read_buf_.size();
|
|
|
|
const size_t offset = ciphertext_read_buf_.size();
|
|
|
|
ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
|
|
|
|
ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; });
|
|
|
|
ciphertext_read_buf_.resize(offset + fill_size, 0);
|
|
|
|
ciphertext_read_buf_.resize(offset + fill_size, 0);
|
|
|
|
auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size);
|
|
|
|
const auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size);
|
|
|
|
auto [actual, err] = socket_->Recv(0, read_span);
|
|
|
|
const auto [actual, err] = socket_->Recv(0, read_span);
|
|
|
|
switch (err) {
|
|
|
|
switch (err) {
|
|
|
|
case Network::Errno::SUCCESS:
|
|
|
|
case Network::Errno::SUCCESS:
|
|
|
|
ASSERT(static_cast<size_t>(actual) <= fill_size);
|
|
|
|
ASSERT(static_cast<size_t>(actual) <= fill_size);
|
|
|
@ -147,7 +147,7 @@ public:
|
|
|
|
// Returns success if the write buffer has been completely emptied.
|
|
|
|
// Returns success if the write buffer has been completely emptied.
|
|
|
|
Result FlushCiphertextWriteBuf() {
|
|
|
|
Result FlushCiphertextWriteBuf() {
|
|
|
|
while (!ciphertext_write_buf_.empty()) {
|
|
|
|
while (!ciphertext_write_buf_.empty()) {
|
|
|
|
auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0);
|
|
|
|
const auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0);
|
|
|
|
switch (err) {
|
|
|
|
switch (err) {
|
|
|
|
case Network::Errno::SUCCESS:
|
|
|
|
case Network::Errno::SUCCESS:
|
|
|
|
ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size());
|
|
|
|
ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size());
|
|
|
@ -165,9 +165,10 @@ public:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Result CallInitializeSecurityContext() {
|
|
|
|
Result CallInitializeSecurityContext() {
|
|
|
|
unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY |
|
|
|
|
const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY |
|
|
|
|
ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
|
|
|
|
ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT |
|
|
|
|
ISC_REQ_USE_SUPPLIED_CREDS;
|
|
|
|
ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM |
|
|
|
|
|
|
|
|
ISC_REQ_USE_SUPPLIED_CREDS;
|
|
|
|
unsigned long attr;
|
|
|
|
unsigned long attr;
|
|
|
|
// https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
|
|
|
|
// https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
|
|
|
|
std::array<SecBuffer, 2> input_buffers{{
|
|
|
|
std::array<SecBuffer, 2> input_buffers{{
|
|
|
@ -219,7 +220,7 @@ public:
|
|
|
|
ciphertext_read_buf_.size());
|
|
|
|
ciphertext_read_buf_.size());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
SECURITY_STATUS ret =
|
|
|
|
const SECURITY_STATUS ret =
|
|
|
|
InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr,
|
|
|
|
InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr,
|
|
|
|
// Caller ensured we have set a hostname:
|
|
|
|
// Caller ensured we have set a hostname:
|
|
|
|
const_cast<char*>(hostname_.value().c_str()), req,
|
|
|
|
const_cast<char*>(hostname_.value().c_str()), req,
|
|
|
@ -231,15 +232,15 @@ public:
|
|
|
|
nullptr); // ptsExpiry
|
|
|
|
nullptr); // ptsExpiry
|
|
|
|
|
|
|
|
|
|
|
|
if (output_buffers[0].pvBuffer) {
|
|
|
|
if (output_buffers[0].pvBuffer) {
|
|
|
|
std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
|
|
|
|
const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer),
|
|
|
|
output_buffers[0].cbBuffer);
|
|
|
|
output_buffers[0].cbBuffer);
|
|
|
|
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end());
|
|
|
|
ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end());
|
|
|
|
FreeContextBuffer(output_buffers[0].pvBuffer);
|
|
|
|
FreeContextBuffer(output_buffers[0].pvBuffer);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (output_buffers[1].pvBuffer) {
|
|
|
|
if (output_buffers[1].pvBuffer) {
|
|
|
|
std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
|
|
|
|
const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer),
|
|
|
|
output_buffers[1].cbBuffer);
|
|
|
|
output_buffers[1].cbBuffer);
|
|
|
|
// The documentation doesn't explain what format this data is in.
|
|
|
|
// The documentation doesn't explain what format this data is in.
|
|
|
|
LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
|
|
|
|
LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(),
|
|
|
|
Common::HexToString(span));
|
|
|
|
Common::HexToString(span));
|
|
|
@ -280,7 +281,7 @@ public:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Result GrabStreamSizes() {
|
|
|
|
Result GrabStreamSizes() {
|
|
|
|
SECURITY_STATUS ret =
|
|
|
|
const SECURITY_STATUS ret =
|
|
|
|
QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_);
|
|
|
|
QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_);
|
|
|
|
if (ret != SEC_E_OK) {
|
|
|
|
if (ret != SEC_E_OK) {
|
|
|
|
LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
|
|
|
|
LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}",
|
|
|
@ -301,7 +302,7 @@ public:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
while (1) {
|
|
|
|
while (1) {
|
|
|
|
if (!cleartext_read_buf_.empty()) {
|
|
|
|
if (!cleartext_read_buf_.empty()) {
|
|
|
|
size_t read_size = std::min(cleartext_read_buf_.size(), data.size());
|
|
|
|
const size_t read_size = std::min(cleartext_read_buf_.size(), data.size());
|
|
|
|
std::memcpy(data.data(), cleartext_read_buf_.data(), read_size);
|
|
|
|
std::memcpy(data.data(), cleartext_read_buf_.data(), read_size);
|
|
|
|
cleartext_read_buf_.erase(cleartext_read_buf_.begin(),
|
|
|
|
cleartext_read_buf_.erase(cleartext_read_buf_.begin(),
|
|
|
|
cleartext_read_buf_.begin() + read_size);
|
|
|
|
cleartext_read_buf_.begin() + read_size);
|
|
|
@ -366,7 +367,7 @@ public:
|
|
|
|
return ResultInternalError;
|
|
|
|
return ResultInternalError;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Result r = FillCiphertextReadBuf();
|
|
|
|
const Result r = FillCiphertextReadBuf();
|
|
|
|
if (r != ResultSuccess) {
|
|
|
|
if (r != ResultSuccess) {
|
|
|
|
return r;
|
|
|
|
return r;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -430,7 +431,7 @@ public:
|
|
|
|
.pBuffers = buffers.data(),
|
|
|
|
.pBuffers = buffers.data(),
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
|
|
|
|
const SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0);
|
|
|
|
if (ret != SEC_E_OK) {
|
|
|
|
if (ret != SEC_E_OK) {
|
|
|
|
LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
|
|
|
|
LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret));
|
|
|
|
return ResultInternalError;
|
|
|
|
return ResultInternalError;
|
|
|
@ -445,19 +446,19 @@ public:
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ResultVal<size_t> WriteAlreadyEncryptedData() {
|
|
|
|
ResultVal<size_t> WriteAlreadyEncryptedData() {
|
|
|
|
Result r = FlushCiphertextWriteBuf();
|
|
|
|
const Result r = FlushCiphertextWriteBuf();
|
|
|
|
if (r != ResultSuccess) {
|
|
|
|
if (r != ResultSuccess) {
|
|
|
|
return r;
|
|
|
|
return r;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// write buf is empty
|
|
|
|
// write buf is empty
|
|
|
|
size_t cleartext_bytes_written = cleartext_write_buf_.size();
|
|
|
|
const size_t cleartext_bytes_written = cleartext_write_buf_.size();
|
|
|
|
cleartext_write_buf_.clear();
|
|
|
|
cleartext_write_buf_.clear();
|
|
|
|
return cleartext_bytes_written;
|
|
|
|
return cleartext_bytes_written;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
|
|
|
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override {
|
|
|
|
PCCERT_CONTEXT returned_cert = nullptr;
|
|
|
|
PCCERT_CONTEXT returned_cert = nullptr;
|
|
|
|
SECURITY_STATUS ret =
|
|
|
|
const SECURITY_STATUS ret =
|
|
|
|
QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
|
|
|
|
QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert);
|
|
|
|
if (ret != SEC_E_OK) {
|
|
|
|
if (ret != SEC_E_OK) {
|
|
|
|
LOG_ERROR(Service_SSL,
|
|
|
|
LOG_ERROR(Service_SSL,
|
|
|
@ -527,7 +528,7 @@ public:
|
|
|
|
|
|
|
|
|
|
|
|
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
|
|
|
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() {
|
|
|
|
auto conn = std::make_unique<SSLConnectionBackendSchannel>();
|
|
|
|
auto conn = std::make_unique<SSLConnectionBackendSchannel>();
|
|
|
|
Result res = conn->Init();
|
|
|
|
const Result res = conn->Init();
|
|
|
|
if (res.IsFailure()) {
|
|
|
|
if (res.IsFailure()) {
|
|
|
|
return res;
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
}
|
|
|
|