summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest/tls_filter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'nss/gtests/ssl_gtest/tls_filter.cc')
-rw-r--r--nss/gtests/ssl_gtest/tls_filter.cc326
1 files changed, 240 insertions, 86 deletions
diff --git a/nss/gtests/ssl_gtest/tls_filter.cc b/nss/gtests/ssl_gtest/tls_filter.cc
index 4f7d195..76d9aaa 100644
--- a/nss/gtests/ssl_gtest/tls_filter.cc
+++ b/nss/gtests/ssl_gtest/tls_filter.cc
@@ -15,9 +15,62 @@ extern "C" {
#include <iostream>
#include "gtest_utils.h"
#include "tls_agent.h"
+#include "tls_filter.h"
+#include "tls_protect.h"
namespace nss_test {
+void TlsVersioned::WriteStream(std::ostream& stream) const {
+ stream << (is_dtls() ? "DTLS " : "TLS ");
+ switch (version()) {
+ case 0:
+ stream << "(no version)";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ stream << "1.0";
+ break;
+ case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE:
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ stream << (is_dtls() ? "1.0" : "1.1");
+ break;
+ case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE:
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ stream << "1.2";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ stream << "1.3";
+ break;
+ default:
+ stream << "Invalid version: " << version();
+ break;
+ }
+}
+
+void TlsRecordFilter::EnableDecryption() {
+ SSLInt_SetCipherSpecChangeFunc(agent()->ssl_fd(), CipherSpecChanged,
+ (void*)this);
+}
+
+void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
+ ssl3CipherSpec* newSpec) {
+ TlsRecordFilter* self = static_cast<TlsRecordFilter*>(arg);
+ PRBool isServer = self->agent()->role() == TlsAgent::SERVER;
+
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "Cipher spec changed. Role="
+ << (isServer ? "server" : "client")
+ << " direction=" << (sending ? "send" : "receive") << std::endl;
+ }
+ if (!sending) return;
+
+ self->cipher_spec_.reset(new TlsCipherSpec());
+ bool ret =
+ self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec),
+ SSLInt_CipherSpecToKey(isServer, newSpec),
+ SSLInt_CipherSpecToIv(isServer, newSpec));
+ EXPECT_EQ(true, ret);
+}
+
PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
bool changed = false;
@@ -25,10 +78,13 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
output->Allocate(input.len());
TlsParser parser(input);
+
while (parser.remaining()) {
- RecordHeader header;
+ TlsRecordHeader header;
DataBuffer record;
+
if (!header.Parse(&parser, &record)) {
+ ADD_FAILURE() << "not a valid record";
return KEEP;
}
@@ -49,12 +105,21 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
return KEEP;
}
-PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header,
- const DataBuffer& record,
- size_t* offset,
- DataBuffer* output) {
+PacketFilter::Action TlsRecordFilter::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& record, size_t* offset,
+ DataBuffer* output) {
DataBuffer filtered;
- PacketFilter::Action action = FilterRecord(header, record, &filtered);
+ uint8_t inner_content_type;
+ DataBuffer plaintext;
+
+ if (!Unprotect(header, record, &inner_content_type, &plaintext)) {
+ return KEEP;
+ }
+
+ TlsRecordHeader real_header = {header.version(), inner_content_type,
+ header.sequence_number()};
+
+ PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
if (action == KEEP) {
return KEEP;
}
@@ -64,19 +129,21 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header,
return DROP;
}
- const DataBuffer* source = &record;
- if (action == CHANGE) {
- EXPECT_GT(0x10000U, filtered.len());
- std::cerr << "record old: " << record << std::endl;
- std::cerr << "record new: " << filtered << std::endl;
- source = &filtered;
- }
+ EXPECT_GT(0x10000U, filtered.len());
+ std::cerr << "record old: " << plaintext << std::endl;
+ std::cerr << "record new: " << filtered << std::endl;
- *offset = header.Write(output, *offset, *source);
+ DataBuffer ciphertext;
+ bool rv = Protect(header, inner_content_type, filtered, &ciphertext);
+ EXPECT_TRUE(rv);
+ if (!rv) {
+ return KEEP;
+ }
+ *offset = header.Write(output, *offset, ciphertext);
return CHANGE;
}
-bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
+bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
if (!parser->Read(&content_type_)) {
return false;
}
@@ -102,8 +169,8 @@ bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
return parser->ReadVariable(body, 2);
}
-size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset,
- const DataBuffer& body) const {
+size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const {
offset = buffer->Write(offset, content_type_, 1);
offset = buffer->Write(offset, version_, 2);
if (is_dtls()) {
@@ -116,8 +183,48 @@ size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset,
return offset;
}
+bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
+ const DataBuffer& ciphertext,
+ uint8_t* inner_content_type,
+ DataBuffer* plaintext) {
+ if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) {
+ *inner_content_type = header.content_type();
+ *plaintext = ciphertext;
+ return true;
+ }
+
+ if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false;
+
+ size_t len = plaintext->len();
+ while (len > 0 && !plaintext->data()[len - 1]) {
+ --len;
+ }
+ if (!len) {
+ // Bogus padding.
+ return false;
+ }
+
+ *inner_content_type = plaintext->data()[len - 1];
+ plaintext->Truncate(len - 1);
+
+ return true;
+}
+
+bool TlsRecordFilter::Protect(const TlsRecordHeader& header,
+ uint8_t inner_content_type,
+ const DataBuffer& plaintext,
+ DataBuffer* ciphertext) {
+ if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) {
+ *ciphertext = plaintext;
+ return true;
+ }
+ DataBuffer padded = plaintext;
+ padded.Write(padded.len(), inner_content_type, 1);
+ return cipher_spec_->Protect(header, padded, ciphertext);
+}
+
PacketFilter::Action TlsHandshakeFilter::FilterRecord(
- const RecordHeader& record_header, const DataBuffer& input,
+ const TlsRecordHeader& record_header, const DataBuffer& input,
DataBuffer* output) {
// Check that the first byte is as requested.
if (record_header.content_type() != kTlsHandshakeType) {
@@ -159,9 +266,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
return changed ? (offset ? CHANGE : DROP) : KEEP;
}
-bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser,
- const RecordHeader& header,
- uint32_t* length) {
+bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
+ TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) {
if (!parser->Read(length, 3)) {
return false; // malformed
}
@@ -192,7 +298,7 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser,
}
bool TlsHandshakeFilter::HandshakeHeader::Parse(
- TlsParser* parser, const RecordHeader& record_header, DataBuffer* body) {
+ TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) {
version_ = record_header.version();
if (!parser->Read(&handshake_type_)) {
return false; // malformed
@@ -205,15 +311,28 @@ bool TlsHandshakeFilter::HandshakeHeader::Parse(
return parser->Read(body, length);
}
-size_t TlsHandshakeFilter::HandshakeHeader::Write(
- DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
+size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body,
+ size_t fragment_offset, size_t fragment_length) const {
+ EXPECT_TRUE(is_dtls());
+ EXPECT_GE(body.len(), fragment_offset + fragment_length);
offset = buffer->Write(offset, handshake_type(), 1);
offset = buffer->Write(offset, body.len(), 3);
+ offset = buffer->Write(offset, message_seq_, 2);
+ offset = buffer->Write(offset, fragment_offset, 3);
+ offset = buffer->Write(offset, fragment_length, 3);
+ offset =
+ buffer->Write(offset, body.data() + fragment_offset, fragment_length);
+ return offset;
+}
+
+size_t TlsHandshakeFilter::HandshakeHeader::Write(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
if (is_dtls()) {
- offset = buffer->Write(offset, message_seq_, 2);
- offset = buffer->Write(offset, 0U, 3); // fragment_offset
- offset = buffer->Write(offset, body.len(), 3);
+ return WriteFragment(buffer, offset, body, 0U, body.len());
}
+ offset = buffer->Write(offset, handshake_type(), 1);
+ offset = buffer->Write(offset, body.len(), 3);
offset = buffer->Write(offset, body);
return offset;
}
@@ -244,42 +363,12 @@ PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
}
PacketFilter::Action TlsConversationRecorder::FilterRecord(
- const RecordHeader& header, const DataBuffer& input, DataBuffer* output) {
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
buffer_.Append(input);
return KEEP;
}
-PacketFilter::Action TlsAlertRecorder::FilterRecord(const RecordHeader& header,
- const DataBuffer& input,
- DataBuffer* output) {
- if (level_ == kTlsAlertFatal) { // already fatal
- return KEEP;
- }
- if (header.content_type() != kTlsAlertType) {
- return KEEP;
- }
-
- std::cerr << "Alert: " << input << std::endl;
-
- TlsParser parser(input);
- uint8_t lvl;
- if (!parser.Read(&lvl)) {
- return KEEP;
- }
- if (lvl == kTlsAlertWarning) { // not strong enough
- return KEEP;
- }
- level_ = lvl;
- (void)parser.Read(&description_);
- return KEEP;
-}
-
-ChainedPacketFilter::~ChainedPacketFilter() {
- for (auto it = filters_.begin(); it != filters_.end(); ++it) {
- delete *it;
- }
-}
-
PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
DataBuffer in(input);
@@ -297,28 +386,7 @@ PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
return changed ? CHANGE : KEEP;
}
-PacketFilter::Action TlsExtensionFilter::FilterHandshake(
- const HandshakeHeader& header, const DataBuffer& input,
- DataBuffer* output) {
- if (header.handshake_type() == kTlsHandshakeClientHello) {
- TlsParser parser(input);
- if (!FindClientHelloExtensions(&parser, header)) {
- return KEEP;
- }
- return FilterExtensions(&parser, input, output);
- }
- if (header.handshake_type() == kTlsHandshakeServerHello) {
- TlsParser parser(input);
- if (!FindServerHelloExtensions(&parser)) {
- return KEEP;
- }
- return FilterExtensions(&parser, input, output);
- }
- return KEEP;
-}
-
-bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser,
- const Versioned& header) {
+bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
if (!parser->Skip(2 + 32)) { // version + random
return false;
}
@@ -337,7 +405,7 @@ bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser,
return true;
}
-bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) {
+bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
uint32_t vtmp;
if (!parser->Read(&vtmp, 2)) {
return false;
@@ -362,6 +430,92 @@ bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) {
return true;
}
+static bool FindHelloRetryExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ // TODO for -19 add cipher suite
+ if (!parser->Skip(2)) { // version
+ return false;
+ }
+ return true;
+}
+
+bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) {
+ return true;
+}
+
+static bool FindCertReqExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->SkipVariable(1)) { // request context
+ return false;
+ }
+ // TODO remove the next two for -19
+ if (!parser->SkipVariable(2)) { // signature_algorithms
+ return false;
+ }
+ if (!parser->SkipVariable(2)) { // certificate_authorities
+ return false;
+ }
+ return true;
+}
+
+// Only look at the EE cert for this one.
+static bool FindCertificateExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->SkipVariable(1)) { // request context
+ return false;
+ }
+ if (!parser->Skip(3)) { // length of certificate list
+ return false;
+ }
+ if (!parser->SkipVariable(3)) { // ASN1Cert
+ return false;
+ }
+ return true;
+}
+
+static bool FindNewSessionTicketExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->Skip(8)) { // lifetime, age add
+ return false;
+ }
+ if (!parser->SkipVariable(2)) { // ticket
+ return false;
+ }
+ return true;
+}
+
+static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = {
+ {kTlsHandshakeClientHello, FindClientHelloExtensions},
+ {kTlsHandshakeServerHello, FindServerHelloExtensions},
+ {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions},
+ {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions},
+ {kTlsHandshakeCertificateRequest, FindCertReqExtensions},
+ {kTlsHandshakeCertificate, FindCertificateExtensions},
+ {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}};
+
+bool TlsExtensionFilter::FindExtensions(TlsParser* parser,
+ const HandshakeHeader& header) {
+ auto it = kExtensionFinders.find(header.handshake_type());
+ if (it == kExtensionFinders.end()) {
+ return false;
+ }
+ return (it->second)(parser, header);
+}
+
+PacketFilter::Action TlsExtensionFilter::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (handshake_types_.count(header.handshake_type()) == 0) {
+ return KEEP;
+ }
+
+ TlsParser parser(input);
+ if (!FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ return FilterExtensions(&parser, input, output);
+}
+
PacketFilter::Action TlsExtensionFilter::FilterExtensions(
TlsParser* parser, const DataBuffer& input, DataBuffer* output) {
size_t length_offset = parser->consumed();
@@ -456,14 +610,14 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension(
return KEEP;
}
-PacketFilter::Action AfterRecordN::FilterRecord(const RecordHeader& header,
+PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) {
if (counter_++ == record_) {
DataBuffer buf;
header.Write(&buf, 0, body);
- src_->SendDirect(buf);
- dest_->Handshake();
+ src_.lock()->SendDirect(buf);
+ dest_.lock()->Handshake();
func_();
return DROP;
}
@@ -476,7 +630,7 @@ PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake(
DataBuffer* output) {
if (header.handshake_type() == kTlsHandshakeClientKeyExchange) {
EXPECT_EQ(SECSuccess,
- SSLInt_IncrementClientHandshakeVersion(server_->ssl_fd()));
+ SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
}
return KEEP;
}