summaryrefslogtreecommitdiff
path: root/libgo/go/crypto/tls/handshake_messages.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/crypto/tls/handshake_messages.go')
-rw-r--r--libgo/go/crypto/tls/handshake_messages.go98
1 files changed, 92 insertions, 6 deletions
diff --git a/libgo/go/crypto/tls/handshake_messages.go b/libgo/go/crypto/tls/handshake_messages.go
index 5d14871a348..799a776799a 100644
--- a/libgo/go/crypto/tls/handshake_messages.go
+++ b/libgo/go/crypto/tls/handshake_messages.go
@@ -16,6 +16,7 @@ type clientHelloMsg struct {
nextProtoNeg bool
serverName string
ocspStapling bool
+ scts bool
supportedCurves []CurveID
supportedPoints []uint8
ticketSupported bool
@@ -40,6 +41,7 @@ func (m *clientHelloMsg) equal(i interface{}) bool {
m.nextProtoNeg == m1.nextProtoNeg &&
m.serverName == m1.serverName &&
m.ocspStapling == m1.ocspStapling &&
+ m.scts == m1.scts &&
eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
m.ticketSupported == m1.ticketSupported &&
@@ -99,6 +101,9 @@ func (m *clientHelloMsg) marshal() []byte {
}
numExtensions++
}
+ if m.scts {
+ numExtensions++
+ }
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
@@ -271,6 +276,13 @@ func (m *clientHelloMsg) marshal() []byte {
lengths[0] = byte(stringsLength >> 8)
lengths[1] = byte(stringsLength)
}
+ if m.scts {
+ // https://tools.ietf.org/html/rfc6962#section-3.3.1
+ z[0] = byte(extensionSCT >> 8)
+ z[1] = byte(extensionSCT)
+ // zero uint16 for the zero-length extension_data
+ z = z[4:]
+ }
m.raw = x
@@ -326,6 +338,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.sessionTicket = nil
m.signatureAndHashes = nil
m.alpnProtocols = nil
+ m.scts = false
if len(data) == 0 {
// ClientHello is optionally followed by extension data
@@ -354,12 +367,16 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
switch extension {
case extensionServerName:
- if length < 2 {
+ d := data[:length]
+ if len(d) < 2 {
return false
}
- numNames := int(data[0])<<8 | int(data[1])
- d := data[2:]
- for i := 0; i < numNames; i++ {
+ namesLen := int(d[0])<<8 | int(d[1])
+ d = d[2:]
+ if len(d) != namesLen {
+ return false
+ }
+ for len(d) > 0 {
if len(d) < 3 {
return false
}
@@ -370,7 +387,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return false
}
if nameType == 0 {
- m.serverName = string(d[0:nameLen])
+ m.serverName = string(d[:nameLen])
break
}
d = d[nameLen:]
@@ -430,7 +447,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.signatureAndHashes[i].signature = d[1]
d = d[2:]
}
- case extensionRenegotiationInfo + 1:
+ case extensionRenegotiationInfo:
if length != 1 || data[0] != 0 {
return false
}
@@ -453,6 +470,11 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
d = d[stringLen:]
}
+ case extensionSCT:
+ m.scts = true
+ if length != 0 {
+ return false
+ }
}
data = data[length:]
}
@@ -470,6 +492,7 @@ type serverHelloMsg struct {
nextProtoNeg bool
nextProtos []string
ocspStapling bool
+ scts [][]byte
ticketSupported bool
secureRenegotiation bool
alpnProtocol string
@@ -481,6 +504,15 @@ func (m *serverHelloMsg) equal(i interface{}) bool {
return false
}
+ if len(m.scts) != len(m1.scts) {
+ return false
+ }
+ for i, sct := range m.scts {
+ if !bytes.Equal(sct, m1.scts[i]) {
+ return false
+ }
+ }
+
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
@@ -530,6 +562,14 @@ func (m *serverHelloMsg) marshal() []byte {
extensionsLength += 2 + 1 + alpnLen
numExtensions++
}
+ sctLen := 0
+ if len(m.scts) > 0 {
+ for _, sct := range m.scts {
+ sctLen += len(sct) + 2
+ }
+ extensionsLength += 2 + sctLen
+ numExtensions++
+ }
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
@@ -605,6 +645,23 @@ func (m *serverHelloMsg) marshal() []byte {
copy(z[7:], []byte(m.alpnProtocol))
z = z[7+alpnLen:]
}
+ if sctLen > 0 {
+ z[0] = byte(extensionSCT >> 8)
+ z[1] = byte(extensionSCT)
+ l := sctLen + 2
+ z[2] = byte(l >> 8)
+ z[3] = byte(l)
+ z[4] = byte(sctLen >> 8)
+ z[5] = byte(sctLen)
+
+ z = z[6:]
+ for _, sct := range m.scts {
+ z[0] = byte(len(sct) >> 8)
+ z[1] = byte(len(sct))
+ copy(z[2:], sct)
+ z = z[len(sct)+2:]
+ }
+ }
m.raw = x
@@ -634,6 +691,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
m.nextProtoNeg = false
m.nextProtos = nil
m.ocspStapling = false
+ m.scts = nil
m.ticketSupported = false
m.alpnProtocol = ""
@@ -706,6 +764,34 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
}
d = d[1:]
m.alpnProtocol = string(d)
+ case extensionSCT:
+ d := data[:length]
+
+ if len(d) < 2 {
+ return false
+ }
+ l := int(d[0])<<8 | int(d[1])
+ d = d[2:]
+ if len(d) != l {
+ return false
+ }
+ if l == 0 {
+ continue
+ }
+
+ m.scts = make([][]byte, 0, 3)
+ for len(d) != 0 {
+ if len(d) < 2 {
+ return false
+ }
+ sctLen := int(d[0])<<8 | int(d[1])
+ d = d[2:]
+ if len(d) < sctLen {
+ return false
+ }
+ m.scts = append(m.scts, d[:sctLen])
+ d = d[sctLen:]
+ }
}
data = data[length:]
}