diff options
Diffstat (limited to 'libgo/go/crypto/tls/prf.go')
-rw-r--r-- | libgo/go/crypto/tls/prf.go | 178 |
1 files changed, 127 insertions, 51 deletions
diff --git a/libgo/go/crypto/tls/prf.go b/libgo/go/crypto/tls/prf.go index fb8b3ab4d1e..6127c1ccfe7 100644 --- a/libgo/go/crypto/tls/prf.go +++ b/libgo/go/crypto/tls/prf.go @@ -10,6 +10,8 @@ import ( "crypto/md5" "crypto/sha1" "crypto/sha256" + "crypto/sha512" + "errors" "hash" ) @@ -65,12 +67,14 @@ func prf10(result, secret, label, seed []byte) { } // prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, section 5. -func prf12(result, secret, label, seed []byte) { - labelAndSeed := make([]byte, len(label)+len(seed)) - copy(labelAndSeed, label) - copy(labelAndSeed[len(label):], seed) +func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { + return func(result, secret, label, seed []byte) { + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) - pHash(result, secret, labelAndSeed, sha256.New) + pHash(result, secret, labelAndSeed, hashFunc) + } } // prf30 implements the SSL 3.0 pseudo-random function, as defined in @@ -117,41 +121,49 @@ var keyExpansionLabel = []byte("key expansion") var clientFinishedLabel = []byte("client finished") var serverFinishedLabel = []byte("server finished") -func prfForVersion(version uint16) func(result, secret, label, seed []byte) { +func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { switch version { case VersionSSL30: - return prf30 + return prf30, crypto.Hash(0) case VersionTLS10, VersionTLS11: - return prf10 + return prf10, crypto.Hash(0) case VersionTLS12: - return prf12 + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384), crypto.SHA384 + } + return prf12(sha256.New), crypto.SHA256 default: panic("unknown version") } } +func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { + prf, _ := prfAndHashForVersion(version, suite) + return prf +} + // masterFromPreMasterSecret generates the master secret from the pre-master // secret. See http://tools.ietf.org/html/rfc5246#section-8.1 -func masterFromPreMasterSecret(version uint16, preMasterSecret, clientRandom, serverRandom []byte) []byte { +func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { var seed [tlsRandomLength * 2]byte copy(seed[0:len(clientRandom)], clientRandom) copy(seed[len(clientRandom):], serverRandom) masterSecret := make([]byte, masterSecretLength) - prfForVersion(version)(masterSecret, preMasterSecret, masterSecretLabel, seed[0:]) + prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed[0:]) return masterSecret } // keysFromMasterSecret generates the connection keys from the master // secret, given the lengths of the MAC key, cipher key and IV, as defined in // RFC 2246, section 6.3. -func keysFromMasterSecret(version uint16, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { +func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { var seed [tlsRandomLength * 2]byte copy(seed[0:len(clientRandom)], serverRandom) copy(seed[len(serverRandom):], clientRandom) n := 2*macLen + 2*keyLen + 2*ivLen keyMaterial := make([]byte, n) - prfForVersion(version)(keyMaterial, masterSecret, keyExpansionLabel, seed[0:]) + prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed[0:]) clientMAC = keyMaterial[:macLen] keyMaterial = keyMaterial[macLen:] serverMAC = keyMaterial[:macLen] @@ -166,11 +178,33 @@ func keysFromMasterSecret(version uint16, masterSecret, clientRandom, serverRand return } -func newFinishedHash(version uint16) finishedHash { - if version >= VersionTLS12 { - return finishedHash{sha256.New(), sha256.New(), nil, nil, version} +// lookupTLSHash looks up the corresponding crypto.Hash for a given +// TLS hash identifier. +func lookupTLSHash(hash uint8) (crypto.Hash, error) { + switch hash { + case hashSHA1: + return crypto.SHA1, nil + case hashSHA256: + return crypto.SHA256, nil + case hashSHA384: + return crypto.SHA384, nil + default: + return 0, errors.New("tls: unsupported hash algorithm") + } +} + +func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { + var buffer []byte + if version == VersionSSL30 || version >= VersionTLS12 { + buffer = []byte{} } - return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), version} + + prf, hash := prfAndHashForVersion(version, cipherSuite) + if hash != 0 { + return finishedHash{hash.New(), hash.New(), nil, nil, buffer, version, prf} + } + + return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), buffer, version, prf} } // A finishedHash calculates the hash of a set of handshake messages suitable @@ -183,10 +217,14 @@ type finishedHash struct { clientMD5 hash.Hash serverMD5 hash.Hash + // In TLS 1.2, a full buffer is sadly required. + buffer []byte + version uint16 + prf func(result, secret, label, seed []byte) } -func (h finishedHash) Write(msg []byte) (n int, err error) { +func (h *finishedHash) Write(msg []byte) (n int, err error) { h.client.Write(msg) h.server.Write(msg) @@ -194,14 +232,29 @@ func (h finishedHash) Write(msg []byte) (n int, err error) { h.clientMD5.Write(msg) h.serverMD5.Write(msg) } + + if h.buffer != nil { + h.buffer = append(h.buffer, msg...) + } + return len(msg), nil } +func (h finishedHash) Sum() []byte { + if h.version >= VersionTLS12 { + return h.client.Sum(nil) + } + + out := make([]byte, 0, md5.Size+sha1.Size) + out = h.clientMD5.Sum(out) + return h.client.Sum(out) +} + // finishedSum30 calculates the contents of the verify_data member of a SSLv3 // Finished message given the MD5 and SHA1 hashes of a set of handshake // messages. -func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic [4]byte) []byte { - md5.Write(magic[:]) +func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic []byte) []byte { + md5.Write(magic) md5.Write(masterSecret) md5.Write(ssl30Pad1[:]) md5Digest := md5.Sum(nil) @@ -212,7 +265,7 @@ func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic [4]byte) []by md5.Write(md5Digest) md5Digest = md5.Sum(nil) - sha1.Write(magic[:]) + sha1.Write(magic) sha1.Write(masterSecret) sha1.Write(ssl30Pad1[:40]) sha1Digest := sha1.Sum(nil) @@ -236,19 +289,11 @@ var ssl3ServerFinishedMagic = [4]byte{0x53, 0x52, 0x56, 0x52} // Finished message. func (h finishedHash) clientSum(masterSecret []byte) []byte { if h.version == VersionSSL30 { - return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic) + return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic[:]) } out := make([]byte, finishedVerifyLength) - if h.version >= VersionTLS12 { - seed := h.client.Sum(nil) - prf12(out, masterSecret, clientFinishedLabel, seed) - } else { - seed := make([]byte, 0, md5.Size+sha1.Size) - seed = h.clientMD5.Sum(seed) - seed = h.client.Sum(seed) - prf10(out, masterSecret, clientFinishedLabel, seed) - } + h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) return out } @@ -256,36 +301,67 @@ func (h finishedHash) clientSum(masterSecret []byte) []byte { // Finished message. func (h finishedHash) serverSum(masterSecret []byte) []byte { if h.version == VersionSSL30 { - return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic) + return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic[:]) } out := make([]byte, finishedVerifyLength) - if h.version >= VersionTLS12 { - seed := h.server.Sum(nil) - prf12(out, masterSecret, serverFinishedLabel, seed) - } else { - seed := make([]byte, 0, md5.Size+sha1.Size) - seed = h.serverMD5.Sum(seed) - seed = h.server.Sum(seed) - prf10(out, masterSecret, serverFinishedLabel, seed) - } + h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) return out } +// selectClientCertSignatureAlgorithm returns a signatureAndHash to sign a +// client's CertificateVerify with, or an error if none can be found. +func (h finishedHash) selectClientCertSignatureAlgorithm(serverList []signatureAndHash, sigType uint8) (signatureAndHash, error) { + if h.version < VersionTLS12 { + // Nothing to negotiate before TLS 1.2. + return signatureAndHash{signature: sigType}, nil + } + + for _, v := range serverList { + if v.signature == sigType && isSupportedSignatureAndHash(v, supportedSignatureAlgorithms) { + return v, nil + } + } + return signatureAndHash{}, errors.New("tls: no supported signature algorithm found for signing client certificate") +} + // hashForClientCertificate returns a digest, hash function, and TLS 1.2 hash // id suitable for signing by a TLS client certificate. -func (h finishedHash) hashForClientCertificate(sigType uint8) ([]byte, crypto.Hash, uint8) { +func (h finishedHash) hashForClientCertificate(signatureAndHash signatureAndHash, masterSecret []byte) ([]byte, crypto.Hash, error) { + if (h.version == VersionSSL30 || h.version >= VersionTLS12) && h.buffer == nil { + panic("a handshake hash for a client-certificate was requested after discarding the handshake buffer") + } + + if h.version == VersionSSL30 { + if signatureAndHash.signature != signatureRSA { + return nil, 0, errors.New("tls: unsupported signature type for client certificate") + } + + md5Hash := md5.New() + md5Hash.Write(h.buffer) + sha1Hash := sha1.New() + sha1Hash.Write(h.buffer) + return finishedSum30(md5Hash, sha1Hash, masterSecret, nil), crypto.MD5SHA1, nil + } if h.version >= VersionTLS12 { - digest := h.server.Sum(nil) - return digest, crypto.SHA256, hashSHA256 + hashAlg, err := lookupTLSHash(signatureAndHash.hash) + if err != nil { + return nil, 0, err + } + hash := hashAlg.New() + hash.Write(h.buffer) + return hash.Sum(nil), hashAlg, nil } - if sigType == signatureECDSA { - digest := h.server.Sum(nil) - return digest, crypto.SHA1, hashSHA1 + + if signatureAndHash.signature == signatureECDSA { + return h.server.Sum(nil), crypto.SHA1, nil } - digest := make([]byte, 0, 36) - digest = h.serverMD5.Sum(digest) - digest = h.server.Sum(digest) - return digest, crypto.MD5SHA1, 0 /* not specified in TLS 1.2. */ + return h.Sum(), crypto.MD5SHA1, nil +} + +// discardHandshakeBuffer is called when there is no more need to +// buffer the entirety of the handshake messages. +func (h *finishedHash) discardHandshakeBuffer() { + h.buffer = nil } |