diff options
Diffstat (limited to 'libgo/go/crypto/tls/handshake_messages.go')
-rw-r--r-- | libgo/go/crypto/tls/handshake_messages.go | 98 |
1 files changed, 92 insertions, 6 deletions
diff --git a/libgo/go/crypto/tls/handshake_messages.go b/libgo/go/crypto/tls/handshake_messages.go index 5d14871a348..799a776799a 100644 --- a/libgo/go/crypto/tls/handshake_messages.go +++ b/libgo/go/crypto/tls/handshake_messages.go @@ -16,6 +16,7 @@ type clientHelloMsg struct { nextProtoNeg bool serverName string ocspStapling bool + scts bool supportedCurves []CurveID supportedPoints []uint8 ticketSupported bool @@ -40,6 +41,7 @@ func (m *clientHelloMsg) equal(i interface{}) bool { m.nextProtoNeg == m1.nextProtoNeg && m.serverName == m1.serverName && m.ocspStapling == m1.ocspStapling && + m.scts == m1.scts && eqCurveIDs(m.supportedCurves, m1.supportedCurves) && bytes.Equal(m.supportedPoints, m1.supportedPoints) && m.ticketSupported == m1.ticketSupported && @@ -99,6 +101,9 @@ func (m *clientHelloMsg) marshal() []byte { } numExtensions++ } + if m.scts { + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions length += 2 + extensionsLength @@ -271,6 +276,13 @@ func (m *clientHelloMsg) marshal() []byte { lengths[0] = byte(stringsLength >> 8) lengths[1] = byte(stringsLength) } + if m.scts { + // https://tools.ietf.org/html/rfc6962#section-3.3.1 + z[0] = byte(extensionSCT >> 8) + z[1] = byte(extensionSCT) + // zero uint16 for the zero-length extension_data + z = z[4:] + } m.raw = x @@ -326,6 +338,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.sessionTicket = nil m.signatureAndHashes = nil m.alpnProtocols = nil + m.scts = false if len(data) == 0 { // ClientHello is optionally followed by extension data @@ -354,12 +367,16 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { switch extension { case extensionServerName: - if length < 2 { + d := data[:length] + if len(d) < 2 { return false } - numNames := int(data[0])<<8 | int(data[1]) - d := data[2:] - for i := 0; i < numNames; i++ { + namesLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != namesLen { + return false + } + for len(d) > 0 { if len(d) < 3 { return false } @@ -370,7 +387,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } if nameType == 0 { - m.serverName = string(d[0:nameLen]) + m.serverName = string(d[:nameLen]) break } d = d[nameLen:] @@ -430,7 +447,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.signatureAndHashes[i].signature = d[1] d = d[2:] } - case extensionRenegotiationInfo + 1: + case extensionRenegotiationInfo: if length != 1 || data[0] != 0 { return false } @@ -453,6 +470,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) d = d[stringLen:] } + case extensionSCT: + m.scts = true + if length != 0 { + return false + } } data = data[length:] } @@ -470,6 +492,7 @@ type serverHelloMsg struct { nextProtoNeg bool nextProtos []string ocspStapling bool + scts [][]byte ticketSupported bool secureRenegotiation bool alpnProtocol string @@ -481,6 +504,15 @@ func (m *serverHelloMsg) equal(i interface{}) bool { return false } + if len(m.scts) != len(m1.scts) { + return false + } + for i, sct := range m.scts { + if !bytes.Equal(sct, m1.scts[i]) { + return false + } + } + return bytes.Equal(m.raw, m1.raw) && m.vers == m1.vers && bytes.Equal(m.random, m1.random) && @@ -530,6 +562,14 @@ func (m *serverHelloMsg) marshal() []byte { extensionsLength += 2 + 1 + alpnLen numExtensions++ } + sctLen := 0 + if len(m.scts) > 0 { + for _, sct := range m.scts { + sctLen += len(sct) + 2 + } + extensionsLength += 2 + sctLen + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions @@ -605,6 +645,23 @@ func (m *serverHelloMsg) marshal() []byte { copy(z[7:], []byte(m.alpnProtocol)) z = z[7+alpnLen:] } + if sctLen > 0 { + z[0] = byte(extensionSCT >> 8) + z[1] = byte(extensionSCT) + l := sctLen + 2 + z[2] = byte(l >> 8) + z[3] = byte(l) + z[4] = byte(sctLen >> 8) + z[5] = byte(sctLen) + + z = z[6:] + for _, sct := range m.scts { + z[0] = byte(len(sct) >> 8) + z[1] = byte(len(sct)) + copy(z[2:], sct) + z = z[len(sct)+2:] + } + } m.raw = x @@ -634,6 +691,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { m.nextProtoNeg = false m.nextProtos = nil m.ocspStapling = false + m.scts = nil m.ticketSupported = false m.alpnProtocol = "" @@ -706,6 +764,34 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { } d = d[1:] m.alpnProtocol = string(d) + case extensionSCT: + d := data[:length] + + if len(d) < 2 { + return false + } + l := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != l { + return false + } + if l == 0 { + continue + } + + m.scts = make([][]byte, 0, 3) + for len(d) != 0 { + if len(d) < 2 { + return false + } + sctLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) < sctLen { + return false + } + m.scts = append(m.scts, d[:sctLen]) + d = d[sctLen:] + } } data = data[length:] } |