summaryrefslogtreecommitdiff
path: root/pssr.cpp
blob: 19360d3ac860e9bc98959b64e0fa1a50987a9e67 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// pssr.cpp - written and placed in the public domain by Wei Dai

#include "pch.h"
#include "pssr.h"

NAMESPACE_BEGIN(CryptoPP)

// more in dll.cpp
template<> const byte EMSA2HashId<RIPEMD160>::id = 0x31;
template<> const byte EMSA2HashId<RIPEMD128>::id = 0x32;
template<> const byte EMSA2HashId<Whirlpool>::id = 0x37;

#ifndef CRYPTOPP_IMPORTS

unsigned int PSSR_MEM_Base::MaxRecoverableLength(unsigned int representativeBitLength, unsigned int hashIdentifierLength, unsigned int digestLength) const
{
	if (AllowRecovery())
	{
		unsigned int saltLen = SaltLen(digestLength);
		unsigned int minPadLen = MinPadLen(digestLength);
		return SaturatingSubtract(representativeBitLength, 8*(minPadLen + saltLen + digestLength + hashIdentifierLength) + 9) / 8;
	}
	return 0;
}

bool PSSR_MEM_Base::IsProbabilistic() const 
{
	return SaltLen(1) > 0;
}

bool PSSR_MEM_Base::AllowNonrecoverablePart() const
{
	return true;
}

bool PSSR_MEM_Base::RecoverablePartFirst() const
{
	return false;
}

void PSSR_MEM_Base::ComputeMessageRepresentative(RandomNumberGenerator &rng, 
	const byte *recoverableMessage, unsigned int recoverableMessageLength,
	HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
	byte *representative, unsigned int representativeBitLength) const
{
	const unsigned int u = hashIdentifier.second + 1;
	const unsigned int representativeByteLength = BitsToBytes(representativeBitLength);
	const unsigned int digestSize = hash.DigestSize();
	const unsigned int saltSize = SaltLen(digestSize);
	byte *const h = representative + representativeByteLength - u - digestSize;

	SecByteBlock digest(digestSize), salt(saltSize);
	hash.Final(digest);
	rng.GenerateBlock(salt, saltSize);

	// compute H = hash of M'
	byte c[8];
	UnalignedPutWord(BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
	UnalignedPutWord(BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
	hash.Update(c, 8);
	hash.Update(recoverableMessage, recoverableMessageLength);
	hash.Update(digest, digestSize);
	hash.Update(salt, saltSize);
	hash.Final(h);

	// compute representative
	GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize, false);
	byte *xorStart = representative + representativeByteLength - u - digestSize - salt.size() - recoverableMessageLength - 1;
	xorStart[0] ^= 1;
	xorbuf(xorStart + 1, recoverableMessage, recoverableMessageLength);
	xorbuf(xorStart + 1 + recoverableMessageLength, salt, salt.size());
	memcpy(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second);
	representative[representativeByteLength - 1] = hashIdentifier.second ? 0xcc : 0xbc;
	if (representativeBitLength % 8 != 0)
		representative[0] = (byte)Crop(representative[0], representativeBitLength % 8);
}

DecodingResult PSSR_MEM_Base::RecoverMessageFromRepresentative(
	HashTransformation &hash, HashIdentifier hashIdentifier, bool messageEmpty,
	byte *representative, unsigned int representativeBitLength,
	byte *recoverableMessage) const
{
	const unsigned int u = hashIdentifier.second + 1;
	const unsigned int representativeByteLength = BitsToBytes(representativeBitLength);
	const unsigned int digestSize = hash.DigestSize();
	const unsigned int saltSize = SaltLen(digestSize);
	const byte *const h = representative + representativeByteLength - u - digestSize;

	SecByteBlock digest(digestSize);
	hash.Final(digest);

	DecodingResult result(0);
	bool &valid = result.isValidCoding;
	unsigned int &recoverableMessageLength = result.messageLength;

	valid = (representative[representativeByteLength - 1] == (hashIdentifier.second ? 0xcc : 0xbc)) && valid;
	valid = (memcmp(representative + representativeByteLength - u, hashIdentifier.first, hashIdentifier.second) == 0) && valid;

	GetMGF().GenerateAndMask(hash, representative, representativeByteLength - u - digestSize, h, digestSize);
	if (representativeBitLength % 8 != 0)
		representative[0] = (byte)Crop(representative[0], representativeBitLength % 8);

	// extract salt and recoverableMessage from DB = 00 ... || 01 || M || salt
	byte *salt = representative + representativeByteLength - u - digestSize - saltSize;
	byte *M = std::find_if(representative, salt-1, std::bind2nd(std::not_equal_to<byte>(), 0));
	if (*M == 0x01 && (unsigned int)(M - representative - (representativeBitLength % 8 != 0)) >= MinPadLen(digestSize))
	{
		recoverableMessageLength = salt-M-1;
		memcpy(recoverableMessage, M+1, recoverableMessageLength);
	}
	else
		valid = false;

	// verify H = hash of M'
	byte c[8];
	UnalignedPutWord(BIG_ENDIAN_ORDER, c, (word32)SafeRightShift<29>(recoverableMessageLength));
	UnalignedPutWord(BIG_ENDIAN_ORDER, c+4, word32(recoverableMessageLength << 3));
	hash.Update(c, 8);
	hash.Update(recoverableMessage, recoverableMessageLength);
	hash.Update(digest, digestSize);
	hash.Update(salt, saltSize);
	valid = hash.Verify(h) && valid;

	if (!AllowRecovery() && valid && recoverableMessageLength != 0)
		{throw NotImplemented("PSSR_MEM: message recovery disabled");}
	
	return result;
}

#endif

NAMESPACE_END