summaryrefslogtreecommitdiff
path: root/hkdf.h
blob: cd64fb3918ce5e98557de3fda45e7d41ed81adca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// hkdf.h - written and placed in public domain by Jeffrey Walton. Copyright assigned to Crypto++ project.

#ifndef CRYPTOPP_HASH_KEY_DERIVATION_FUNCTION_H
#define CRYPTOPP_HASH_KEY_DERIVATION_FUNCTION_H

#include "cryptlib.h"
#include "hrtimer.h"
#include "secblock.h"
#include "hmac.h"

NAMESPACE_BEGIN(CryptoPP)

//! abstract base class for key derivation function
class KeyDerivationFunction
{
public:
	//! maximum number of bytes which can be produced under a secuirty context
	virtual size_t MaxDerivedKeyLength() const =0;
	virtual bool Usesinfo() const =0;
	//! derive a key from secret
	virtual unsigned int DeriveKey(byte *derived, size_t derivedLen, const byte *secret, size_t secretLen, const byte *salt, size_t saltLen, const byte* info=NULL, size_t infoLen=0) const =0;

	virtual ~KeyDerivationFunction() {}
};

//! General, multipurpose KDF from RFC 5869. T should be a HashTransformation class
//! https://eprint.iacr.org/2010/264 and https://tools.ietf.org/html/rfc5869
template <class T>
class HKDF : public KeyDerivationFunction
{
public:
	CRYPTOPP_CONSTANT(DIGESTSIZE = T::DIGESTSIZE);
	CRYPTOPP_CONSTANT(SALTSIZE = T::DIGESTSIZE);
	static const char* StaticAlgorithmName () {
		static const std::string name(std::string("HKDF(") + std::string(T::StaticAlgorithmName()) + std::string(")"));
		return name.c_str();
	}
	size_t MaxDerivedKeyLength() const {return static_cast<size_t>(T::DIGESTSIZE) * 255;}
	bool Usesinfo() const {return true;}
	unsigned int DeriveKey(byte *derived, size_t derivedLen, const byte *secret, size_t secretLen, const byte *salt, size_t saltLen, const byte* info, size_t infoLen) const;
	
protected:
	// If salt is missing (NULL), then use the NULL vector. Missing is different than EMPTY (0 length). The length
	// of s_NullVector used depends on the Hash function. SHA-256 will use 32 bytes of s_NullVector.
	typedef byte NullVectorType[SALTSIZE];
	static const NullVectorType& GetNullVector() {
		static const NullVectorType s_NullVector = {0};
		return s_NullVector;
	}
};

template <class T>
unsigned int HKDF<T>::DeriveKey(byte *derived, size_t derivedLen, const byte *secret, size_t secretLen, const byte *salt, size_t saltLen, const byte* info, size_t infoLen) const
{
	static const size_t DIGEST_SIZE = static_cast<size_t>(T::DIGESTSIZE);
	const unsigned int req = static_cast<unsigned int>(derivedLen);
	
	assert(secret && secretLen);
	assert(derived && derivedLen);
	assert(derivedLen <= MaxDerivedKeyLength());

	if (derivedLen > MaxDerivedKeyLength())
		throw InvalidArgument("HKDF: derivedLen must be less than or equal to MaxDerivedKeyLength");

	HMAC<T> hmac;
	FixedSizeSecBlock<byte, DIGEST_SIZE> prk, buffer;

	// Extract
	const byte* key = (salt ? salt : GetNullVector());
	const size_t klen = (salt ? saltLen : DIGEST_SIZE);

	hmac.SetKey(key, klen);
	hmac.CalculateDigest(prk, secret, secretLen);

	// Expand
	hmac.SetKey(prk.data(), prk.size());
	byte block = 0;

	while (derivedLen > 0)
	{
		if (block++) {hmac.Update(buffer, buffer.size());}
		if (info && infoLen) {hmac.Update(info, infoLen);}
		hmac.CalculateDigest(buffer, &block, 1);

#if CRYPTOPP_MSC_VERSION
		const size_t segmentLen = STDMIN(derivedLen, DIGEST_SIZE);
		memcpy_s(derived, segmentLen, buffer, segmentLen);
#else
		const size_t segmentLen = STDMIN(derivedLen, DIGEST_SIZE);
		std::memcpy(derived, buffer, segmentLen);
#endif

		derived += segmentLen;
		derivedLen -= segmentLen;
	}

	return req;
}

NAMESPACE_END

#endif // CRYPTOPP_HASH_KEY_DERIVATION_FUNCTION_H