diff options
-rw-r--r-- | lib/dtls-sw.c | 137 | ||||
-rw-r--r-- | lib/gnutls_int.h | 2 | ||||
-rw-r--r-- | tests/dtls-sliding-window.c | 60 | ||||
-rw-r--r-- | tests/mini-dtls-record.c | 3 |
4 files changed, 123 insertions, 79 deletions
diff --git a/lib/dtls-sw.c b/lib/dtls-sw.c index 8c334c05af..36630abb07 100644 --- a/lib/dtls-sw.c +++ b/lib/dtls-sw.c @@ -37,42 +37,26 @@ */ #define DTLS_EPOCH_SHIFT (6*CHAR_BIT) #define DTLS_SEQ_NUM_MASK 0x0000FFFFFFFFFFFF -#define DTLS_WINDOW_HAVE_RECV_PACKET(W) ((W)->dtls_sw_have_recv != 0) -#define DTLS_WINDOW_INIT_AT(W, S) (W)->dtls_sw_bits = ((W)->dtls_sw_have_recv) = 0; (W)->dtls_sw_start = (S&DTLS_SEQ_NUM_MASK) -#define DTLS_WINDOW_INIT(W) DTLS_WINDOW_INIT_AT(W, 0) +#define DTLS_EMPTY_BITMAP (0xFFFFFFFFFFFFFFFFULL) -#define DTLS_WINDOW_INSIDE(W, S) ((((S) & DTLS_SEQ_NUM_MASK) > (W)->dtls_sw_start) && \ - (((S) & DTLS_SEQ_NUM_MASK) - (W)->dtls_sw_start <= (sizeof((W)->dtls_sw_bits) * CHAR_BIT))) +/* We expect the compiler to be able to spot that this is a byteswapping + * load, and emit instructions like 'movbe' on x86_64 where appropriate. +*/ +#define LOAD_UINT64(out, ubytes) \ + out = (((uint64_t)ubytes[0] << 56) | \ + ((uint64_t)ubytes[1] << 48) | \ + ((uint64_t)ubytes[2] << 40) | \ + ((uint64_t)ubytes[3] << 32) | \ + ((uint64_t)ubytes[4] << 24) | \ + ((uint64_t)ubytes[5] << 16) | \ + ((uint64_t)ubytes[6] << 8) | \ + ((uint64_t)ubytes[7] << 0) ) -#define DTLS_WINDOW_OFFSET(W, S) ((((S) & DTLS_SEQ_NUM_MASK) - (W)->dtls_sw_start) - 1) - -#define DTLS_WINDOW_RECEIVED(W, S) (((W)->dtls_sw_bits & ((uint64_t) 1 << DTLS_WINDOW_OFFSET(W, S))) != 0) - -#define DTLS_WINDOW_MARK(W, S) ((W)->dtls_sw_bits |= ((uint64_t) 1 << DTLS_WINDOW_OFFSET(W, S))) - -/* We forcefully advance the window once we have received more than - * 8 packets since the first one. That way we ensure that we don't - * get stuck on connections with many lost packets. */ -#define DTLS_WINDOW_UPDATE(W) \ - if (((W)->dtls_sw_bits & 0xffffffffffff0000LL) != 0) { \ - (W)->dtls_sw_bits = (W)->dtls_sw_bits >> 1; \ - (W)->dtls_sw_start++; \ - } \ - while ((W)->dtls_sw_bits & (uint64_t) 1) { \ - (W)->dtls_sw_bits = (W)->dtls_sw_bits >> 1; \ - (W)->dtls_sw_start++; \ - } - -#define LOAD_UINT64(out, ubytes) \ - for (i = 0; i < 8; i++) { \ - out <<= 8; \ - out |= ubytes[i] & 0xff; \ - } void _dtls_reset_window(struct record_parameters_st *rp) { - DTLS_WINDOW_INIT(rp); + rp->dtls_sw_have_recv = 0; } /* Checks if a sequence number is not replayed. If a replayed @@ -82,7 +66,6 @@ void _dtls_reset_window(struct record_parameters_st *rp) int _dtls_record_check(struct record_parameters_st *rp, uint64 * _seq) { uint64_t seq_num = 0; - unsigned i; LOAD_UINT64(seq_num, _seq->i); @@ -90,24 +73,86 @@ int _dtls_record_check(struct record_parameters_st *rp, uint64 * _seq) return gnutls_assert_val(-1); } - if (!DTLS_WINDOW_HAVE_RECV_PACKET(rp)) { - DTLS_WINDOW_INIT_AT(rp, seq_num); + seq_num &= DTLS_SEQ_NUM_MASK; + + /* + * rp->dtls_sw_next is the next *expected* packet (N), being + * the sequence number *after* the latest we have received. + * + * By definition, therefore, packet N-1 *has* been received. + * And thus there's no point wasting a bit in the bitmap for it. + * + * So the backlog bitmap covers the 64 packets prior to that, + * with the LSB representing packet (N - 2), and the MSB + * representing (N - 65). A received packet is represented + * by a zero bit, and a missing packet is represented by a one. + * + * Thus we can allow out-of-order reception of packets that are + * within a reasonable interval of the latest packet received. + */ + if (!rp->dtls_sw_have_recv) { + rp->dtls_sw_next = seq_num + 1; + rp->dtls_sw_bits = DTLS_EMPTY_BITMAP; rp->dtls_sw_have_recv = 1; return 0; - } + } else if (seq_num == rp->dtls_sw_next) { + /* The common case. This is the packet we expected next. */ - /* are we inside sliding window? */ - if (!DTLS_WINDOW_INSIDE(rp, seq_num)) { - return gnutls_assert_val(-2); - } + rp->dtls_sw_bits <<= 1; - /* already received? */ - if (DTLS_WINDOW_RECEIVED(rp, seq_num)) { - return gnutls_assert_val(-3); + /* This might reach a value higher than 48-bit DTLS sequence + * numbers can actually reach. Which is fine. When that + * happens, we'll do the right thing and just not accept + * any newer packets. Someone needs to start a new epoch. */ + rp->dtls_sw_next++; + return 0; + } else if (seq_num > rp->dtls_sw_next) { + /* The packet we were expecting has gone missing; this one is newer. + * We always advance the window to accommodate it. */ + uint64_t delta = seq_num - rp->dtls_sw_next; + + if (delta >= 64) { + /* We jumped a long way into the future. We have not seen + * any of the previous 32 packets so set the backlog bitmap + * to all ones. */ + rp->dtls_sw_bits = DTLS_EMPTY_BITMAP; + } else if (delta == 63) { + /* Avoid undefined behaviour that shifting by 64 would incur. + * The (clear) top bit represents the packet which is currently + * rp->dtls_sw_next, which we know was already received. */ + rp->dtls_sw_bits = DTLS_EMPTY_BITMAP >> 1; + } else { + /* We have missed (delta) packets. Shift the backlog by that + * amount *plus* the one we would have shifted it anyway if + * we'd received the packet we were expecting. The zero bit + * representing the packet which is currently rp->dtls_sw_next-1, + * which we know has been received, ends up at bit position + * (1<<delta). Then we set all the bits lower than that, which + * represent the missing packets. */ + rp->dtls_sw_bits <<= delta + 1; + rp->dtls_sw_bits |= (1ULL << delta) - 1; + } + rp->dtls_sw_next = seq_num + 1; + return 0; + } else { + /* This packet is older than the one we were expecting. By how much...? */ + uint64_t delta = rp->dtls_sw_next - seq_num; + + if (delta > 65) { + /* Too old. We can't know if it's a replay */ + return gnutls_assert_val(-2); + } else if (delta == 1) { + /* Not in the bitmask since it is by definition already received. */ + return gnutls_assert_val(-3); + } else { + /* Within the sliding window, so we remember whether we've seen it or not */ + uint64_t mask = 1ULL << (rp->dtls_sw_next - seq_num - 2); + + if (!(rp->dtls_sw_bits & mask)) + return gnutls_assert_val(-3); + + rp->dtls_sw_bits &= ~mask; + return 0; + } } - - DTLS_WINDOW_MARK(rp, seq_num); - DTLS_WINDOW_UPDATE(rp); - - return 0; } diff --git a/lib/gnutls_int.h b/lib/gnutls_int.h index a984b49c7f..25d4b3a814 100644 --- a/lib/gnutls_int.h +++ b/lib/gnutls_int.h @@ -622,8 +622,8 @@ struct record_parameters_st { const mac_entry_st *mac; /* for DTLS sliding window */ + uint64_t dtls_sw_next; /* The end point (next expected packet) of the sliding window without epoch */ uint64_t dtls_sw_bits; - uint64_t dtls_sw_start; /* The starting point of the sliding window without epoch */ unsigned dtls_sw_have_recv; /* whether at least a packet has been received */ record_state_st read; diff --git a/tests/dtls-sliding-window.c b/tests/dtls-sliding-window.c index 1ad15296cd..c6a5e3d554 100644 --- a/tests/dtls-sliding-window.c +++ b/tests/dtls-sliding-window.c @@ -36,7 +36,7 @@ struct record_parameters_st { uint64_t dtls_sw_bits; - uint64_t dtls_sw_start; + uint64_t dtls_sw_next; unsigned dtls_sw_have_recv; unsigned epoch; }; @@ -75,8 +75,8 @@ static void uint64_set(uint64* t, uint64_t v) #define RESET_WINDOW \ memset(&state, 0, sizeof(state)) -#define SET_WINDOW_START(x) \ - state.dtls_sw_start = ((x)&DTLS_SEQ_NUM_MASK) +#define SET_WINDOW_NEXT(x) \ + state.dtls_sw_next = (((x)&DTLS_SEQ_NUM_MASK)) #define SET_WINDOW_LAST_RECV(x) \ uint64_set(&t, BSWAP64(x)); \ @@ -88,7 +88,7 @@ static void check_dtls_window_uninit_0(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(0); + SET_WINDOW_NEXT(0); uint64_set(&t, 0); @@ -125,7 +125,7 @@ static void check_dtls_window_12(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(0); + SET_WINDOW_NEXT(0); SET_WINDOW_LAST_RECV(1); uint64_set(&t, BSWAP64(2)); @@ -139,7 +139,7 @@ static void check_dtls_window_19(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(0); + SET_WINDOW_NEXT(0); SET_WINDOW_LAST_RECV(1); uint64_set(&t, BSWAP64(9)); @@ -154,7 +154,7 @@ static void check_dtls_window_skip1(void **glob_state) unsigned i; RESET_WINDOW; - SET_WINDOW_START(0); + SET_WINDOW_NEXT(0); SET_WINDOW_LAST_RECV(1); for (i=2;i<256;i+=2) { @@ -170,7 +170,7 @@ static void check_dtls_window_skip3(void **glob_state) unsigned i; RESET_WINDOW; - SET_WINDOW_START(0); + SET_WINDOW_NEXT(0); SET_WINDOW_LAST_RECV(1); for (i=5;i<256;i+=2) { @@ -185,7 +185,7 @@ static void check_dtls_window_21(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(0); + SET_WINDOW_NEXT(0); SET_WINDOW_LAST_RECV(2); uint64_set(&t, BSWAP64(1)); @@ -199,7 +199,7 @@ static void check_dtls_window_91(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(0); + SET_WINDOW_NEXT(0); SET_WINDOW_LAST_RECV(9); uint64_set(&t, BSWAP64(1)); @@ -213,7 +213,7 @@ static void check_dtls_window_large_21(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT); + SET_WINDOW_NEXT(LARGE_INT); SET_WINDOW_LAST_RECV(LARGE_INT+2); uint64_set(&t, BSWAP64(LARGE_INT+1)); @@ -227,7 +227,7 @@ static void check_dtls_window_large_12(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT); + SET_WINDOW_NEXT(LARGE_INT); SET_WINDOW_LAST_RECV(LARGE_INT+1); uint64_set(&t, BSWAP64(LARGE_INT+2)); @@ -241,7 +241,7 @@ static void check_dtls_window_large_91(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT); + SET_WINDOW_NEXT(LARGE_INT); SET_WINDOW_LAST_RECV(LARGE_INT+9); uint64_set(&t, BSWAP64(LARGE_INT+1)); @@ -255,7 +255,7 @@ static void check_dtls_window_large_19(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT); + SET_WINDOW_NEXT(LARGE_INT); SET_WINDOW_LAST_RECV(LARGE_INT+1); uint64_set(&t, BSWAP64(LARGE_INT+9)); @@ -269,7 +269,7 @@ static void check_dtls_window_very_large_12(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(INT_OVER_32_BITS); + SET_WINDOW_NEXT(INT_OVER_32_BITS); SET_WINDOW_LAST_RECV(INT_OVER_32_BITS+1); uint64_set(&t, BSWAP64(INT_OVER_32_BITS+2)); @@ -283,7 +283,7 @@ static void check_dtls_window_very_large_91(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(INT_OVER_32_BITS); + SET_WINDOW_NEXT(INT_OVER_32_BITS); SET_WINDOW_LAST_RECV(INT_OVER_32_BITS+9); uint64_set(&t, BSWAP64(INT_OVER_32_BITS+1)); @@ -297,7 +297,7 @@ static void check_dtls_window_very_large_19(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(INT_OVER_32_BITS); + SET_WINDOW_NEXT(INT_OVER_32_BITS); SET_WINDOW_LAST_RECV(INT_OVER_32_BITS+1); uint64_set(&t, BSWAP64(INT_OVER_32_BITS+9)); @@ -311,12 +311,12 @@ static void check_dtls_window_outside(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(0); + SET_WINDOW_NEXT(0); SET_WINDOW_LAST_RECV(1); uint64_set(&t, BSWAP64(1+64)); - assert_int_equal(_dtls_record_check(&state, &t), -2); + assert_int_equal(_dtls_record_check(&state, &t), 0); } static void check_dtls_window_large_outside(void **glob_state) @@ -325,12 +325,12 @@ static void check_dtls_window_large_outside(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT); + SET_WINDOW_NEXT(LARGE_INT); SET_WINDOW_LAST_RECV(LARGE_INT+1); uint64_set(&t, BSWAP64(LARGE_INT+1+64)); - assert_int_equal(_dtls_record_check(&state, &t), -2); + assert_int_equal(_dtls_record_check(&state, &t), 0); } static void check_dtls_window_very_large_outside(void **glob_state) @@ -339,12 +339,12 @@ static void check_dtls_window_very_large_outside(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(INT_OVER_32_BITS); + SET_WINDOW_NEXT(INT_OVER_32_BITS); SET_WINDOW_LAST_RECV(INT_OVER_32_BITS+1); uint64_set(&t, BSWAP64(INT_OVER_32_BITS+1+64)); - assert_int_equal(_dtls_record_check(&state, &t), -2); + assert_int_equal(_dtls_record_check(&state, &t), 0); } static void check_dtls_window_dup1(void **glob_state) @@ -353,7 +353,7 @@ static void check_dtls_window_dup1(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT-1); + SET_WINDOW_NEXT(LARGE_INT-1); SET_WINDOW_LAST_RECV(LARGE_INT); uint64_set(&t, BSWAP64(LARGE_INT)); @@ -366,7 +366,7 @@ static void check_dtls_window_dup1(void **glob_state) assert_int_equal(_dtls_record_check(&state, &t), 0); uint64_set(&t, BSWAP64(LARGE_INT+1)); - assert_int_equal(_dtls_record_check(&state, &t), -2); + assert_int_equal(_dtls_record_check(&state, &t), -3); } static void check_dtls_window_dup2(void **glob_state) @@ -375,7 +375,7 @@ static void check_dtls_window_dup2(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT-1); + SET_WINDOW_NEXT(LARGE_INT-1); SET_WINDOW_LAST_RECV(LARGE_INT); uint64_set(&t, BSWAP64(LARGE_INT)); @@ -397,7 +397,7 @@ static void check_dtls_window_dup3(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT-1); + SET_WINDOW_NEXT(LARGE_INT-1); SET_WINDOW_LAST_RECV(LARGE_INT); uint64_set(&t, BSWAP64(LARGE_INT)); @@ -425,7 +425,7 @@ static void check_dtls_window_out_of_order(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT-1); + SET_WINDOW_NEXT(LARGE_INT-1); SET_WINDOW_LAST_RECV(LARGE_INT); uint64_set(&t, BSWAP64(LARGE_INT)); @@ -465,7 +465,7 @@ static void check_dtls_window_epoch_higher(void **glob_state) uint64 t; RESET_WINDOW; - SET_WINDOW_START(LARGE_INT-1); + SET_WINDOW_NEXT(LARGE_INT-1); SET_WINDOW_LAST_RECV(LARGE_INT); uint64_set(&t, BSWAP64(LARGE_INT)); @@ -484,7 +484,7 @@ static void check_dtls_window_epoch_lower(void **glob_state) uint64_set(&t, BSWAP64(0x1000000000000LL)); state.epoch = 1; - SET_WINDOW_START(0x1000000000000LL); + SET_WINDOW_NEXT(0x1000000000000LL); SET_WINDOW_LAST_RECV((0x1000000000000LL) + 1); uint64_set(&t, BSWAP64(2 | 0x1000000000000LL)); diff --git a/tests/mini-dtls-record.c b/tests/mini-dtls-record.c index 8d32d8f3af..63bba89aaf 100644 --- a/tests/mini-dtls-record.c +++ b/tests/mini-dtls-record.c @@ -163,7 +163,7 @@ static ssize_t n_push(gnutls_transport_ptr_t tr, const void *data, size_t len) /* The first five messages are handshake. Thus corresponds to msg_seq+5 */ static int recv_msg_seq[] = - { 1, 2, 3, 4, 5, 6, 12, 28, 8, 9, 10, 11, 13, 15, 16, 14, 18, 20, + { 1, 2, 3, 4, 5, 6, 12, 28, 7, 8, 9, 10, 11, 13, 15, 16, 14, 18, 20, 19, 21, 22, 23, 25, 24, 26, 27, 29, 30, 31, 33, 32, 34, 35, 38, 36, 37, -1 }; @@ -248,7 +248,6 @@ static void client(int fd) fail("received message sequence differs\n"); terminate(); } - if (((uint32_t)recv_msg_seq[current]) != useq) { fail("received message sequence differs (current: %u, got: %u, expected: %u)\n", (unsigned)current, (unsigned)useq, (unsigned)recv_msg_seq[current]); |