summaryrefslogtreecommitdiff
path: root/libgo/go/crypto/tls/prf.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/crypto/tls/prf.go')
-rw-r--r--libgo/go/crypto/tls/prf.go178
1 files changed, 127 insertions, 51 deletions
diff --git a/libgo/go/crypto/tls/prf.go b/libgo/go/crypto/tls/prf.go
index fb8b3ab4d1e..6127c1ccfe7 100644
--- a/libgo/go/crypto/tls/prf.go
+++ b/libgo/go/crypto/tls/prf.go
@@ -10,6 +10,8 @@ import (
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
+ "crypto/sha512"
+ "errors"
"hash"
)
@@ -65,12 +67,14 @@ func prf10(result, secret, label, seed []byte) {
}
// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, section 5.
-func prf12(result, secret, label, seed []byte) {
- labelAndSeed := make([]byte, len(label)+len(seed))
- copy(labelAndSeed, label)
- copy(labelAndSeed[len(label):], seed)
+func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) {
+ return func(result, secret, label, seed []byte) {
+ labelAndSeed := make([]byte, len(label)+len(seed))
+ copy(labelAndSeed, label)
+ copy(labelAndSeed[len(label):], seed)
- pHash(result, secret, labelAndSeed, sha256.New)
+ pHash(result, secret, labelAndSeed, hashFunc)
+ }
}
// prf30 implements the SSL 3.0 pseudo-random function, as defined in
@@ -117,41 +121,49 @@ var keyExpansionLabel = []byte("key expansion")
var clientFinishedLabel = []byte("client finished")
var serverFinishedLabel = []byte("server finished")
-func prfForVersion(version uint16) func(result, secret, label, seed []byte) {
+func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) {
switch version {
case VersionSSL30:
- return prf30
+ return prf30, crypto.Hash(0)
case VersionTLS10, VersionTLS11:
- return prf10
+ return prf10, crypto.Hash(0)
case VersionTLS12:
- return prf12
+ if suite.flags&suiteSHA384 != 0 {
+ return prf12(sha512.New384), crypto.SHA384
+ }
+ return prf12(sha256.New), crypto.SHA256
default:
panic("unknown version")
}
}
+func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) {
+ prf, _ := prfAndHashForVersion(version, suite)
+ return prf
+}
+
// masterFromPreMasterSecret generates the master secret from the pre-master
// secret. See http://tools.ietf.org/html/rfc5246#section-8.1
-func masterFromPreMasterSecret(version uint16, preMasterSecret, clientRandom, serverRandom []byte) []byte {
+func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte {
var seed [tlsRandomLength * 2]byte
copy(seed[0:len(clientRandom)], clientRandom)
copy(seed[len(clientRandom):], serverRandom)
masterSecret := make([]byte, masterSecretLength)
- prfForVersion(version)(masterSecret, preMasterSecret, masterSecretLabel, seed[0:])
+ prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed[0:])
return masterSecret
}
// keysFromMasterSecret generates the connection keys from the master
// secret, given the lengths of the MAC key, cipher key and IV, as defined in
// RFC 2246, section 6.3.
-func keysFromMasterSecret(version uint16, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
+func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) {
var seed [tlsRandomLength * 2]byte
copy(seed[0:len(clientRandom)], serverRandom)
copy(seed[len(serverRandom):], clientRandom)
n := 2*macLen + 2*keyLen + 2*ivLen
keyMaterial := make([]byte, n)
- prfForVersion(version)(keyMaterial, masterSecret, keyExpansionLabel, seed[0:])
+ prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed[0:])
clientMAC = keyMaterial[:macLen]
keyMaterial = keyMaterial[macLen:]
serverMAC = keyMaterial[:macLen]
@@ -166,11 +178,33 @@ func keysFromMasterSecret(version uint16, masterSecret, clientRandom, serverRand
return
}
-func newFinishedHash(version uint16) finishedHash {
- if version >= VersionTLS12 {
- return finishedHash{sha256.New(), sha256.New(), nil, nil, version}
+// lookupTLSHash looks up the corresponding crypto.Hash for a given
+// TLS hash identifier.
+func lookupTLSHash(hash uint8) (crypto.Hash, error) {
+ switch hash {
+ case hashSHA1:
+ return crypto.SHA1, nil
+ case hashSHA256:
+ return crypto.SHA256, nil
+ case hashSHA384:
+ return crypto.SHA384, nil
+ default:
+ return 0, errors.New("tls: unsupported hash algorithm")
+ }
+}
+
+func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash {
+ var buffer []byte
+ if version == VersionSSL30 || version >= VersionTLS12 {
+ buffer = []byte{}
}
- return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), version}
+
+ prf, hash := prfAndHashForVersion(version, cipherSuite)
+ if hash != 0 {
+ return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf}
+ }
+
+ return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf}
}
// A finishedHash calculates the hash of a set of handshake messages suitable
@@ -183,10 +217,14 @@ type finishedHash struct {
clientMD5 hash.Hash
serverMD5 hash.Hash
+ // In TLS 1.2, a full buffer is sadly required.
+ buffer []byte
+
version uint16
+ prf func(result, secret, label, seed []byte)
}
-func (h finishedHash) Write(msg []byte) (n int, err error) {
+func (h *finishedHash) Write(msg []byte) (n int, err error) {
h.client.Write(msg)
h.server.Write(msg)
@@ -194,14 +232,29 @@ func (h finishedHash) Write(msg []byte) (n int, err error) {
h.clientMD5.Write(msg)
h.serverMD5.Write(msg)
}
+
+ if h.buffer != nil {
+ h.buffer = append(h.buffer, msg...)
+ }
+
return len(msg), nil
}
+func (h finishedHash) Sum() []byte {
+ if h.version >= VersionTLS12 {
+ return h.client.Sum(nil)
+ }
+
+ out := make([]byte, 0, md5.Size+sha1.Size)
+ out = h.clientMD5.Sum(out)
+ return h.client.Sum(out)
+}
+
// finishedSum30 calculates the contents of the verify_data member of a SSLv3
// Finished message given the MD5 and SHA1 hashes of a set of handshake
// messages.
-func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic [4]byte) []byte {
- md5.Write(magic[:])
+func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic []byte) []byte {
+ md5.Write(magic)
md5.Write(masterSecret)
md5.Write(ssl30Pad1[:])
md5Digest := md5.Sum(nil)
@@ -212,7 +265,7 @@ func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic [4]byte) []by
md5.Write(md5Digest)
md5Digest = md5.Sum(nil)
- sha1.Write(magic[:])
+ sha1.Write(magic)
sha1.Write(masterSecret)
sha1.Write(ssl30Pad1[:40])
sha1Digest := sha1.Sum(nil)
@@ -236,19 +289,11 @@ var ssl3ServerFinishedMagic = [4]byte{0x53, 0x52, 0x56, 0x52}
// Finished message.
func (h finishedHash) clientSum(masterSecret []byte) []byte {
if h.version == VersionSSL30 {
- return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic)
+ return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic[:])
}
out := make([]byte, finishedVerifyLength)
- if h.version >= VersionTLS12 {
- seed := h.client.Sum(nil)
- prf12(out, masterSecret, clientFinishedLabel, seed)
- } else {
- seed := make([]byte, 0, md5.Size+sha1.Size)
- seed = h.clientMD5.Sum(seed)
- seed = h.client.Sum(seed)
- prf10(out, masterSecret, clientFinishedLabel, seed)
- }
+ h.prf(out, masterSecret, clientFinishedLabel, h.Sum())
return out
}
@@ -256,36 +301,67 @@ func (h finishedHash) clientSum(masterSecret []byte) []byte {
// Finished message.
func (h finishedHash) serverSum(masterSecret []byte) []byte {
if h.version == VersionSSL30 {
- return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic)
+ return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic[:])
}
out := make([]byte, finishedVerifyLength)
- if h.version >= VersionTLS12 {
- seed := h.server.Sum(nil)
- prf12(out, masterSecret, serverFinishedLabel, seed)
- } else {
- seed := make([]byte, 0, md5.Size+sha1.Size)
- seed = h.serverMD5.Sum(seed)
- seed = h.server.Sum(seed)
- prf10(out, masterSecret, serverFinishedLabel, seed)
- }
+ h.prf(out, masterSecret, serverFinishedLabel, h.Sum())
return out
}
+// selectClientCertSignatureAlgorithm returns a signatureAndHash to sign a
+// client's CertificateVerify with, or an error if none can be found.
+func (h finishedHash) selectClientCertSignatureAlgorithm(serverList []signatureAndHash, sigType uint8) (signatureAndHash, error) {
+ if h.version < VersionTLS12 {
+ // Nothing to negotiate before TLS 1.2.
+ return signatureAndHash{signature: sigType}, nil
+ }
+
+ for _, v := range serverList {
+ if v.signature == sigType && isSupportedSignatureAndHash(v, supportedSignatureAlgorithms) {
+ return v, nil
+ }
+ }
+ return signatureAndHash{}, errors.New("tls: no supported signature algorithm found for signing client certificate")
+}
+
// hashForClientCertificate returns a digest, hash function, and TLS 1.2 hash
// id suitable for signing by a TLS client certificate.
-func (h finishedHash) hashForClientCertificate(sigType uint8) ([]byte, crypto.Hash, uint8) {
+func (h finishedHash) hashForClientCertificate(signatureAndHash signatureAndHash, masterSecret []byte) ([]byte, crypto.Hash, error) {
+ if (h.version == VersionSSL30 || h.version >= VersionTLS12) && h.buffer == nil {
+ panic("a handshake hash for a client-certificate was requested after discarding the handshake buffer")
+ }
+
+ if h.version == VersionSSL30 {
+ if signatureAndHash.signature != signatureRSA {
+ return nil, 0, errors.New("tls: unsupported signature type for client certificate")
+ }
+
+ md5Hash := md5.New()
+ md5Hash.Write(h.buffer)
+ sha1Hash := sha1.New()
+ sha1Hash.Write(h.buffer)
+ return finishedSum30(md5Hash, sha1Hash, masterSecret, nil), crypto.MD5SHA1, nil
+ }
if h.version >= VersionTLS12 {
- digest := h.server.Sum(nil)
- return digest, crypto.SHA256, hashSHA256
+ hashAlg, err := lookupTLSHash(signatureAndHash.hash)
+ if err != nil {
+ return nil, 0, err
+ }
+ hash := hashAlg.New()
+ hash.Write(h.buffer)
+ return hash.Sum(nil), hashAlg, nil
}
- if sigType == signatureECDSA {
- digest := h.server.Sum(nil)
- return digest, crypto.SHA1, hashSHA1
+
+ if signatureAndHash.signature == signatureECDSA {
+ return h.server.Sum(nil), crypto.SHA1, nil
}
- digest := make([]byte, 0, 36)
- digest = h.serverMD5.Sum(digest)
- digest = h.server.Sum(digest)
- return digest, crypto.MD5SHA1, 0 /* not specified in TLS 1.2. */
+ return h.Sum(), crypto.MD5SHA1, nil
+}
+
+// discardHandshakeBuffer is called when there is no more need to
+// buffer the entirety of the handshake messages.
+func (h *finishedHash) discardHandshakeBuffer() {
+ h.buffer = nil
}