summaryrefslogtreecommitdiff
path: root/lib/java/src/main/java/org/apache/thrift/TNonblockingMultiFetchClient.java
blob: 034cc8599e369cbddec277c7cebac6179c5ba07a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.thrift;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * This class uses a single thread to set up non-blocking sockets to a set of remote servers
 * (hostname and port pairs), and sends a same request to all these servers. It then fetches
 * responses from servers.
 *
 * <p>Parameters: int maxRecvBufBytesPerServer - an upper limit for receive buffer size per server
 * (in byte). If a response from a server exceeds this limit, the client will not allocate memory or
 * read response data for it.
 *
 * <p>int fetchTimeoutSeconds - time limit for fetching responses from all servers (in second).
 * After the timeout, the fetch job is stopped and available responses are returned.
 *
 * <p>ByteBuffer requestBuf - request message that is sent to all servers.
 *
 * <p>Output: Responses are stored in an array of ByteBuffers. Index of elements in this array
 * corresponds to index of servers in the server list. Content in a ByteBuffer may be in one of the
 * following forms: 1. First 4 bytes form an integer indicating length of following data, then
 * followed by the data. 2. First 4 bytes form an integer indicating length of following data, then
 * followed by nothing - this happens when the response data size exceeds maxRecvBufBytesPerServer,
 * and the client will not read any response data. 3. No data in the ByteBuffer - this happens when
 * the server does not return any response within fetchTimeoutSeconds.
 *
 * <p>In some special cases (no servers are given, fetchTimeoutSeconds less than or equal to 0,
 * requestBuf is null), the return is null.
 *
 * <p>Note: It assumes all remote servers are TNonblockingServers and use TFramedTransport.
 */
public class TNonblockingMultiFetchClient {

  private static final Logger LOGGER = LoggerFactory.getLogger(TNonblockingMultiFetchClient.class);

  // if the size of the response msg exceeds this limit (in byte), we will
  // not read the msg
  private int maxRecvBufBytesPerServer;

  // time limit for fetching data from all servers (in second)
  private int fetchTimeoutSeconds;

  // store request that will be sent to servers
  private ByteBuffer requestBuf;
  private ByteBuffer requestBufDuplication;

  // a list of remote servers
  private List<InetSocketAddress> servers;

  // store fetch results
  private TNonblockingMultiFetchStats stats;
  private ByteBuffer[] recvBuf;

  public TNonblockingMultiFetchClient(
      int maxRecvBufBytesPerServer,
      int fetchTimeoutSeconds,
      ByteBuffer requestBuf,
      List<InetSocketAddress> servers) {
    this.maxRecvBufBytesPerServer = maxRecvBufBytesPerServer;
    this.fetchTimeoutSeconds = fetchTimeoutSeconds;
    this.requestBuf = requestBuf;
    this.servers = servers;

    stats = new TNonblockingMultiFetchStats();
    recvBuf = null;
  }

  public synchronized int getMaxRecvBufBytesPerServer() {
    return maxRecvBufBytesPerServer;
  }

  public synchronized int getFetchTimeoutSeconds() {
    return fetchTimeoutSeconds;
  }

  /**
   * Returns a copy of requestBuf, so that requestBuf will not be modified by others.
   *
   * @return a copy of requestBuf.
   */
  public synchronized ByteBuffer getRequestBuf() {
    if (requestBuf == null) {
      return null;
    } else {
      if (requestBufDuplication == null) {
        requestBufDuplication = requestBuf.duplicate();
      }
      return requestBufDuplication;
    }
  }

  public synchronized List<InetSocketAddress> getServerList() {
    if (servers == null) {
      return null;
    }
    return Collections.unmodifiableList(servers);
  }

  public synchronized TNonblockingMultiFetchStats getFetchStats() {
    return stats;
  }

  /**
   * Main entry function for fetching from servers.
   *
   * @return The fetched data.
   */
  public synchronized ByteBuffer[] fetch() {
    // clear previous results
    recvBuf = null;
    stats.clear();

    if (servers == null || servers.size() == 0 || requestBuf == null || fetchTimeoutSeconds <= 0) {
      return recvBuf;
    }

    ExecutorService executor = Executors.newSingleThreadExecutor();
    MultiFetch multiFetch = new MultiFetch();
    FutureTask<?> task = new FutureTask(multiFetch, null);
    executor.execute(task);
    try {
      task.get(fetchTimeoutSeconds, TimeUnit.SECONDS);
    } catch (InterruptedException ie) {
      // attempt to cancel execution of the task.
      task.cancel(true);
      LOGGER.error("Interrupted during fetch", ie);
    } catch (ExecutionException ee) {
      // attempt to cancel execution of the task.
      task.cancel(true);
      LOGGER.error("Exception during fetch", ee);
    } catch (TimeoutException te) {
      // attempt to cancel execution of the task.
      task.cancel(true);
      LOGGER.error("Timeout for fetch", te);
    }

    executor.shutdownNow();
    multiFetch.close();
    return recvBuf;
  }

  /**
   * Private class that does real fetch job. Users are not allowed to directly use this class, as
   * its run() function may run forever.
   */
  private class MultiFetch implements Runnable {
    private Selector selector;

    /**
     * main entry function for fetching.
     *
     * <p>Server responses are stored in TNonblocingMultiFetchClient.recvBuf, and fetch statistics
     * is in TNonblockingMultiFetchClient.stats.
     *
     * <p>Sanity check for parameters has been done in TNonblockingMultiFetchClient before calling
     * this function.
     */
    public void run() {
      long t1 = System.currentTimeMillis();

      int numTotalServers = servers.size();
      stats.setNumTotalServers(numTotalServers);

      // buffer for receiving response from servers
      recvBuf = new ByteBuffer[numTotalServers];
      // buffer for sending request
      ByteBuffer[] sendBuf = new ByteBuffer[numTotalServers];
      long[] numBytesRead = new long[numTotalServers];
      int[] frameSize = new int[numTotalServers];
      boolean[] hasReadFrameSize = new boolean[numTotalServers];

      try {
        selector = Selector.open();
      } catch (IOException ioe) {
        LOGGER.error("Selector opens error", ioe);
        return;
      }

      for (int i = 0; i < numTotalServers; i++) {
        // create buffer to send request to server.
        sendBuf[i] = requestBuf.duplicate();
        // create buffer to read response's frame size from server
        recvBuf[i] = ByteBuffer.allocate(4);
        stats.incTotalRecvBufBytes(4);

        InetSocketAddress server = servers.get(i);
        SocketChannel s = null;
        SelectionKey key = null;
        try {
          s = SocketChannel.open();
          s.configureBlocking(false);
          // now this method is non-blocking
          s.connect(server);
          key = s.register(selector, s.validOps());
          // attach index of the key
          key.attach(i);
        } catch (Exception e) {
          stats.incNumConnectErrorServers();
          LOGGER.error("Set up socket to server {} error", server, e);

          // free resource
          if (s != null) {
            try {
              s.close();
            } catch (Exception ex) {
            }
          }
          if (key != null) {
            key.cancel();
          }
        }
      }

      // wait for events
      while (stats.getNumReadCompletedServers() + stats.getNumConnectErrorServers()
          < stats.getNumTotalServers()) {
        // if the thread is interrupted (e.g., task is cancelled)
        if (Thread.currentThread().isInterrupted()) {
          return;
        }

        try {
          selector.select();
        } catch (Exception e) {
          LOGGER.error("Selector selects error", e);
          continue;
        }

        Iterator<SelectionKey> it = selector.selectedKeys().iterator();
        while (it.hasNext()) {
          SelectionKey selKey = it.next();
          it.remove();

          // get previously attached index
          int index = (Integer) selKey.attachment();

          if (selKey.isValid() && selKey.isConnectable()) {
            // if this socket throws an exception (e.g., connection refused),
            // print error msg and skip it.
            try {
              SocketChannel sChannel = (SocketChannel) selKey.channel();
              sChannel.finishConnect();
            } catch (Exception e) {
              stats.incNumConnectErrorServers();
              LOGGER.error("Socket {} connects to server {} error", index, servers.get(index), e);
            }
          }

          if (selKey.isValid() && selKey.isWritable()) {
            if (sendBuf[index].hasRemaining()) {
              // if this socket throws an exception, print error msg and
              // skip it.
              try {
                SocketChannel sChannel = (SocketChannel) selKey.channel();
                sChannel.write(sendBuf[index]);
              } catch (Exception e) {
                LOGGER.error("Socket {} writes to server {} error", index, servers.get(index), e);
              }
            }
          }

          if (selKey.isValid() && selKey.isReadable()) {
            // if this socket throws an exception, print error msg and
            // skip it.
            try {
              SocketChannel sChannel = (SocketChannel) selKey.channel();
              int bytesRead = sChannel.read(recvBuf[index]);

              if (bytesRead > 0) {
                numBytesRead[index] += bytesRead;

                if (!hasReadFrameSize[index] && recvBuf[index].remaining() == 0) {
                  // if the frame size has been read completely, then prepare
                  // to read the actual frame.
                  frameSize[index] = recvBuf[index].getInt(0);

                  if (frameSize[index] <= 0) {
                    stats.incNumInvalidFrameSize();
                    LOGGER.error(
                        "Read an invalid frame size {} from {}. Does the server use TFramedTransport?",
                        frameSize[index],
                        servers.get(index));
                    sChannel.close();
                    continue;
                  }

                  if (frameSize[index] + 4 > stats.getMaxResponseBytes()) {
                    stats.setMaxResponseBytes(frameSize[index] + 4);
                  }

                  if (frameSize[index] + 4 > maxRecvBufBytesPerServer) {
                    stats.incNumOverflowedRecvBuf();
                    LOGGER.error(
                        "Read frame size {} from {}, total buffer size would exceed limit {}",
                        frameSize[index],
                        servers.get(index),
                        maxRecvBufBytesPerServer);
                    sChannel.close();
                    continue;
                  }

                  // reallocate buffer for actual frame data
                  recvBuf[index] = ByteBuffer.allocate(frameSize[index] + 4);
                  recvBuf[index].putInt(frameSize[index]);

                  stats.incTotalRecvBufBytes(frameSize[index]);
                  hasReadFrameSize[index] = true;
                }

                if (hasReadFrameSize[index] && numBytesRead[index] >= frameSize[index] + 4) {
                  // has read all data
                  sChannel.close();
                  stats.incNumReadCompletedServers();
                  long t2 = System.currentTimeMillis();
                  stats.setReadTime(t2 - t1);
                }
              }
            } catch (Exception e) {
              LOGGER.error("Socket {} reads from server {} error", index, servers.get(index), e);
            }
          }
        }
      }
    }

    /** dispose any resource allocated */
    public void close() {
      try {
        if (selector.isOpen()) {
          Iterator<SelectionKey> it = selector.keys().iterator();
          while (it.hasNext()) {
            SelectionKey selKey = it.next();
            SocketChannel sChannel = (SocketChannel) selKey.channel();
            sChannel.close();
          }

          selector.close();
        }
      } catch (IOException e) {
        LOGGER.error("Free resource error", e);
      }
    }
  }
}