summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Warner <warner@lothar.com>2011-10-04 18:45:13 -0400
committerBrian Warner <warner@lothar.com>2011-10-04 18:45:13 -0400
commit70e3ee7d6bc7c137749c300727d4d5b552773bb7 (patch)
tree4fba239dec80bd9f79c64694594e72222c51737c
parent51b219eb9bf46a6ad05b81e422a1508df5e88a6b (diff)
downloadecdsa-70e3ee7d6bc7c137749c300727d4d5b552773bb7.tar.gz
accept default hashfunc in Key constructor, so wrappers are easier to build
-rw-r--r--ecdsa/keys.py35
-rw-r--r--ecdsa/test_pyecdsa.py22
2 files changed, 40 insertions, 17 deletions
diff --git a/ecdsa/keys.py b/ecdsa/keys.py
index 2ef5936..29a1cd7 100644
--- a/ecdsa/keys.py
+++ b/ecdsa/keys.py
@@ -19,15 +19,16 @@ class VerifyingKey:
raise TypeError("Please use SigningKey.generate() to construct me")
@classmethod
- def from_public_point(klass, point, curve=NIST192p):
+ def from_public_point(klass, point, curve=NIST192p, hashfunc=sha1):
self = klass(_error__please_use_generate=True)
self.curve = curve
+ self.default_hashfunc = hashfunc
self.pubkey = ecdsa.Public_key(curve.generator, point)
self.pubkey.order = curve.order
return self
@classmethod
- def from_string(klass, string, curve=NIST192p):
+ def from_string(klass, string, curve=NIST192p, hashfunc=sha1):
order = curve.order
assert len(string) == curve.verifying_key_length, \
(len(string), curve.verifying_key_length)
@@ -40,7 +41,7 @@ class VerifyingKey:
assert ecdsa.point_is_valid(curve.generator, x, y)
import ellipticcurve
point = ellipticcurve.Point(curve.curve, x, y, order)
- return klass.from_public_point(point, curve)
+ return klass.from_public_point(point, curve, hashfunc)
@classmethod
def from_pem(klass, string):
@@ -90,7 +91,8 @@ class VerifyingKey:
self.curve.encoded_oid),
der.encode_bitstring(point_str))
- def verify(self, signature, data, hashfunc=sha1, sigdecode=sigdecode_string):
+ def verify(self, signature, data, hashfunc=None, sigdecode=sigdecode_string):
+ hashfunc = hashfunc or self.default_hashfunc
digest = hashfunc(data).digest()
return self.verify_digest(signature, digest, sigdecode)
@@ -112,9 +114,9 @@ class SigningKey:
raise TypeError("Please use SigningKey.generate() to construct me")
@classmethod
- def generate(klass, curve=NIST192p, entropy=None):
+ def generate(klass, curve=NIST192p, entropy=None, hashfunc=sha1):
secexp = randrange(curve.order, entropy)
- return klass.from_secret_exponent(secexp, curve)
+ return klass.from_secret_exponent(secexp, curve, hashfunc)
# to create a signing key from a short (arbitrary-length) seed, convert
# that seed into an integer with something like
@@ -122,34 +124,36 @@ class SigningKey:
# that integer into SigningKey.from_secret_exponent(secexp, curve)
@classmethod
- def from_secret_exponent(klass, secexp, curve=NIST192p):
+ def from_secret_exponent(klass, secexp, curve=NIST192p, hashfunc=sha1):
self = klass(_error__please_use_generate=True)
self.curve = curve
+ self.default_hashfunc = hashfunc
self.baselen = curve.baselen
n = curve.order
assert 1 <= secexp < n
pubkey_point = curve.generator*secexp
pubkey = ecdsa.Public_key(curve.generator, pubkey_point)
pubkey.order = n
- self.verifying_key = VerifyingKey.from_public_point(pubkey_point, curve)
+ self.verifying_key = VerifyingKey.from_public_point(pubkey_point, curve,
+ hashfunc)
self.privkey = ecdsa.Private_key(pubkey, secexp)
self.privkey.order = n
return self
@classmethod
- def from_string(klass, string, curve=NIST192p):
+ def from_string(klass, string, curve=NIST192p, hashfunc=sha1):
assert len(string) == curve.baselen, (len(string), curve.baselen)
secexp = string_to_number(string)
- return klass.from_secret_exponent(secexp, curve)
+ return klass.from_secret_exponent(secexp, curve, hashfunc)
@classmethod
- def from_pem(klass, string):
+ def from_pem(klass, string, hashfunc=sha1):
# the privkey pem file has two sections: "EC PARAMETERS" and "EC
# PRIVATE KEY". The first is redundant.
privkey_pem = string[string.index("-----BEGIN EC PRIVATE KEY-----"):]
- return klass.from_der(der.unpem(privkey_pem))
+ return klass.from_der(der.unpem(privkey_pem), hashfunc)
@classmethod
- def from_der(klass, string):
+ def from_der(klass, string, hashfunc=sha1):
# SEQ([int(1), octetstring(privkey),cont[0], oid(secp224r1),
# cont[1],bitstring])
s, empty = der.remove_sequence(string)
@@ -185,7 +189,7 @@ class SigningKey:
# our from_string method likes fixed-length privkey strings
if len(privkey_str) < curve.baselen:
privkey_str = "\x00"*(curve.baselen-len(privkey_str)) + privkey_str
- return klass.from_string(privkey_str, curve)
+ return klass.from_string(privkey_str, curve, hashfunc)
def to_string(self):
secexp = self.privkey.secret_multiplier
@@ -209,7 +213,7 @@ class SigningKey:
def get_verifying_key(self):
return self.verifying_key
- def sign(self, data, entropy=None, hashfunc=sha1, sigencode=sigencode_string):
+ def sign(self, data, entropy=None, hashfunc=None, sigencode=sigencode_string):
"""
hashfunc= should behave like hashlib.sha1 . The output length of the
hash (in bytes) must not be longer than the length of the curve order
@@ -222,6 +226,7 @@ class SigningKey:
or hashfunc=hashlib.sha256 for openssl-1.0.0's -ecdsa-with-SHA256.
"""
+ hashfunc = hashfunc or self.default_hashfunc
h = hashfunc(data).digest()
return self.sign_digest(h, entropy, sigencode)
diff --git a/ecdsa/test_pyecdsa.py b/ecdsa/test_pyecdsa.py
index 383a62a..1f1a9bd 100644
--- a/ecdsa/test_pyecdsa.py
+++ b/ecdsa/test_pyecdsa.py
@@ -4,7 +4,7 @@ import time
import shutil
import subprocess
from binascii import hexlify, unhexlify
-from hashlib import sha1
+from hashlib import sha1, sha256
from keys import SigningKey, VerifyingKey
from keys import BadSignatureError
@@ -253,7 +253,25 @@ class ECDSA(unittest.TestCase):
sig_der = priv1.sign(data, sigencode=sigencode_der)
self.failUnlessEqual(type(sig_der), str)
self.failUnless(pub1.verify(sig_der, data, sigdecode=sigdecode_der))
-
+
+ def test_hashfunc(self):
+ sk = SigningKey.generate(curve=NIST256p, hashfunc=sha256)
+ data = "security level is 128 bits"
+ sig = sk.sign(data)
+ vk = VerifyingKey.from_string(sk.get_verifying_key().to_string(),
+ curve=NIST256p, hashfunc=sha256)
+ self.failUnless(vk.verify(sig, data))
+
+ sk2 = SigningKey.generate(curve=NIST256p)
+ sig2 = sk2.sign(data, hashfunc=sha256)
+ vk2 = VerifyingKey.from_string(sk2.get_verifying_key().to_string(),
+ curve=NIST256p, hashfunc=sha256)
+ self.failUnless(vk2.verify(sig2, data))
+
+ vk3 = VerifyingKey.from_string(sk.get_verifying_key().to_string(),
+ curve=NIST256p)
+ self.failUnless(vk3.verify(sig, data, hashfunc=sha256))
+
class OpenSSL(unittest.TestCase):
# test interoperability with OpenSSL tools. Note that openssl's ECDSA