summaryrefslogtreecommitdiff
path: root/pssr.cpp
blob: 020cb989172591050696d101bbf0b52c61d69344 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// pssr.cpp - written and placed in the public domain by Wei Dai

#include "pch.h"
#include "pssr.h"

NAMESPACE_BEGIN(CryptoPP)

template<> const byte EMSA2HashId<SHA>::id = 0x33;
template<> const byte EMSA2HashId<RIPEMD160>::id = 0x31;

unsigned int PSSR_MEM_Base::MaxRecoverableLength(unsigned int representativeBitLength, unsigned int hashIdentifierLength, unsigned int digestLength) const
{
	if (AllowRecovery())
	{
		unsigned int saltLen = SaltLen(digestLength);
		unsigned int minPadLen = MinPadLen(digestLength);
		return SaturatingSubtract(representativeBitLength, 8*(minPadLen + saltLen + digestLength + hashIdentifierLength) + 9) / 8;
	}
	return 0;
}

bool PSSR_MEM_Base::IsProbabilistic() const 
{
	return SaltLen(1) > 0;
}

bool PSSR_MEM_Base::AllowNonrecoverablePart() const
{
	return true;
}

bool PSSR_MEM_Base::RecoverablePartFirst() const
{
	return false;
}

void PSSR_MEM_Base::ComputeMessageRepresentative(RandomNumberGenerator &rng, 
	const byte *recoverableMessage, unsigned int recoverableMessageLength,
	HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
	byte *representative, unsigned int representativeBitLength) const
{
	const unsigned int u = hashIdentifier.second + 1;
	const unsigned int representativeByteLength = BitsToBytes(representativeBitLength);
	const unsigned int digestSize = hash.DigestSize();
	const unsigned int saltSize = SaltLen(digestSize);
	byte *const h = representative + representativeByteLength - u - digestSize;

	SecByteBlock digest(digestSize), salt(saltSize);
	hash.Final(digest);
	rng.GenerateBlock(salt, saltSize);

	// compute H = hash of M'
	byte c[8];
	UnalignedPutWord(BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
	UnalignedPutWord(BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
	hash.Update(c, 8);
	hash.Update(recoverableMessage, recoverableMessageLength);
	hash.Update(digest, digestSize);
	hash.Update(salt, saltSize);
	hash.Final(h);

	// compute representative
	GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize, false);
	byte *xorStart = representative + representativeByteLength - u - digestSize - salt.size() - recoverableMessageLength - 1;
	xorStart[0] ^= 1;
	xorbuf(xorStart + 1, recoverableMessage, recoverableMessageLength);
	xorbuf(xorStart + 1 + recoverableMessageLength, salt, salt.size());
	memcpy(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second);
	representative[representativeByteLength - 1] = hashIdentifier.second ? 0xcc : 0xbc;
	if (representativeBitLength % 8 != 0)
		representative[0] = Crop(representative[0], representativeBitLength % 8);
}

DecodingResult PSSR_MEM_Base::RecoverMessageFromRepresentative(
	HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
	byte *representative, unsigned int representativeBitLength,
	byte *recoverableMessage) const
{
	const unsigned int u = hashIdentifier.second + 1;
	const unsigned int representativeByteLength = BitsToBytes(representativeBitLength);
	const unsigned int digestSize = hash.DigestSize();
	const unsigned int saltSize = SaltLen(digestSize);
	const byte *const h = representative + representativeByteLength - u - digestSize;

	SecByteBlock digest(digestSize);
	hash.Final(digest);

	DecodingResult result(0);
	bool &valid = result.isValidCoding;
	unsigned int &recoverableMessageLength = result.messageLength;

	valid = (representative[representativeByteLength - 1] == (hashIdentifier.second ? 0xcc : 0xbc)) && valid;
	valid = (memcmp(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second) == 0) && valid;

	GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize);
	if (representativeBitLength % 8 != 0)
		representative[0] = Crop(representative[0], representativeBitLength % 8);

	// extract salt and recoverableMessage from DB = 00 ... || 01 || M || salt
	byte *salt = representative + representativeByteLength - u - digestSize - saltSize;
	byte *M = std::find_if(representative, salt-1, std::bind2nd(std::not_equal_to<byte>(), 0));
	if (*M == 0x01 && M - representative - (representativeBitLength % 8 != 0) >= MinPadLen(digestSize))
	{
		recoverableMessageLength = salt-M-1;
		memcpy(recoverableMessage, M+1, recoverableMessageLength);
	}
	else
		valid = false;

	// verify H = hash of M'
	byte c[8];
	UnalignedPutWord(BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
	UnalignedPutWord(BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
	hash.Update(c, 8);
	hash.Update(recoverableMessage, recoverableMessageLength);
	hash.Update(digest, digestSize);
	hash.Update(salt, saltSize);
	valid = hash.Verify(h) && valid;

	if (!AllowRecovery() && valid && recoverableMessageLength != 0)
		{throw NotImplemented("PSSR_MEM: message recovery disabled");}
	
	return result;
}

NAMESPACE_END