summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest/databuffer.h
blob: e7236d4e9ab99f32187d9bc49f1373f6a0afb9af (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef databuffer_h__
#define databuffer_h__

#include <algorithm>
#include <cassert>
#include <cstring>
#include <iomanip>
#include <iostream>
#if defined(WIN32) || defined(WIN64)
#include <winsock2.h>
#else
#include <arpa/inet.h>
#endif

extern bool g_ssl_gtest_verbose;

namespace nss_test {

class DataBuffer {
 public:
  DataBuffer() : data_(nullptr), len_(0) {}
  DataBuffer(const uint8_t* data, size_t len) : data_(nullptr), len_(0) {
    Assign(data, len);
  }
  DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) {
    Assign(other);
  }
  ~DataBuffer() { delete[] data_; }

  DataBuffer& operator=(const DataBuffer& other) {
    if (&other != this) {
      Assign(other);
    }
    return *this;
  }

  void Allocate(size_t len) {
    delete[] data_;
    data_ = new uint8_t[len ? len : 1];  // Don't depend on new [0].
    len_ = len;
  }

  void Truncate(size_t len) { len_ = std::min(len_, len); }

  void Assign(const DataBuffer& other) { Assign(other.data(), other.len()); }

  void Assign(const uint8_t* data, size_t len) {
    if (data) {
      Allocate(len);
      memcpy(static_cast<void*>(data_), static_cast<const void*>(data), len);
    } else {
      assert(len == 0);
      data_ = nullptr;
      len_ = 0;
    }
  }

  // Write will do a new allocation and expand the size of the buffer if needed.
  // Returns the offset of the end of the write.
  size_t Write(size_t index, const uint8_t* val, size_t count) {
    assert(val);
    if (index + count > len_) {
      size_t newlen = index + count;
      uint8_t* tmp = new uint8_t[newlen];  // Always > 0.
      if (data_) {
        memcpy(static_cast<void*>(tmp), static_cast<const void*>(data_), len_);
      }
      if (index > len_) {
        memset(static_cast<void*>(tmp + len_), 0, index - len_);
      }
      delete[] data_;
      data_ = tmp;
      len_ = newlen;
    }
    if (data_) {
      memcpy(static_cast<void*>(data_ + index), static_cast<const void*>(val),
             count);
    }
    return index + count;
  }

  size_t Write(size_t index, const DataBuffer& buf) {
    return Write(index, buf.data(), buf.len());
  }

  // Write an integer, also performing host-to-network order conversion.
  // Returns the offset of the end of the write.
  size_t Write(size_t index, uint32_t val, size_t count) {
    assert(count <= sizeof(uint32_t));
    uint32_t nvalue = htonl(val);
    auto* addr = reinterpret_cast<const uint8_t*>(&nvalue);
    return Write(index, addr + sizeof(uint32_t) - count, count);
  }

  // This can't use the same trick as Write(), since we might be reading from a
  // smaller data source.
  bool Read(size_t index, size_t count, uint32_t* val) const {
    assert(count < sizeof(uint32_t));
    assert(val);
    if ((index > len()) || (count > (len() - index))) {
      return false;
    }
    *val = 0;
    for (size_t i = 0; i < count; ++i) {
      *val = (*val << 8) | data()[index + i];
    }
    return true;
  }

  // Starting at |index|, remove |remove| bytes and replace them with the
  // contents of |buf|.
  void Splice(const DataBuffer& buf, size_t index, size_t remove = 0) {
    Splice(buf.data(), buf.len(), index, remove);
  }

  void Splice(const uint8_t* ins, size_t ins_len, size_t index,
              size_t remove = 0) {
    assert(ins);
    uint8_t* old_value = data_;
    size_t old_len = len_;

    // The amount of stuff remaining from the tail of the old.
    size_t tail_len = old_len - std::min(old_len, index + remove);
    // The new length: the head of the old, the new, and the tail of the old.
    len_ = index + ins_len + tail_len;
    data_ = new uint8_t[len_ ? len_ : 1];

    // The head of the old.
    if (old_value) {
      Write(0, old_value, std::min(old_len, index));
    }
    // Maybe a gap.
    if (old_value && index > old_len) {
      memset(old_value + index, 0, index - old_len);
    }
    // The new.
    Write(index, ins, ins_len);
    // The tail of the old.
    if (tail_len > 0) {
      Write(index + ins_len, old_value + index + remove, tail_len);
    }

    delete[] old_value;
  }

  void Append(const DataBuffer& buf) { Splice(buf, len_); }

  const uint8_t* data() const { return data_; }
  uint8_t* data() { return data_; }
  size_t len() const { return len_; }
  bool empty() const { return len_ == 0; }

 private:
  uint8_t* data_;
  size_t len_;
};

static const size_t kMaxBufferPrint = 32;

inline std::ostream& operator<<(std::ostream& stream, const DataBuffer& buf) {
  stream << "[" << buf.len() << "] ";
  for (size_t i = 0; i < buf.len(); ++i) {
    if (!g_ssl_gtest_verbose && i >= kMaxBufferPrint) {
      stream << "...";
      break;
    }
    stream << std::hex << std::setfill('0') << std::setw(2)
           << static_cast<unsigned>(buf.data()[i]);
  }
  stream << std::dec;
  return stream;
}

inline bool operator==(const DataBuffer& a, const DataBuffer& b) {
  return (a.empty() && b.empty()) ||
         (a.len() == b.len() && 0 == memcmp(a.data(), b.data(), a.len()));
}

inline bool operator!=(const DataBuffer& a, const DataBuffer& b) {
  return !(a == b);
}

}  // namespace nss_test

#endif