summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest/tls_filter.h
diff options
context:
space:
mode:
Diffstat (limited to 'nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r--nss/gtests/ssl_gtest/tls_filter.h231
1 files changed, 152 insertions, 79 deletions
diff --git a/nss/gtests/ssl_gtest/tls_filter.h b/nss/gtests/ssl_gtest/tls_filter.h
index fa2e387..e4030e2 100644
--- a/nss/gtests/ssl_gtest/tls_filter.h
+++ b/nss/gtests/ssl_gtest/tls_filter.h
@@ -9,17 +9,67 @@
#include <functional>
#include <memory>
+#include <set>
#include <vector>
#include "test_io.h"
#include "tls_parser.h"
+#include "tls_protect.h"
+
+extern "C" {
+#include "libssl_internals.h"
+}
namespace nss_test {
+class TlsCipherSpec;
+class TlsAgent;
+
+class TlsVersioned {
+ public:
+ TlsVersioned() : version_(0) {}
+ explicit TlsVersioned(uint16_t version) : version_(version) {}
+
+ bool is_dtls() const { return IsDtls(version_); }
+ uint16_t version() const { return version_; }
+
+ void WriteStream(std::ostream& stream) const;
+
+ protected:
+ uint16_t version_;
+};
+
+class TlsRecordHeader : public TlsVersioned {
+ public:
+ TlsRecordHeader() : TlsVersioned(), content_type_(0), sequence_number_(0) {}
+ TlsRecordHeader(uint16_t version, uint8_t content_type,
+ uint64_t sequence_number)
+ : TlsVersioned(version),
+ content_type_(content_type),
+ sequence_number_(sequence_number) {}
+
+ uint8_t content_type() const { return content_type_; }
+ uint64_t sequence_number() const { return sequence_number_; }
+ size_t header_length() const { return is_dtls() ? 11 : 3; }
+
+ // Parse the header; return true if successful; body in an outparam if OK.
+ bool Parse(TlsParser* parser, DataBuffer* body);
+ // Write the header and body to a buffer at the given offset.
+ // Return the offset of the end of the write.
+ size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
+
+ private:
+ uint8_t content_type_;
+ uint64_t sequence_number_;
+};
+
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter() : count_(0) {}
+ TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {}
+
+ void SetAgent(const TlsAgent* agent) { agent_ = agent; }
+ const TlsAgent* agent() const { return agent_; }
// External interface. Overrides PacketFilter.
PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
@@ -27,42 +77,14 @@ class TlsRecordFilter : public PacketFilter {
// Report how many packets were altered by the filter.
size_t filtered_packets() const { return count_; }
- class Versioned {
- public:
- Versioned() : version_(0) {}
- explicit Versioned(uint16_t version) : version_(version) {}
-
- bool is_dtls() const { return IsDtls(version_); }
- uint16_t version() const { return version_; }
-
- protected:
- uint16_t version_;
- };
-
- class RecordHeader : public Versioned {
- public:
- RecordHeader() : Versioned(), content_type_(0), sequence_number_(0) {}
- RecordHeader(uint16_t version, uint8_t content_type,
- uint64_t sequence_number)
- : Versioned(version),
- content_type_(content_type),
- sequence_number_(sequence_number) {}
-
- uint8_t content_type() const { return content_type_; }
- uint64_t sequence_number() const { return sequence_number_; }
- size_t header_length() const { return is_dtls() ? 11 : 3; }
-
- // Parse the header; return true if successful; body in an outparam if OK.
- bool Parse(TlsParser* parser, DataBuffer* body);
- // Write the header and body to a buffer at the given offset.
- // Return the offset of the end of the write.
- size_t Write(DataBuffer* buffer, size_t offset,
- const DataBuffer& body) const;
-
- private:
- uint8_t content_type_;
- uint64_t sequence_number_;
- };
+ // Enable decryption. This only works properly for TLS 1.3 and above.
+ // Enabling it for lower version tests will cause undefined
+ // behavior.
+ void EnableDecryption();
+ bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
+ uint8_t* inner_content_type, DataBuffer* plaintext);
+ bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type,
+ const DataBuffer& plaintext, DataBuffer* ciphertext);
protected:
// There are two filter functions which can be overriden. Both are
@@ -72,7 +94,7 @@ class TlsRecordFilter : public PacketFilter {
// just lets you change the record contents. By default, the
// outer one calls the inner one, so if you override the outer
// one, the inner one is never called unless you call it yourself.
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& record,
size_t* offset, DataBuffer* output);
@@ -80,16 +102,49 @@ class TlsRecordFilter : public PacketFilter {
// sequence number (which is zero for TLS), plus the existing record payload.
// It returns an action (KEEP, CHANGE, DROP). It writes to the `changed`
// outparam with the new record contents if it chooses to CHANGE the record.
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
DataBuffer* changed) {
return KEEP;
}
private:
+ static void CipherSpecChanged(void* arg, PRBool sending,
+ ssl3CipherSpec* newSpec);
+
+ const TlsAgent* agent_;
size_t count_;
+ std::unique_ptr<TlsCipherSpec> cipher_spec_;
};
+inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) {
+ v.WriteStream(stream);
+ return stream;
+}
+
+inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
+ hdr.WriteStream(stream);
+ stream << ' ';
+ switch (hdr.content_type()) {
+ case kTlsChangeCipherSpecType:
+ stream << "CCS";
+ break;
+ case kTlsAlertType:
+ stream << "Alert";
+ break;
+ case kTlsHandshakeType:
+ stream << "Handshake";
+ break;
+ case kTlsApplicationDataType:
+ stream << "Data";
+ break;
+ default:
+ stream << '<' << hdr.content_type() << '>';
+ break;
+ }
+ return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
+}
+
// Abstract filter that operates on handshake messages rather than records.
// This assumes that the handshake messages are written in a block as entire
// records and that they don't span records or anything crazy like that.
@@ -97,20 +152,23 @@ class TlsHandshakeFilter : public TlsRecordFilter {
public:
TlsHandshakeFilter() {}
- class HandshakeHeader : public Versioned {
+ class HandshakeHeader : public TlsVersioned {
public:
- HandshakeHeader() : Versioned(), handshake_type_(0), message_seq_(0) {}
+ HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {}
uint8_t handshake_type() const { return handshake_type_; }
- bool Parse(TlsParser* parser, const RecordHeader& record_header,
+ bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
DataBuffer* body);
size_t Write(DataBuffer* buffer, size_t offset,
const DataBuffer& body) const;
+ size_t WriteFragment(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body, size_t fragment_offset,
+ size_t fragment_length) const;
private:
// Reads the length from the record header.
// This also reads the DTLS fragment information and checks it.
- bool ReadLength(TlsParser* parser, const RecordHeader& header,
+ bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
uint32_t* length);
uint8_t handshake_type_;
@@ -119,7 +177,7 @@ class TlsHandshakeFilter : public TlsRecordFilter {
};
protected:
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -167,7 +225,7 @@ class TlsConversationRecorder : public TlsRecordFilter {
public:
TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {}
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
@@ -175,43 +233,39 @@ class TlsConversationRecorder : public TlsRecordFilter {
DataBuffer& buffer_;
};
-// Records an alert. If an alert has already been recorded, it won't save the
-// new alert unless the old alert is a warning and the new one is fatal.
-class TlsAlertRecorder : public TlsRecordFilter {
- public:
- TlsAlertRecorder() : level_(255), description_(255) {}
-
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
- const DataBuffer& input,
- DataBuffer* output);
-
- uint8_t level() const { return level_; }
- uint8_t description() const { return description_; }
-
- private:
- uint8_t level_;
- uint8_t description_;
-};
-
// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
ChainedPacketFilter() {}
- ChainedPacketFilter(const std::vector<PacketFilter*> filters)
+ ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
: filters_(filters.begin(), filters.end()) {}
- virtual ~ChainedPacketFilter();
+ virtual ~ChainedPacketFilter() {}
virtual PacketFilter::Action Filter(const DataBuffer& input,
DataBuffer* output);
// Takes ownership of the filter.
- void Add(PacketFilter* filter) { filters_.push_back(filter); }
+ void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); }
private:
- std::vector<PacketFilter*> filters_;
+ std::vector<std::shared_ptr<PacketFilter>> filters_;
};
+typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
+ TlsExtensionFinder;
+
class TlsExtensionFilter : public TlsHandshakeFilter {
+ public:
+ TlsExtensionFilter() : handshake_types_() {
+ handshake_types_.insert(kTlsHandshakeClientHello);
+ handshake_types_.insert(kTlsHandshakeServerHello);
+ }
+
+ TlsExtensionFilter(const std::set<uint8_t>& types)
+ : handshake_types_(types) {}
+
+ static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
+
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -221,15 +275,12 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
const DataBuffer& input,
DataBuffer* output) = 0;
- public:
- static bool FindClientHelloExtensions(TlsParser* parser,
- const Versioned& header);
- static bool FindServerHelloExtensions(TlsParser* parser);
-
private:
PacketFilter::Action FilterExtensions(TlsParser* parser,
const DataBuffer& input,
DataBuffer* output);
+
+ std::set<uint8_t> handshake_types_;
};
class TlsExtensionCapture : public TlsExtensionFilter {
@@ -280,17 +331,17 @@ typedef std::function<void(void)> VoidFunction;
class AfterRecordN : public TlsRecordFilter {
public:
- AfterRecordN(TlsAgent* src, TlsAgent* dest, unsigned int record,
- VoidFunction func)
+ AfterRecordN(std::shared_ptr<TlsAgent>& src, std::shared_ptr<TlsAgent>& dest,
+ unsigned int record, VoidFunction func)
: src_(src), dest_(dest), record_(record), func_(func), counter_(0) {}
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) override;
private:
- TlsAgent* src_;
- TlsAgent* dest_;
+ std::weak_ptr<TlsAgent> src_;
+ std::weak_ptr<TlsAgent> dest_;
unsigned int record_;
VoidFunction func_;
unsigned int counter_;
@@ -300,14 +351,15 @@ class AfterRecordN : public TlsRecordFilter {
// ClientHelloVersion on |server|.
class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionChanger(TlsAgent* server) : server_(server) {}
+ TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server)
+ : server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
- TlsAgent* server_;
+ std::weak_ptr<TlsAgent> server_;
};
// This class selectively drops complete writes. This relies on the fact that
@@ -338,6 +390,27 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
uint16_t version_;
};
+// Damages the last byte of a handshake message.
+class TlsLastByteDamager : public TlsHandshakeFilter {
+ public:
+ TlsLastByteDamager(uint8_t type) : type_(type) {}
+ PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) override {
+ if (header.handshake_type() != type_) {
+ return KEEP;
+ }
+
+ *output = input;
+
+ output->data()[output->len() - 1]++;
+ return CHANGE;
+ }
+
+ private:
+ uint8_t type_;
+};
+
} // namespace nss_test
#endif