/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "Service.h" #include #include #include #include #include #include #include #include using __gnu_cxx::hash_map; using __gnu_cxx::hash; using namespace std; using namespace boost; using namespace apache::thrift; using namespace apache::thrift::protocol; using namespace apache::thrift::transport; using namespace apache::thrift::server; using namespace apache::thrift::concurrency; using namespace test::stress; struct eqstr { bool operator()(const char* s1, const char* s2) const { return strcmp(s1, s2) == 0; } }; struct ltstr { bool operator()(const char* s1, const char* s2) const { return strcmp(s1, s2) < 0; } }; // typedef hash_map, eqstr> count_map; typedef map count_map; class Server : public ServiceIf { public: Server() {} void count(const char* method) { Guard m(lock_); int ct = counts_[method]; counts_[method] = ++ct; } void echoVoid() { count("echoVoid"); // Sleep to simulate work usleep(5000); return; } count_map getCount() { Guard m(lock_); return counts_; } int8_t echoByte(const int8_t arg) {return arg;} int32_t echoI32(const int32_t arg) {return arg;} int64_t echoI64(const int64_t arg) {return arg;} void echoString(string& out, const string &arg) { if (arg != "hello") { T_ERROR_ABORT("WRONG STRING!!!!"); } out = arg; } void echoList(vector &out, const vector &arg) { out = arg; } void echoSet(set &out, const set &arg) { out = arg; } void echoMap(map &out, const map &arg) { out = arg; } private: count_map counts_; Mutex lock_; }; class ClientThread: public Runnable { public: ClientThread(shared_ptrtransport, shared_ptr client, Monitor& monitor, size_t& workerCount, size_t loopCount, TType loopType) : _transport(transport), _client(client), _monitor(monitor), _workerCount(workerCount), _loopCount(loopCount), _loopType(loopType) {} void run() { // Wait for all worker threads to start {Synchronized s(_monitor); while(_workerCount == 0) { _monitor.wait(); } } _startTime = Util::currentTime(); _transport->open(); switch(_loopType) { case T_VOID: loopEchoVoid(); break; case T_BYTE: loopEchoByte(); break; case T_I32: loopEchoI32(); break; case T_I64: loopEchoI64(); break; case T_STRING: loopEchoString(); break; default: cerr << "Unexpected loop type" << _loopType << endl; break; } _endTime = Util::currentTime(); _transport->close(); _done = true; {Synchronized s(_monitor); _workerCount--; if (_workerCount == 0) { _monitor.notify(); } } } void loopEchoVoid() { for (size_t ix = 0; ix < _loopCount; ix++) { _client->echoVoid(); } } void loopEchoByte() { for (size_t ix = 0; ix < _loopCount; ix++) { int8_t arg = 1; int8_t result; result =_client->echoByte(arg); assert(result == arg); } } void loopEchoI32() { for (size_t ix = 0; ix < _loopCount; ix++) { int32_t arg = 1; int32_t result; result =_client->echoI32(arg); assert(result == arg); } } void loopEchoI64() { for (size_t ix = 0; ix < _loopCount; ix++) { int64_t arg = 1; int64_t result; result =_client->echoI64(arg); assert(result == arg); } } void loopEchoString() { for (size_t ix = 0; ix < _loopCount; ix++) { string arg = "hello"; string result; _client->echoString(result, arg); assert(result == arg); } } shared_ptr _transport; shared_ptr _client; Monitor& _monitor; size_t& _workerCount; size_t _loopCount; TType _loopType; long long _startTime; long long _endTime; bool _done; Monitor _sleep; }; int main(int argc, char **argv) { int port = 9091; string serverType = "simple"; string protocolType = "binary"; size_t workerCount = 4; size_t clientCount = 20; size_t loopCount = 50000; TType loopType = T_VOID; string callName = "echoVoid"; bool runServer = true; bool logRequests = false; string requestLogPath = "./requestlog.tlog"; bool replayRequests = false; ostringstream usage; usage << argv[0] << " [--port=] [--server] [--server-type=] [--protocol-type=] [--workers=] [--clients=] [--loop=]" << endl << "\tclients Number of client threads to create - 0 implies no clients, i.e. server only. Default is " << clientCount << endl << "\thelp Prints this help text." << endl << "\tcall Service method to call. Default is " << callName << endl << "\tloop The number of remote thrift calls each client makes. Default is " << loopCount << endl << "\tport The port the server and clients should bind to for thrift network connections. Default is " << port << endl << "\tserver Run the Thrift server in this process. Default is " << runServer << endl << "\tserver-type Type of server, \"simple\" or \"thread-pool\". Default is " << serverType << endl << "\tprotocol-type Type of protocol, \"binary\", \"ascii\", or \"xml\". Default is " << protocolType << endl << "\tlog-request Log all request to ./requestlog.tlog. Default is " << logRequests << endl << "\treplay-request Replay requests from log file (./requestlog.tlog) Default is " << replayRequests << endl << "\tworkers Number of thread pools workers. Only valid for thread-pool server type. Default is " << workerCount << endl; map args; for (int ix = 1; ix < argc; ix++) { string arg(argv[ix]); if (arg.compare(0,2, "--") == 0) { size_t end = arg.find_first_of("=", 2); string key = string(arg, 2, end - 2); if (end != string::npos) { args[key] = string(arg, end + 1); } else { args[key] = "true"; } } else { throw invalid_argument("Unexcepted command line token: "+arg); } } try { if (!args["clients"].empty()) { clientCount = atoi(args["clients"].c_str()); } if (!args["help"].empty()) { cerr << usage.str(); return 0; } if (!args["loop"].empty()) { loopCount = atoi(args["loop"].c_str()); } if (!args["call"].empty()) { callName = args["call"]; } if (!args["port"].empty()) { port = atoi(args["port"].c_str()); } if (!args["server"].empty()) { runServer = args["server"] == "true"; } if (!args["log-request"].empty()) { logRequests = args["log-request"] == "true"; } if (!args["replay-request"].empty()) { replayRequests = args["replay-request"] == "true"; } if (!args["server-type"].empty()) { serverType = args["server-type"]; } if (!args["workers"].empty()) { workerCount = atoi(args["workers"].c_str()); } } catch(exception& e) { cerr << e.what() << endl; cerr << usage; } shared_ptr threadFactory = shared_ptr(new PosixThreadFactory()); // Dispatcher shared_ptr serviceHandler(new Server()); if (replayRequests) { shared_ptr serviceHandler(new Server()); shared_ptr serviceProcessor(new ServiceProcessor(serviceHandler)); // Transports shared_ptr fileTransport(new TFileTransport(requestLogPath)); fileTransport->setChunkSize(2 * 1024 * 1024); fileTransport->setMaxEventSize(1024 * 16); fileTransport->seekToEnd(); // Protocol Factory shared_ptr protocolFactory(new TBinaryProtocolFactory()); TFileProcessor fileProcessor(serviceProcessor, protocolFactory, fileTransport); fileProcessor.process(0, true); exit(0); } if (runServer) { shared_ptr serviceProcessor(new ServiceProcessor(serviceHandler)); // Protocol Factory shared_ptr protocolFactory(new TBinaryProtocolFactory()); // Transport Factory shared_ptr transportFactory; if (logRequests) { // initialize the log file shared_ptr fileTransport(new TFileTransport(requestLogPath)); fileTransport->setChunkSize(2 * 1024 * 1024); fileTransport->setMaxEventSize(1024 * 16); transportFactory = shared_ptr(new TPipedTransportFactory(fileTransport)); } shared_ptr serverThread; shared_ptr serverThread2; if (serverType == "simple") { serverThread = threadFactory->newThread(shared_ptr(new TNonblockingServer(serviceProcessor, protocolFactory, port))); serverThread2 = threadFactory->newThread(shared_ptr(new TNonblockingServer(serviceProcessor, protocolFactory, port+1))); } else if (serverType == "thread-pool") { shared_ptr threadManager = ThreadManager::newSimpleThreadManager(workerCount); threadManager->threadFactory(threadFactory); threadManager->start(); serverThread = threadFactory->newThread(shared_ptr(new TNonblockingServer(serviceProcessor, protocolFactory, port, threadManager))); serverThread2 = threadFactory->newThread(shared_ptr(new TNonblockingServer(serviceProcessor, protocolFactory, port+1, threadManager))); } cerr << "Starting the server on port " << port << " and " << (port + 1) << endl; serverThread->start(); serverThread2->start(); // If we aren't running clients, just wait forever for external clients if (clientCount == 0) { serverThread->join(); serverThread2->join(); } } sleep(1); if (clientCount > 0) { Monitor monitor; size_t threadCount = 0; set > clientThreads; if (callName == "echoVoid") { loopType = T_VOID;} else if (callName == "echoByte") { loopType = T_BYTE;} else if (callName == "echoI32") { loopType = T_I32;} else if (callName == "echoI64") { loopType = T_I64;} else if (callName == "echoString") { loopType = T_STRING;} else {throw invalid_argument("Unknown service call "+callName);} for (size_t ix = 0; ix < clientCount; ix++) { shared_ptr socket(new TSocket("127.0.0.1", port + (ix % 2))); shared_ptr framedSocket(new TFramedTransport(socket)); shared_ptr protocol(new TBinaryProtocol(framedSocket)); shared_ptr serviceClient(new ServiceClient(protocol)); clientThreads.insert(threadFactory->newThread(shared_ptr(new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType)))); } for (std::set >::const_iterator thread = clientThreads.begin(); thread != clientThreads.end(); thread++) { (*thread)->start(); } long long time00; long long time01; {Synchronized s(monitor); threadCount = clientCount; cerr << "Launch "<< clientCount << " client threads" << endl; time00 = Util::currentTime(); monitor.notifyAll(); while(threadCount > 0) { monitor.wait(); } time01 = Util::currentTime(); } long long firstTime = 9223372036854775807LL; long long lastTime = 0; double averageTime = 0; long long minTime = 9223372036854775807LL; long long maxTime = 0; for (set >::iterator ix = clientThreads.begin(); ix != clientThreads.end(); ix++) { shared_ptr client = dynamic_pointer_cast((*ix)->runnable()); long long delta = client->_endTime - client->_startTime; assert(delta > 0); if (client->_startTime < firstTime) { firstTime = client->_startTime; } if (client->_endTime > lastTime) { lastTime = client->_endTime; } if (delta < minTime) { minTime = delta; } if (delta > maxTime) { maxTime = delta; } averageTime+= delta; } averageTime /= clientCount; cout << "workers :" << workerCount << ", client : " << clientCount << ", loops : " << loopCount << ", rate : " << (clientCount * loopCount * 1000) / ((double)(time01 - time00)) << endl; count_map count = serviceHandler->getCount(); count_map::iterator iter; for (iter = count.begin(); iter != count.end(); ++iter) { printf("%s => %d\n", iter->first, iter->second); } cerr << "done." << endl; } return 0; }