diff options
Diffstat (limited to 'nss/lib/freebl/ctr.c')
-rw-r--r-- | nss/lib/freebl/ctr.c | 175 |
1 files changed, 96 insertions, 79 deletions
diff --git a/nss/lib/freebl/ctr.c b/nss/lib/freebl/ctr.c index 1cbf30c..d5715a5 100644 --- a/nss/lib/freebl/ctr.c +++ b/nss/lib/freebl/ctr.c @@ -19,33 +19,38 @@ SECStatus CTR_InitContext(CTRContext *ctr, void *context, freeblCipherFunc cipher, - const unsigned char *param, unsigned int blocksize) + const unsigned char *param, unsigned int blocksize) { const CK_AES_CTR_PARAMS *ctrParams = (const CK_AES_CTR_PARAMS *)param; if (ctrParams->ulCounterBits == 0 || - ctrParams->ulCounterBits > blocksize * PR_BITS_PER_BYTE) { - PORT_SetError(SEC_ERROR_INVALID_ARGS); - return SECFailure; + ctrParams->ulCounterBits > blocksize * PR_BITS_PER_BYTE) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; } /* Invariant: 0 < ctr->bufPtr <= blocksize */ + ctr->checkWrap = PR_FALSE; ctr->bufPtr = blocksize; /* no unused data in the buffer */ ctr->cipher = cipher; ctr->context = context; ctr->counterBits = ctrParams->ulCounterBits; if (blocksize > sizeof(ctr->counter) || - blocksize > sizeof(ctrParams->cb)) { - PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); - return SECFailure; + blocksize > sizeof(ctrParams->cb)) { + PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); + return SECFailure; } PORT_Memcpy(ctr->counter, ctrParams->cb, blocksize); + if (ctr->counterBits < 64) { + PORT_Memcpy(ctr->counterFirst, ctr->counter, blocksize); + ctr->checkWrap = PR_TRUE; + } return SECSuccess; } CTRContext * CTR_CreateContext(void *context, freeblCipherFunc cipher, - const unsigned char *param, unsigned int blocksize) + const unsigned char *param, unsigned int blocksize) { CTRContext *ctr; SECStatus rv; @@ -53,12 +58,12 @@ CTR_CreateContext(void *context, freeblCipherFunc cipher, /* first fill in the Counter context */ ctr = PORT_ZNew(CTRContext); if (ctr == NULL) { - return NULL; + return NULL; } rv = CTR_InitContext(ctr, context, cipher, param, blocksize); if (rv != SECSuccess) { - CTR_DestroyContext(ctr, PR_TRUE); - ctr = NULL; + CTR_DestroyContext(ctr, PR_TRUE); + ctr = NULL; } return ctr; } @@ -68,7 +73,7 @@ CTR_DestroyContext(CTRContext *ctr, PRBool freeit) { PORT_Memset(ctr, 0, sizeof(CTRContext)); if (freeit) { - PORT_Free(ctr); + PORT_Free(ctr); } } @@ -82,23 +87,23 @@ CTR_DestroyContext(CTRContext *ctr, PRBool freeit) */ static void ctr_GetNextCtr(unsigned char *counter, unsigned int counterBits, - unsigned int blocksize) + unsigned int blocksize) { unsigned char *counterPtr = counter + blocksize - 1; unsigned char mask, count; - PORT_Assert(counterBits <= blocksize*PR_BITS_PER_BYTE); + PORT_Assert(counterBits <= blocksize * PR_BITS_PER_BYTE); while (counterBits >= PR_BITS_PER_BYTE) { - if (++(*(counterPtr--))) { - return; - } - counterBits -= PR_BITS_PER_BYTE; + if (++(*(counterPtr--))) { + return; + } + counterBits -= PR_BITS_PER_BYTE; } if (counterBits == 0) { - return; + return; } /* increment the final partial byte */ - mask = (1 << counterBits)-1; + mask = (1 << counterBits) - 1; count = ++(*counterPtr) & mask; *counterPtr = ((*counterPtr) & ~mask) | count; return; @@ -106,64 +111,76 @@ ctr_GetNextCtr(unsigned char *counter, unsigned int counterBits, static void ctr_xor(unsigned char *target, const unsigned char *x, - const unsigned char *y, unsigned int count) + const unsigned char *y, unsigned int count) { unsigned int i; - for (i=0; i < count; i++) { - *target++ = *x++ ^ *y++; + for (i = 0; i < count; i++) { + *target++ = *x++ ^ *y++; } } SECStatus CTR_Update(CTRContext *ctr, unsigned char *outbuf, - unsigned int *outlen, unsigned int maxout, - const unsigned char *inbuf, unsigned int inlen, - unsigned int blocksize) + unsigned int *outlen, unsigned int maxout, + const unsigned char *inbuf, unsigned int inlen, + unsigned int blocksize) { unsigned int tmp; SECStatus rv; if (maxout < inlen) { - *outlen = inlen; - PORT_SetError(SEC_ERROR_OUTPUT_LEN); - return SECFailure; + *outlen = inlen; + PORT_SetError(SEC_ERROR_OUTPUT_LEN); + return SECFailure; } *outlen = 0; if (ctr->bufPtr != blocksize) { - unsigned int needed = PR_MIN(blocksize-ctr->bufPtr, inlen); - ctr_xor(outbuf, inbuf, ctr->buffer + ctr->bufPtr, needed); - ctr->bufPtr += needed; - outbuf += needed; - inbuf += needed; - *outlen += needed; - inlen -= needed; - if (inlen == 0) { - return SECSuccess; - } - PORT_Assert(ctr->bufPtr == blocksize); + unsigned int needed = PR_MIN(blocksize - ctr->bufPtr, inlen); + ctr_xor(outbuf, inbuf, ctr->buffer + ctr->bufPtr, needed); + ctr->bufPtr += needed; + outbuf += needed; + inbuf += needed; + *outlen += needed; + inlen -= needed; + if (inlen == 0) { + return SECSuccess; + } + PORT_Assert(ctr->bufPtr == blocksize); } while (inlen >= blocksize) { - rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize, - ctr->counter, blocksize, blocksize); - ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize); - if (rv != SECSuccess) { - return SECFailure; - } - ctr_xor(outbuf, inbuf, ctr->buffer, blocksize); - outbuf += blocksize; - inbuf += blocksize; - *outlen += blocksize; - inlen -= blocksize; + rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize, + ctr->counter, blocksize, blocksize); + ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize); + if (ctr->checkWrap) { + if (PORT_Memcmp(ctr->counter, ctr->counterFirst, blocksize) == 0) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + } + if (rv != SECSuccess) { + return SECFailure; + } + ctr_xor(outbuf, inbuf, ctr->buffer, blocksize); + outbuf += blocksize; + inbuf += blocksize; + *outlen += blocksize; + inlen -= blocksize; } if (inlen == 0) { - return SECSuccess; + return SECSuccess; } rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize, - ctr->counter, blocksize, blocksize); + ctr->counter, blocksize, blocksize); ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize); + if (ctr->checkWrap) { + if (PORT_Memcmp(ctr->counter, ctr->counterFirst, blocksize) == 0) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + } if (rv != SECSuccess) { - return SECFailure; + return SECFailure; } ctr_xor(outbuf, inbuf, ctr->buffer, inlen); ctr->bufPtr = inlen; @@ -174,52 +191,52 @@ CTR_Update(CTRContext *ctr, unsigned char *outbuf, #if defined(USE_HW_AES) && defined(_MSC_VER) SECStatus CTR_Update_HW_AES(CTRContext *ctr, unsigned char *outbuf, - unsigned int *outlen, unsigned int maxout, - const unsigned char *inbuf, unsigned int inlen, - unsigned int blocksize) + unsigned int *outlen, unsigned int maxout, + const unsigned char *inbuf, unsigned int inlen, + unsigned int blocksize) { unsigned int fullblocks; unsigned int tmp; SECStatus rv; if (maxout < inlen) { - *outlen = inlen; - PORT_SetError(SEC_ERROR_OUTPUT_LEN); - return SECFailure; + *outlen = inlen; + PORT_SetError(SEC_ERROR_OUTPUT_LEN); + return SECFailure; } *outlen = 0; if (ctr->bufPtr != blocksize) { - unsigned int needed = PR_MIN(blocksize-ctr->bufPtr, inlen); - ctr_xor(outbuf, inbuf, ctr->buffer + ctr->bufPtr, needed); - ctr->bufPtr += needed; - outbuf += needed; - inbuf += needed; - *outlen += needed; - inlen -= needed; - if (inlen == 0) { - return SECSuccess; - } - PORT_Assert(ctr->bufPtr == blocksize); - } - - intel_aes_ctr_worker(((AESContext*)(ctr->context))->Nr)( - ctr, outbuf, outlen, maxout, inbuf, inlen, blocksize); + unsigned int needed = PR_MIN(blocksize - ctr->bufPtr, inlen); + ctr_xor(outbuf, inbuf, ctr->buffer + ctr->bufPtr, needed); + ctr->bufPtr += needed; + outbuf += needed; + inbuf += needed; + *outlen += needed; + inlen -= needed; + if (inlen == 0) { + return SECSuccess; + } + PORT_Assert(ctr->bufPtr == blocksize); + } + + intel_aes_ctr_worker(((AESContext *)(ctr->context))->Nr)( + ctr, outbuf, outlen, maxout, inbuf, inlen, blocksize); /* XXX intel_aes_ctr_worker should set *outlen. */ PORT_Assert(*outlen == 0); - fullblocks = (inlen/blocksize)*blocksize; + fullblocks = (inlen / blocksize) * blocksize; *outlen += fullblocks; outbuf += fullblocks; inbuf += fullblocks; inlen -= fullblocks; if (inlen == 0) { - return SECSuccess; + return SECSuccess; } rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize, - ctr->counter, blocksize, blocksize); + ctr->counter, blocksize, blocksize); ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize); if (rv != SECSuccess) { - return SECFailure; + return SECFailure; } ctr_xor(outbuf, inbuf, ctr->buffer, inlen); ctr->bufPtr = inlen; |