diff options
Diffstat (limited to 'lib')
417 files changed, 64921 insertions, 0 deletions
diff --git a/lib/Makefile.am b/lib/Makefile.am new file mode 100644 index 000000000..3558dd803 --- /dev/null +++ b/lib/Makefile.am @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SUBDIRS = \ + cpp + +if WITH_MONO +SUBDIRS += csharp +endif + +if WITH_JAVA +SUBDIRS += java +endif + +if WITH_PYTHON +SUBDIRS += py +endif + +if WITH_ERLANG +SUBDIRS += erl +endif + +if WITH_RUBY +SUBDIRS += rb +endif + +if WITH_PERL +SUBDIRS += perl +endif + +# All of the libs that don't use Automake need to go in here +# so they will end up in our release tarballs. +EXTRA_DIST = \ + cocoa \ + hs \ + ocaml \ + php \ + erl \ + st diff --git a/lib/cocoa/README b/lib/cocoa/README new file mode 100644 index 000000000..bbe3c934f --- /dev/null +++ b/lib/cocoa/README @@ -0,0 +1,21 @@ +Thrift Cocoa Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. diff --git a/lib/cocoa/src/TApplicationException.h b/lib/cocoa/src/TApplicationException.h new file mode 100644 index 000000000..cf1641d96 --- /dev/null +++ b/lib/cocoa/src/TApplicationException.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TException.h" +#import "TProtocol.h" + +enum { + TApplicationException_UNKNOWN = 0, + TApplicationException_UNKNOWN_METHOD = 1, + TApplicationException_INVALID_MESSAGE_TYPE = 2, + TApplicationException_WRONG_METHOD_NAME = 3, + TApplicationException_BAD_SEQUENCE_ID = 4, + TApplicationException_MISSING_RESULT = 5 +}; + +// FIXME +@interface TApplicationException : TException { + int mType; +} + ++ (TApplicationException *) read: (id <TProtocol>) protocol; + +- (void) write: (id <TProtocol>) protocol; + ++ (TApplicationException *) exceptionWithType: (int) type + reason: (NSString *) message; + +@end diff --git a/lib/cocoa/src/TApplicationException.m b/lib/cocoa/src/TApplicationException.m new file mode 100644 index 000000000..706875377 --- /dev/null +++ b/lib/cocoa/src/TApplicationException.m @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TApplicationException.h" +#import "TProtocolUtil.h" + +@implementation TApplicationException + +- (id) initWithType: (int) type + reason: (NSString *) reason +{ + mType = type; + + NSString * name; + switch (type) { + case TApplicationException_UNKNOWN_METHOD: + name = @"Unknown method"; + break; + case TApplicationException_INVALID_MESSAGE_TYPE: + name = @"Invalid message type"; + break; + case TApplicationException_WRONG_METHOD_NAME: + name = @"Wrong method name"; + break; + case TApplicationException_BAD_SEQUENCE_ID: + name = @"Bad sequence ID"; + break; + case TApplicationException_MISSING_RESULT: + name = @"Missing result"; + break; + default: + name = @"Unknown"; + break; + } + + self = [super initWithName: name reason: reason userInfo: nil]; + return self; +} + + ++ (TApplicationException *) read: (id <TProtocol>) protocol +{ + NSString * reason = nil; + int type = TApplicationException_UNKNOWN; + int fieldType; + int fieldID; + + [protocol readStructBeginReturningName: NULL]; + + while (true) { + [protocol readFieldBeginReturningName: NULL + type: &fieldType + fieldID: &fieldID]; + if (fieldType == TType_STOP) { + break; + } + switch (fieldID) { + case 1: + if (fieldType == TType_STRING) { + reason = [protocol readString]; + } else { + [TProtocolUtil skipType: fieldType onProtocol: protocol]; + } + break; + case 2: + if (fieldType == TType_I32) { + type = [protocol readI32]; + } else { + [TProtocolUtil skipType: fieldType onProtocol: protocol]; + } + break; + default: + [TProtocolUtil skipType: fieldType onProtocol: protocol]; + break; + } + [protocol readFieldEnd]; + } + [protocol readStructEnd]; + + return [TApplicationException exceptionWithType: type reason: reason]; +} + + +- (void) write: (id <TProtocol>) protocol +{ + [protocol writeStructBeginWithName: @"TApplicationException"]; + + if ([self reason] != nil) { + [protocol writeFieldBeginWithName: @"message" + type: TType_STRING + fieldID: 1]; + [protocol writeString: [self reason]]; + [protocol writeFieldEnd]; + } + + [protocol writeFieldBeginWithName: @"type" + type: TType_I32 + fieldID: 2]; + [protocol writeI32: mType]; + [protocol writeFieldEnd]; + + [protocol writeFieldStop]; + [protocol writeStructEnd]; +} + + ++ (TApplicationException *) exceptionWithType: (int) type + reason: (NSString *) reason +{ + return [[[TApplicationException alloc] initWithType: type + reason: reason] autorelease]; +} + +@end diff --git a/lib/cocoa/src/TException.h b/lib/cocoa/src/TException.h new file mode 100644 index 000000000..b069a8683 --- /dev/null +++ b/lib/cocoa/src/TException.h @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> + +@interface TException : NSException { +} + ++ (id) exceptionWithName: (NSString *) name; + ++ (id) exceptionWithName: (NSString *) name + reason: (NSString *) reason; + ++ (id) exceptionWithName: (NSString *) name + reason: (NSString *) reason + error: (NSError *) error; + +@end diff --git a/lib/cocoa/src/TException.m b/lib/cocoa/src/TException.m new file mode 100644 index 000000000..7c84199d8 --- /dev/null +++ b/lib/cocoa/src/TException.m @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TException.h" + +@implementation TException + ++ (id) exceptionWithName: (NSString *) name +{ + return [self exceptionWithName: name reason: @"unknown" error: nil]; +} + + ++ (id) exceptionWithName: (NSString *) name + reason: (NSString *) reason +{ + return [self exceptionWithName: name reason: reason error: nil]; +} + + ++ (id) exceptionWithName: (NSString *) name + reason: (NSString *) reason + error: (NSError *) error +{ + NSDictionary * userInfo = nil; + if (error != nil) { + userInfo = [NSDictionary dictionaryWithObject: error forKey: @"error"]; + } + + return [super exceptionWithName: name + reason: reason + userInfo: userInfo]; +} + + +- (NSString *) description +{ + NSMutableString * result = [NSMutableString stringWithString: [self name]]; + [result appendFormat: @": %@", [self reason]]; + if ([self userInfo] != nil) { + [result appendFormat: @"\n userInfo = %@", [self userInfo]]; + } + + return result; +} + + +@end diff --git a/lib/cocoa/src/TProcessor.h b/lib/cocoa/src/TProcessor.h new file mode 100644 index 000000000..f8df225ee --- /dev/null +++ b/lib/cocoa/src/TProcessor.h @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> + + +@protocol TProcessor <NSObject> + +- (BOOL) processOnInputProtocol: (id <TProtocol>) inProtocol + outputProtocol: (id <TProtocol>) outProtocol; + +@end diff --git a/lib/cocoa/src/protocol/TBinaryProtocol.h b/lib/cocoa/src/protocol/TBinaryProtocol.h new file mode 100644 index 000000000..52cf2669a --- /dev/null +++ b/lib/cocoa/src/protocol/TBinaryProtocol.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TProtocol.h" +#import "TTransport.h" +#import "TProtocolFactory.h" + + +@interface TBinaryProtocol : NSObject <TProtocol> { + id <TTransport> mTransport; + BOOL mStrictRead; + BOOL mStrictWrite; + int32_t mMessageSizeLimit; +} + +- (id) initWithTransport: (id <TTransport>) transport; + +- (id) initWithTransport: (id <TTransport>) transport + strictRead: (BOOL) strictRead + strictWrite: (BOOL) strictWrite; + +- (int32_t) messageSizeLimit; +- (void) setMessageSizeLimit: (int32_t) sizeLimit; + +@end; + + +@interface TBinaryProtocolFactory : NSObject <TProtocolFactory> { +} + ++ (TBinaryProtocolFactory *) sharedFactory; + +- (TBinaryProtocol *) newProtocolOnTransport: (id <TTransport>) transport; + +@end diff --git a/lib/cocoa/src/protocol/TBinaryProtocol.m b/lib/cocoa/src/protocol/TBinaryProtocol.m new file mode 100644 index 000000000..ba7f4629d --- /dev/null +++ b/lib/cocoa/src/protocol/TBinaryProtocol.m @@ -0,0 +1,469 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TBinaryProtocol.h" +#import "TProtocolException.h" + +int32_t VERSION_1 = 0x80010000; +int32_t VERSION_MASK = 0xffff0000; + + +static TBinaryProtocolFactory * gSharedFactory = nil; + +@implementation TBinaryProtocolFactory + ++ (TBinaryProtocolFactory *) sharedFactory { + if (gSharedFactory == nil) { + gSharedFactory = [[TBinaryProtocolFactory alloc] init]; + } + + return gSharedFactory; +} + +- (TBinaryProtocol *) newProtocolOnTransport: (id <TTransport>) transport { + return [[[TBinaryProtocol alloc] initWithTransport: transport] autorelease]; +} + +@end + + + +@implementation TBinaryProtocol + +- (id) initWithTransport: (id <TTransport>) transport +{ + return [self initWithTransport: transport strictRead: NO strictWrite: NO]; +} + +- (id) initWithTransport: (id <TTransport>) transport + strictRead: (BOOL) strictRead + strictWrite: (BOOL) strictWrite +{ + self = [super init]; + mTransport = [transport retain]; + mStrictRead = strictRead; + mStrictWrite = strictWrite; + return self; +} + + +- (int32_t) messageSizeLimit +{ + return mMessageSizeLimit; +} + + +- (void) setMessageSizeLimit: (int32_t) sizeLimit +{ + mMessageSizeLimit = sizeLimit; +} + + +- (void) dealloc +{ + [mTransport release]; + [super dealloc]; +} + + +- (id <TTransport>) transport +{ + return mTransport; +} + + +- (NSString *) readStringBody: (int) size +{ + char buff[size+1]; + [mTransport readAll: (uint8_t *) buff offset: 0 length: size]; + buff[size] = 0; + return [NSString stringWithUTF8String: buff]; +} + + +- (void) readMessageBeginReturningName: (NSString **) name + type: (int *) type + sequenceID: (int *) sequenceID +{ + int32_t size = [self readI32]; + if (size < 0) { + int version = size & VERSION_MASK; + if (version != VERSION_1) { + @throw [TProtocolException exceptionWithName: @"TProtocolException" + reason: @"Bad version in readMessageBegin"]; + } + if (type != NULL) { + *type = version & 0x00FF; + } + NSString * messageName = [self readString]; + if (name != NULL) { + *name = messageName; + } + int seqID = [self readI32]; + if (sequenceID != NULL) { + *sequenceID = seqID; + } + } else { + if (mStrictRead) { + @throw [TProtocolException exceptionWithName: @"TProtocolException" + reason: @"Missing version in readMessageBegin, old client?"]; + } + if ([self messageSizeLimit] > 0 && size > [self messageSizeLimit]) { + @throw [TProtocolException exceptionWithName: @"TProtocolException" + reason: [NSString stringWithFormat: @"Message too big. Size limit is: %d Message size is: %d", + mMessageSizeLimit, + size]]; + } + NSString * messageName = [self readStringBody: size]; + if (name != NULL) { + *name = messageName; + } + int messageType = [self readByte]; + if (type != NULL) { + *type = messageType; + } + int seqID = [self readI32]; + if (sequenceID != NULL) { + *sequenceID = seqID; + } + } +} + + +- (void) readMessageEnd {} + + +- (void) readStructBeginReturningName: (NSString **) name +{ + if (name != NULL) { + *name = nil; + } +} + + +- (void) readStructEnd {} + + +- (void) readFieldBeginReturningName: (NSString **) name + type: (int *) fieldType + fieldID: (int *) fieldID +{ + if (name != NULL) { + *name = nil; + } + int ft = [self readByte]; + if (fieldType != NULL) { + *fieldType = ft; + } + if (ft != TType_STOP) { + int fid = [self readI16]; + if (fieldID != NULL) { + *fieldID = fid; + } + } +} + + +- (void) readFieldEnd {} + + +- (int32_t) readI32 +{ + uint8_t i32rd[4]; + [mTransport readAll: i32rd offset: 0 length: 4]; + return + ((i32rd[0] & 0xff) << 24) | + ((i32rd[1] & 0xff) << 16) | + ((i32rd[2] & 0xff) << 8) | + ((i32rd[3] & 0xff)); +} + + +- (NSString *) readString +{ + int size = [self readI32]; + return [self readStringBody: size]; +} + + +- (BOOL) readBool +{ + return [self readByte] == 1; +} + +- (uint8_t) readByte +{ + uint8_t myByte; + [mTransport readAll: &myByte offset: 0 length: 1]; + return myByte; +} + +- (short) readI16 +{ + uint8_t buff[2]; + [mTransport readAll: buff offset: 0 length: 2]; + return (short) + (((buff[0] & 0xff) << 8) | + ((buff[1] & 0xff))); + return 0; +} + +- (int64_t) readI64; +{ + uint8_t i64rd[8]; + [mTransport readAll: i64rd offset: 0 length: 8]; + return + ((int64_t)(i64rd[0] & 0xff) << 56) | + ((int64_t)(i64rd[1] & 0xff) << 48) | + ((int64_t)(i64rd[2] & 0xff) << 40) | + ((int64_t)(i64rd[3] & 0xff) << 32) | + ((int64_t)(i64rd[4] & 0xff) << 24) | + ((int64_t)(i64rd[5] & 0xff) << 16) | + ((int64_t)(i64rd[6] & 0xff) << 8) | + ((int64_t)(i64rd[7] & 0xff)); +} + +- (double) readDouble; +{ + // FIXME - will this get us into trouble on PowerPC? + int64_t ieee754 = [self readI64]; + return *((double *) &ieee754); +} + + +- (NSData *) readBinary +{ + int32_t size = [self readI32]; + uint8_t * buff = malloc(size); + if (buff == NULL) { + @throw [TProtocolException + exceptionWithName: @"TProtocolException" + reason: [NSString stringWithFormat: @"Out of memory. Unable to allocate %d bytes trying to read binary data.", + size]]; + } + [mTransport readAll: buff offset: 0 length: size]; + return [NSData dataWithBytesNoCopy: buff length: size]; +} + + +- (void) readMapBeginReturningKeyType: (int *) keyType + valueType: (int *) valueType + size: (int *) size +{ + int kt = [self readByte]; + int vt = [self readByte]; + int s = [self readI32]; + if (keyType != NULL) { + *keyType = kt; + } + if (valueType != NULL) { + *valueType = vt; + } + if (size != NULL) { + *size = s; + } +} + +- (void) readMapEnd {} + + +- (void) readSetBeginReturningElementType: (int *) elementType + size: (int *) size +{ + int et = [self readByte]; + int s = [self readI32]; + if (elementType != NULL) { + *elementType = et; + } + if (size != NULL) { + *size = s; + } +} + + +- (void) readSetEnd {} + + +- (void) readListBeginReturningElementType: (int *) elementType + size: (int *) size +{ + int et = [self readByte]; + int s = [self readI32]; + if (elementType != NULL) { + *elementType = et; + } + if (size != NULL) { + *size = s; + } +} + + +- (void) readListEnd {} + + +- (void) writeByte: (uint8_t) value +{ + [mTransport write: &value offset: 0 length: 1]; +} + + +- (void) writeMessageBeginWithName: (NSString *) name + type: (int) messageType + sequenceID: (int) sequenceID +{ + if (mStrictWrite) { + int version = VERSION_1 | messageType; + [self writeI32: version]; + [self writeString: name]; + [self writeI32: sequenceID]; + } else { + [self writeString: name]; + [self writeByte: messageType]; + [self writeI32: sequenceID]; + } +} + + +- (void) writeMessageEnd {} + + +- (void) writeStructBeginWithName: (NSString *) name {} + + +- (void) writeStructEnd {} + + +- (void) writeFieldBeginWithName: (NSString *) name + type: (int) fieldType + fieldID: (int) fieldID +{ + [self writeByte: fieldType]; + [self writeI16: fieldID]; +} + + +- (void) writeI32: (int32_t) value +{ + uint8_t buff[4]; + buff[0] = 0xFF & (value >> 24); + buff[1] = 0xFF & (value >> 16); + buff[2] = 0xFF & (value >> 8); + buff[3] = 0xFF & value; + [mTransport write: buff offset: 0 length: 4]; +} + +- (void) writeI16: (short) value +{ + uint8_t buff[2]; + buff[0] = 0xff & (value >> 8); + buff[1] = 0xff & value; + [mTransport write: buff offset: 0 length: 2]; +} + + +- (void) writeI64: (int64_t) value +{ + uint8_t buff[8]; + buff[0] = 0xFF & (value >> 56); + buff[1] = 0xFF & (value >> 48); + buff[2] = 0xFF & (value >> 40); + buff[3] = 0xFF & (value >> 32); + buff[4] = 0xFF & (value >> 24); + buff[5] = 0xFF & (value >> 16); + buff[6] = 0xFF & (value >> 8); + buff[7] = 0xFF & value; + [mTransport write: buff offset: 0 length: 8]; +} + +- (void) writeDouble: (double) value +{ + // spit out IEEE 754 bits - FIXME - will this get us in trouble on + // PowerPC? + [self writeI64: *((int64_t *) &value)]; +} + + +- (void) writeString: (NSString *) value +{ + if (value != nil) { + const char * utf8Bytes = [value UTF8String]; + size_t length = strlen(utf8Bytes); + [self writeI32: length]; + [mTransport write: (uint8_t *) utf8Bytes offset: 0 length: length]; + } else { + // instead of crashing when we get null, let's write out a zero + // length string + [self writeI32: 0]; + } +} + + +- (void) writeBinary: (NSData *) data +{ + [self writeI32: [data length]]; + [mTransport write: [data bytes] offset: 0 length: [data length]]; +} + +- (void) writeFieldStop +{ + [self writeByte: TType_STOP]; +} + + +- (void) writeFieldEnd {} + + +- (void) writeMapBeginWithKeyType: (int) keyType + valueType: (int) valueType + size: (int) size +{ + [self writeByte: keyType]; + [self writeByte: valueType]; + [self writeI32: size]; +} + +- (void) writeMapEnd {} + + +- (void) writeSetBeginWithElementType: (int) elementType + size: (int) size +{ + [self writeByte: elementType]; + [self writeI32: size]; +} + +- (void) writeSetEnd {} + + +- (void) writeListBeginWithElementType: (int) elementType + size: (int) size +{ + [self writeByte: elementType]; + [self writeI32: size]; +} + +- (void) writeListEnd {} + + +- (void) writeBool: (BOOL) value +{ + [self writeByte: (value ? 1 : 0)]; +} + +@end diff --git a/lib/cocoa/src/protocol/TProtocol.h b/lib/cocoa/src/protocol/TProtocol.h new file mode 100644 index 000000000..cc8cdb4bf --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocol.h @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> + +#import "TTransport.h" + + +enum { + TMessageType_CALL = 1, + TMessageType_REPLY = 2, + TMessageType_EXCEPTION = 3, + TMessageType_ONEWAY = 4 +}; + +enum { + TType_STOP = 0, + TType_VOID = 1, + TType_BOOL = 2, + TType_BYTE = 3, + TType_DOUBLE = 4, + TType_I16 = 6, + TType_I32 = 8, + TType_I64 = 10, + TType_STRING = 11, + TType_STRUCT = 12, + TType_MAP = 13, + TType_SET = 14, + TType_LIST = 15 +}; + + +@protocol TProtocol <NSObject> + +- (id <TTransport>) transport; + +- (void) readMessageBeginReturningName: (NSString **) name + type: (int *) type + sequenceID: (int *) sequenceID; +- (void) readMessageEnd; + +- (void) readStructBeginReturningName: (NSString **) name; +- (void) readStructEnd; + +- (void) readFieldBeginReturningName: (NSString **) name + type: (int *) fieldType + fieldID: (int *) fieldID; +- (void) readFieldEnd; + +- (NSString *) readString; + +- (BOOL) readBool; + +- (unsigned char) readByte; + +- (short) readI16; + +- (int32_t) readI32; + +- (int64_t) readI64; + +- (double) readDouble; + +- (NSData *) readBinary; + +- (void) readMapBeginReturningKeyType: (int *) keyType + valueType: (int *) valueType + size: (int *) size; +- (void) readMapEnd; + + +- (void) readSetBeginReturningElementType: (int *) elementType + size: (int *) size; +- (void) readSetEnd; + + +- (void) readListBeginReturningElementType: (int *) elementType + size: (int *) size; +- (void) readListEnd; + + +- (void) writeMessageBeginWithName: (NSString *) name + type: (int) messageType + sequenceID: (int) sequenceID; +- (void) writeMessageEnd; + +- (void) writeStructBeginWithName: (NSString *) name; +- (void) writeStructEnd; + +- (void) writeFieldBeginWithName: (NSString *) name + type: (int) fieldType + fieldID: (int) fieldID; + +- (void) writeI32: (int32_t) value; + +- (void) writeI64: (int64_t) value; + +- (void) writeI16: (short) value; + +- (void) writeByte: (uint8_t) value; + +- (void) writeString: (NSString *) value; + +- (void) writeDouble: (double) value; + +- (void) writeBool: (BOOL) value; + +- (void) writeBinary: (NSData *) data; + +- (void) writeFieldStop; + +- (void) writeFieldEnd; + +- (void) writeMapBeginWithKeyType: (int) keyType + valueType: (int) valueType + size: (int) size; +- (void) writeMapEnd; + + +- (void) writeSetBeginWithElementType: (int) elementType + size: (int) size; +- (void) writeSetEnd; + + +- (void) writeListBeginWithElementType: (int) elementType + size: (int) size; + +- (void) writeListEnd; + + +@end + diff --git a/lib/cocoa/src/protocol/TProtocolException.h b/lib/cocoa/src/protocol/TProtocolException.h new file mode 100644 index 000000000..ad354fc2f --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolException.h @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TException.h" + +@interface TProtocolException : TException { +} + +@end diff --git a/lib/cocoa/src/protocol/TProtocolException.m b/lib/cocoa/src/protocol/TProtocolException.m new file mode 100644 index 000000000..681487a43 --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolException.m @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TProtocolException.h" + +@implementation TProtocolException +@end diff --git a/lib/cocoa/src/protocol/TProtocolFactory.h b/lib/cocoa/src/protocol/TProtocolFactory.h new file mode 100644 index 000000000..2d125e96f --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolFactory.h @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> +#import "TProtocol.h" +#import "TTransport.h" + + +@protocol TProtocolFactory <NSObject> + +- (id <TProtocol>) newProtocolOnTransport: (id <TTransport>) transport; + +@end diff --git a/lib/cocoa/src/protocol/TProtocolUtil.h b/lib/cocoa/src/protocol/TProtocolUtil.h new file mode 100644 index 000000000..c2d2521cc --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolUtil.h @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TProtocol.h" +#import "TTransport.h" + +@interface TProtocolUtil : NSObject { + +} + ++ (void) skipType: (int) type onProtocol: (id <TProtocol>) protocol; + +@end; diff --git a/lib/cocoa/src/protocol/TProtocolUtil.m b/lib/cocoa/src/protocol/TProtocolUtil.m new file mode 100644 index 000000000..13d70954f --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolUtil.m @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TProtocolUtil.h" + +@implementation TProtocolUtil + ++ (void) skipType: (int) type onProtocol: (id <TProtocol>) protocol +{ + switch (type) { + case TType_BOOL: + [protocol readBool]; + break; + case TType_BYTE: + [protocol readByte]; + break; + case TType_I16: + [protocol readI16]; + break; + case TType_I32: + [protocol readI32]; + break; + case TType_I64: + [protocol readI64]; + break; + case TType_DOUBLE: + [protocol readDouble]; + break; + case TType_STRING: + [protocol readString]; + break; + case TType_STRUCT: + [protocol readStructBeginReturningName: NULL]; + while (true) { + int fieldType; + [protocol readFieldBeginReturningName: nil type: &fieldType fieldID: nil]; + if (fieldType == TType_STOP) { + break; + } + [TProtocolUtil skipType: fieldType onProtocol: protocol]; + [protocol readFieldEnd]; + } + [protocol readStructEnd]; + break; + case TType_MAP: + { + int keyType; + int valueType; + int size; + [protocol readMapBeginReturningKeyType: &keyType valueType: &valueType size: &size]; + int i; + for (i = 0; i < size; i++) { + [TProtocolUtil skipType: keyType onProtocol: protocol]; + [TProtocolUtil skipType: valueType onProtocol: protocol]; + } + [protocol readMapEnd]; + } + break; + case TType_SET: + { + int elemType; + int size; + [protocol readSetBeginReturningElementType: &elemType size: &size]; + int i; + for (i = 0; i < size; i++) { + [TProtocolUtil skipType: elemType onProtocol: protocol]; + } + [protocol readSetEnd]; + } + break; + case TType_LIST: + { + int elemType; + int size; + [protocol readListBeginReturningElementType: &elemType size: &size]; + int i; + for (i = 0; i < size; i++) { + [TProtocolUtil skipType: elemType onProtocol: protocol]; + } + [protocol readListEnd]; + } + break; + default: + return; + } +} + +@end diff --git a/lib/cocoa/src/server/TSocketServer.h b/lib/cocoa/src/server/TSocketServer.h new file mode 100644 index 000000000..3d4a9e0e7 --- /dev/null +++ b/lib/cocoa/src/server/TSocketServer.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> +#import "TProtocolFactory.h" +#import "TProcessor.h" + + +@interface TSocketServer : NSObject { + NSSocketPort * mServerSocket; + NSFileHandle * mSocketFileHandle; + id <TProtocolFactory> mInputProtocolFactory; + id <TProtocolFactory> mOutputProtocolFactory; + id <TProcessor> mProcessor; +} + +- (id) initWithPort: (int) port + protocolFactory: (id <TProtocolFactory>) protocolFactory + processor: (id <TProcessor>) processor; + +@end + + + diff --git a/lib/cocoa/src/server/TSocketServer.m b/lib/cocoa/src/server/TSocketServer.m new file mode 100644 index 000000000..97d8bae7b --- /dev/null +++ b/lib/cocoa/src/server/TSocketServer.m @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> +#import "TSocketServer.h" +#import "TNSFileHandleTransport.h" +#import "TProtocol.h" +#import "TTransportException.h" + + +@implementation TSocketServer + +- (id) initWithPort: (int) port + protocolFactory: (id <TProtocolFactory>) protocolFactory + processor: (id <TProcessor>) processor; +{ + self = [super init]; + + mInputProtocolFactory = [protocolFactory retain]; + mOutputProtocolFactory = [protocolFactory retain]; + mProcessor = [processor retain]; + + // create a socket + mServerSocket = [[NSSocketPort alloc] initWithTCPPort: port]; + // FIXME - move this separate start method and add method to close + // and cleanup any open ports + + if (mServerSocket == nil) { + NSLog(@"Unable to listen on TCP port %d", port); + } else { + NSLog(@"Listening on TCP port %d", port); + + // wrap it in a file handle so we can get messages from it + mSocketFileHandle = [[NSFileHandle alloc] initWithFileDescriptor: [mServerSocket socket] + closeOnDealloc: YES]; + + // register for notifications of accepted incoming connections + [[NSNotificationCenter defaultCenter] addObserver: self + selector: @selector(connectionAccepted:) + name: NSFileHandleConnectionAcceptedNotification + object: mSocketFileHandle]; + + // tell socket to listen + [mSocketFileHandle acceptConnectionInBackgroundAndNotify]; + } + + return self; +} + + +- (void) dealloc { + [mInputProtocolFactory release]; + [mOutputProtocolFactory release]; + [mProcessor release]; + [mSocketFileHandle release]; + [mServerSocket release]; + [super dealloc]; +} + + +- (void) connectionAccepted: (NSNotification *) aNotification +{ + NSFileHandle * socket = [[aNotification userInfo] objectForKey: NSFileHandleNotificationFileHandleItem]; + + // now that we have a client connected, spin off a thread to handle activity + [NSThread detachNewThreadSelector: @selector(handleClientConnection:) + toTarget: self + withObject: socket]; + + [[aNotification object] acceptConnectionInBackgroundAndNotify]; +} + + +- (void) handleClientConnection: (NSFileHandle *) clientSocket +{ + NSAutoreleasePool * pool = [[NSAutoreleasePool alloc] init]; + + TNSFileHandleTransport * transport = [[TNSFileHandleTransport alloc] initWithFileHandle: clientSocket]; + + id <TProtocol> inProtocol = [mInputProtocolFactory newProtocolOnTransport: transport]; + id <TProtocol> outProtocol = [mOutputProtocolFactory newProtocolOnTransport: transport]; + + @try { + while ([mProcessor processOnInputProtocol: inProtocol outputProtocol: outProtocol]); + } + @catch (TTransportException * te) { + NSLog(@"%@", te); + } + + [pool release]; +} + + + +@end + + + diff --git a/lib/cocoa/src/transport/THTTPClient.h b/lib/cocoa/src/transport/THTTPClient.h new file mode 100644 index 000000000..86f3f0548 --- /dev/null +++ b/lib/cocoa/src/transport/THTTPClient.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> +#import "TTransport.h" + +@interface THTTPClient : NSObject <TTransport> { + NSURL * mURL; + NSMutableURLRequest * mRequest; + NSMutableData * mRequestData; + NSData * mResponseData; + int mResponseDataOffset; + NSString * mUserAgent; + int mTimeout; +} + +- (id) initWithURL: (NSURL *) aURL; + +- (id) initWithURL: (NSURL *) aURL + userAgent: (NSString *) userAgent + timeout: (int) timeout; + +- (void) setURL: (NSURL *) aURL; + +@end + diff --git a/lib/cocoa/src/transport/THTTPClient.m b/lib/cocoa/src/transport/THTTPClient.m new file mode 100644 index 000000000..6391bead8 --- /dev/null +++ b/lib/cocoa/src/transport/THTTPClient.m @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "THTTPClient.h" +#import "TTransportException.h" + +@implementation THTTPClient + + +- (void) setupRequest +{ + if (mRequest != nil) { + [mRequest release]; + } + + // set up our request object that we'll use for each request + mRequest = [[NSMutableURLRequest alloc] initWithURL: mURL]; + [mRequest setHTTPMethod: @"POST"]; + [mRequest setValue: @"application/x-thrift" forHTTPHeaderField: @"Content-Type"]; + [mRequest setValue: @"application/x-thrift" forHTTPHeaderField: @"Accept"]; + + NSString * userAgent = mUserAgent; + if (!userAgent) { + userAgent = @"Cocoa/THTTPClient"; + } + [mRequest setValue: userAgent forHTTPHeaderField: @"User-Agent"]; + + [mRequest setCachePolicy: NSURLRequestReloadIgnoringCacheData]; + if (mTimeout) { + [mRequest setTimeoutInterval: mTimeout]; + } +} + + +- (id) initWithURL: (NSURL *) aURL +{ + return [self initWithURL: aURL + userAgent: nil + timeout: 0]; +} + + +- (id) initWithURL: (NSURL *) aURL + userAgent: (NSString *) userAgent + timeout: (int) timeout +{ + self = [super init]; + if (!self) { + return nil; + } + + mTimeout = timeout; + if (userAgent) { + mUserAgent = [userAgent retain]; + } + mURL = [aURL retain]; + + [self setupRequest]; + + // create our request data buffer + mRequestData = [[NSMutableData alloc] initWithCapacity: 1024]; + + return self; +} + + +- (void) setURL: (NSURL *) aURL +{ + [aURL retain]; + [mURL release]; + mURL = aURL; + + [self setupRequest]; +} + + +- (void) dealloc +{ + [mURL release]; + [mUserAgent release]; + [mRequest release]; + [mRequestData release]; + [mResponseData release]; + [super dealloc]; +} + + +- (int) readAll: (uint8_t *) buf offset: (int) off length: (int) len +{ + NSRange r; + r.location = mResponseDataOffset; + r.length = len; + + [mResponseData getBytes: buf+off range: r]; + mResponseDataOffset += len; + + return len; +} + + +- (void) write: (const uint8_t *) data offset: (unsigned int) offset length: (unsigned int) length +{ + [mRequestData appendBytes: data+offset length: length]; +} + + +- (void) flush +{ + [mRequest setHTTPBody: mRequestData]; // not sure if it copies the data + + // make the HTTP request + NSURLResponse * response; + NSError * error; + NSData * responseData = + [NSURLConnection sendSynchronousRequest: mRequest returningResponse: &response error: &error]; + + [mRequestData setLength: 0]; + + if (responseData == nil) { + @throw [TTransportException exceptionWithName: @"TTransportException" + reason: @"Could not make HTTP request" + error: error]; + } + if (![response isKindOfClass: [NSHTTPURLResponse class]]) { + @throw [TTransportException exceptionWithName: @"TTransportException" + reason: @"Unexpected NSURLResponse type"]; + } + + NSHTTPURLResponse * httpResponse = (NSHTTPURLResponse *) response; + if ([httpResponse statusCode] != 200) { + @throw [TTransportException exceptionWithName: @"TTransportException" + reason: [NSString stringWithFormat: @"Bad response from HTTP server: %d", + [httpResponse statusCode]]]; + } + + // phew! + [mResponseData release]; + mResponseData = [responseData retain]; + mResponseDataOffset = 0; +} + + +@end diff --git a/lib/cocoa/src/transport/TNSFileHandleTransport.h b/lib/cocoa/src/transport/TNSFileHandleTransport.h new file mode 100644 index 000000000..64a6af3cf --- /dev/null +++ b/lib/cocoa/src/transport/TNSFileHandleTransport.h @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +#import <Cocoa/Cocoa.h> +#import "TTransport.h" + +@interface TNSFileHandleTransport : NSObject <TTransport> { + NSFileHandle * mInputFileHandle; + NSFileHandle * mOutputFileHandle; +} + +- (id) initWithFileHandle: (NSFileHandle *) fileHandle; + +- (id) initWithInputFileHandle: (NSFileHandle *) inputFileHandle + outputFileHandle: (NSFileHandle *) outputFileHandle; + + +@end diff --git a/lib/cocoa/src/transport/TNSFileHandleTransport.m b/lib/cocoa/src/transport/TNSFileHandleTransport.m new file mode 100644 index 000000000..153393414 --- /dev/null +++ b/lib/cocoa/src/transport/TNSFileHandleTransport.m @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +#import "TNSFileHandleTransport.h" +#import "TTransportException.h" + + +@implementation TNSFileHandleTransport + +- (id) initWithFileHandle: (NSFileHandle *) fileHandle +{ + return [self initWithInputFileHandle: fileHandle + outputFileHandle: fileHandle]; +} + + +- (id) initWithInputFileHandle: (NSFileHandle *) inputFileHandle + outputFileHandle: (NSFileHandle *) outputFileHandle +{ + self = [super init]; + + mInputFileHandle = [inputFileHandle retain]; + mOutputFileHandle = [outputFileHandle retain]; + + return self; +} + + +- (void) dealloc { + [mInputFileHandle release]; + [mOutputFileHandle release]; + [super dealloc]; +} + + +- (int) readAll: (uint8_t *) buf offset: (int) off length: (int) len +{ + int got = 0; + while (got < len) { + NSData * d = [mInputFileHandle readDataOfLength: len-got]; + if ([d length] == 0) { + @throw [TTransportException exceptionWithName: @"TTransportException" + reason: @"Cannot read. No more data."]; + } + [d getBytes: buf+got]; + got += [d length]; + } + return got; +} + + +- (void) write: (uint8_t *) data offset: (unsigned int) offset length: (unsigned int) length +{ + NSData * dataObject = [[NSData alloc] initWithBytesNoCopy: data+offset + length: length + freeWhenDone: NO]; + + [mOutputFileHandle writeData: dataObject]; + + + [dataObject release]; +} + + +- (void) flush +{ + +} + +@end diff --git a/lib/cocoa/src/transport/TNSStreamTransport.h b/lib/cocoa/src/transport/TNSStreamTransport.h new file mode 100644 index 000000000..295a185cd --- /dev/null +++ b/lib/cocoa/src/transport/TNSStreamTransport.h @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> +#import "TTransport.h" + +@interface TNSStreamTransport : NSObject <TTransport> { + NSInputStream * mInput; + NSOutputStream * mOutput; +} + +- (id) initWithInputStream: (NSInputStream *) input + outputStream: (NSOutputStream *) output; + +- (id) initWithInputStream: (NSInputStream *) input; + +- (id) initWithOutputStream: (NSOutputStream *) output; + +@end + + + diff --git a/lib/cocoa/src/transport/TNSStreamTransport.m b/lib/cocoa/src/transport/TNSStreamTransport.m new file mode 100644 index 000000000..52a02e277 --- /dev/null +++ b/lib/cocoa/src/transport/TNSStreamTransport.m @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TNSStreamTransport.h" +#import "TTransportException.h" + + +@implementation TNSStreamTransport + +- (id) initWithInputStream: (NSInputStream *) input + outputStream: (NSOutputStream *) output +{ + [super init]; + mInput = [input retain]; + mOutput = [output retain]; + return self; +} + +- (id) initWithInputStream: (NSInputStream *) input +{ + return [self initWithInputStream: input outputStream: nil]; +} + +- (id) initWithOutputStream: (NSOutputStream *) output +{ + return [self initWithInputStream: nil outputStream: output]; +} + +- (void) dealloc +{ + [mInput release]; + [mOutput release]; + [super dealloc]; +} + + +- (int) readAll: (uint8_t *) buf offset: (int) off length: (int) len +{ + int got = 0; + int ret = 0; + while (got < len) { + ret = [mInput read: buf+off+got maxLength: len-got]; + if (ret <= 0) { + @throw [TTransportException exceptionWithReason: @"Cannot read. Remote side has closed."]; + } + got += ret; + } + return got; +} + + +// FIXME:geech:20071019 - make this write all +- (void) write: (uint8_t *) data offset: (unsigned int) offset length: (unsigned int) length +{ + int result = [mOutput write: data+offset maxLength: length]; + if (result == -1) { + @throw [TTransportException exceptionWithReason: @"Error writing to transport output stream." + error: [mOutput streamError]]; + } else if (result == 0) { + @throw [TTransportException exceptionWithReason: @"End of output stream."]; + } else if (result != length) { + @throw [TTransportException exceptionWithReason: @"Output stream did not write all of our data."]; + } +} + +- (void) flush +{ + // no flush for you! +} + +@end diff --git a/lib/cocoa/src/transport/TSocketClient.h b/lib/cocoa/src/transport/TSocketClient.h new file mode 100644 index 000000000..a883acbb2 --- /dev/null +++ b/lib/cocoa/src/transport/TSocketClient.h @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> +#import "TNSStreamTransport.h" + +@interface TSocketClient : TNSStreamTransport { +} + +- (id) initWithHostname: (NSString *) hostname + port: (int) port; + +@end + + + diff --git a/lib/cocoa/src/transport/TSocketClient.m b/lib/cocoa/src/transport/TSocketClient.m new file mode 100644 index 000000000..7c07c5613 --- /dev/null +++ b/lib/cocoa/src/transport/TSocketClient.m @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import <Cocoa/Cocoa.h> +#import "TSocketClient.h" + +@implementation TSocketClient + +- (id) initWithHostname: (NSString *) hostname + port: (int) port +{ + NSInputStream * input = nil; + NSOutputStream * output = nil; + + [NSStream getStreamsToHost: [NSHost hostWithName: hostname] + port: port + inputStream: &input + outputStream: &output]; + + self = [super initWithInputStream: input outputStream: output]; + [input open]; + [output open]; + + return self; +} + + +@end + + + diff --git a/lib/cocoa/src/transport/TTransport.h b/lib/cocoa/src/transport/TTransport.h new file mode 100644 index 000000000..61ebbd214 --- /dev/null +++ b/lib/cocoa/src/transport/TTransport.h @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +@protocol TTransport <NSObject> + + /** + * Guarantees that all of len bytes are read + * + * @param buf Buffer to read into + * @param off Index in buffer to start storing bytes at + * @param len Maximum number of bytes to read + * @return The number of bytes actually read, which must be equal to len + * @throws TTransportException if there was an error reading data + */ +- (int) readAll: (uint8_t *) buf offset: (int) off length: (int) len; + +- (void) write: (const uint8_t *) data offset: (unsigned int) offset length: (unsigned int) length; + +- (void) flush; +@end diff --git a/lib/cocoa/src/transport/TTransportException.h b/lib/cocoa/src/transport/TTransportException.h new file mode 100644 index 000000000..6749fe289 --- /dev/null +++ b/lib/cocoa/src/transport/TTransportException.h @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TException.h" + +@interface TTransportException : TException { +} + ++ (id) exceptionWithReason: (NSString *) reason + error: (NSError *) error; + ++ (id) exceptionWithReason: (NSString *) reason; + +@end diff --git a/lib/cocoa/src/transport/TTransportException.m b/lib/cocoa/src/transport/TTransportException.m new file mode 100644 index 000000000..aa67149e1 --- /dev/null +++ b/lib/cocoa/src/transport/TTransportException.m @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TTransportException.h" + +@implementation TTransportException + ++ (id) exceptionWithReason: (NSString *) reason + error: (NSError *) error +{ + NSDictionary * userInfo = nil; + if (error != nil) { + userInfo = [NSDictionary dictionaryWithObject: error forKey: @"error"]; + } + + return [super exceptionWithName: @"TTransportException" + reason: reason + userInfo: userInfo]; +} + + ++ (id) exceptionWithReason: (NSString *) reason +{ + return [self exceptionWithReason: reason error: nil]; +} + +@end diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am new file mode 100644 index 000000000..dc0b6ae5f --- /dev/null +++ b/lib/cpp/Makefile.am @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +ACLOCAL_AMFLAGS = -I ./aclocal + +pkgconfigdir = $(libdir)/pkgconfig + +lib_LTLIBRARIES = libthrift.la +pkgconfig_DATA = thrift.pc + +## We only build the extra libraries if we have the dependencies, +## but we install all of the headers unconditionally. +if AMX_HAVE_LIBEVENT +lib_LTLIBRARIES += libthriftnb.la +pkgconfig_DATA += thrift-nb.pc +endif +if AMX_HAVE_ZLIB +lib_LTLIBRARIES += libthriftz.la +pkgconfig_DATA += thrift-z.pc +endif + +AM_CXXFLAGS = -Wall +AM_CPPFLAGS = $(BOOST_CPPFLAGS) -I$(srcdir)/src + +# Define the source files for the module + +libthrift_la_SOURCES = src/Thrift.cpp \ + src/concurrency/Mutex.cpp \ + src/concurrency/Monitor.cpp \ + src/concurrency/PosixThreadFactory.cpp \ + src/concurrency/ThreadManager.cpp \ + src/concurrency/TimerManager.cpp \ + src/concurrency/Util.cpp \ + src/protocol/TBinaryProtocol.cpp \ + src/protocol/TCompactProtocol.cpp \ + src/protocol/TDebugProtocol.cpp \ + src/protocol/TDenseProtocol.cpp \ + src/protocol/TJSONProtocol.cpp \ + src/protocol/TBase64Utils.cpp \ + src/transport/TTransportException.cpp \ + src/transport/TFDTransport.cpp \ + src/transport/TFileTransport.cpp \ + src/transport/TSimpleFileTransport.cpp \ + src/transport/THttpClient.cpp \ + src/transport/TSocket.cpp \ + src/transport/TSocketPool.cpp \ + src/transport/TServerSocket.cpp \ + src/transport/TTransportUtils.cpp \ + src/transport/TBufferTransports.cpp \ + src/server/TServer.cpp \ + src/server/TSimpleServer.cpp \ + src/server/TThreadPoolServer.cpp \ + src/server/TThreadedServer.cpp \ + src/processor/PeekProcessor.cpp + +libthriftnb_la_SOURCES = src/server/TNonblockingServer.cpp + +libthriftz_la_SOURCES = src/transport/TZlibTransport.cpp + + +# Flags for the various libraries +libthriftnb_la_CPPFLAGS = $(AM_CPPFLAGS) $(LIBEVENT_CPPFLAGS) +libthriftz_la_CPPFLAGS = $(AM_CPPFLAGS) $(ZLIB_CPPFLAGS) + + +include_thriftdir = $(includedir)/thrift +include_thrift_HEADERS = \ + $(top_builddir)/config.h \ + src/Thrift.h \ + src/TReflectionLocal.h \ + src/TProcessor.h \ + src/TLogging.h + +include_concurrencydir = $(include_thriftdir)/concurrency +include_concurrency_HEADERS = \ + src/concurrency/Exception.h \ + src/concurrency/Mutex.h \ + src/concurrency/Monitor.h \ + src/concurrency/PosixThreadFactory.h \ + src/concurrency/Thread.h \ + src/concurrency/ThreadManager.h \ + src/concurrency/TimerManager.h \ + src/concurrency/FunctionRunner.h \ + src/concurrency/Util.h + +include_protocoldir = $(include_thriftdir)/protocol +include_protocol_HEADERS = \ + src/protocol/TBinaryProtocol.h \ + src/protocol/TCompactProtocol.h \ + src/protocol/TDenseProtocol.h \ + src/protocol/TDebugProtocol.h \ + src/protocol/TOneWayProtocol.h \ + src/protocol/TBase64Utils.h \ + src/protocol/TJSONProtocol.h \ + src/protocol/TProtocolTap.h \ + src/protocol/TProtocolException.h \ + src/protocol/TProtocol.h + +include_transportdir = $(include_thriftdir)/transport +include_transport_HEADERS = \ + src/transport/TFDTransport.h \ + src/transport/TFileTransport.h \ + src/transport/TSimpleFileTransport.h \ + src/transport/TServerSocket.h \ + src/transport/TServerTransport.h \ + src/transport/THttpClient.h \ + src/transport/TSocket.h \ + src/transport/TSocketPool.h \ + src/transport/TTransport.h \ + src/transport/TTransportException.h \ + src/transport/TTransportUtils.h \ + src/transport/TBufferTransports.h \ + src/transport/TShortReadTransport.h \ + src/transport/TZlibTransport.h + +include_serverdir = $(include_thriftdir)/server +include_server_HEADERS = \ + src/server/TServer.h \ + src/server/TSimpleServer.h \ + src/server/TThreadPoolServer.h \ + src/server/TThreadedServer.h \ + src/server/TNonblockingServer.h + +include_processordir = $(include_thriftdir)/processor +include_processor_HEADERS = \ + src/processor/PeekProcessor.h \ + src/processor/StatsProcessor.h + +noinst_PROGRAMS = concurrency_test + +concurrency_test_SOURCES = src/concurrency/test/Tests.cpp \ + src/concurrency/test/ThreadFactoryTests.h \ + src/concurrency/test/ThreadManagerTests.h \ + src/concurrency/test/TimerManagerTests.h + +concurrency_test_LDADD = libthrift.la + +EXTRA_DIST = \ + README \ + thrift-nb.pc.in \ + thrift.pc.in \ + thrift-z.pc.in diff --git a/lib/cpp/README b/lib/cpp/README new file mode 100644 index 000000000..576d01702 --- /dev/null +++ b/lib/cpp/README @@ -0,0 +1,67 @@ +Thrift C++ Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with C++ +===================== + +The Thrift C++ libraries are built using the GNU tools. Follow the instructions +in the top-level README, or run bootstrap.sh in this folder to generate the +Makefiles. + +In case you do not want to open another README file, do this: + ./bootstrap.sh + ./configure (--with-boost=/usr/local) + make + sudo make install + +Thrift is divided into two libraries. + +libthrift + The core Thrift library contains all the core Thrift code. It requires + boost shared pointers, pthreads, and librt. + +libthriftnb + This library contains the Thrift nonblocking server, which uses libevent. + To link this library you will also need to link libevent. + +Linking Against Thrift +====================== + +After you build and install Thrift the libraries are installed to +/usr/local/lib by default. Make sure this is in your LDPATH. + +On Linux, the best way to do this is to ensure that /usr/local/lib is in +your /etc/ld.so.conf and then run /sbin/ldconfig. + +Depending upon whether you are linking dynamically or statically and how +your build environment it set up, you may need to include additional +libraries when linking against thrift, such as librt and/or libpthread. If +you are using libthriftnb you will also need libevent. + +Dependencies +============ + +boost shared pointers +http://www.boost.org/libs/smart_ptr/smart_ptr.htm + +libevent (for libthriftnb only) +http://monkey.org/~provos/libevent/ diff --git a/lib/cpp/src/TLogging.h b/lib/cpp/src/TLogging.h new file mode 100644 index 000000000..2df82dd7e --- /dev/null +++ b/lib/cpp/src/TLogging.h @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TLOGGING_H_ +#define _THRIFT_TLOGGING_H_ 1 + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +/** + * Contains utility macros for debugging and logging. + * + */ + +#ifndef HAVE_CLOCK_GETTIME +#include <time.h> +#else +#include <sys/time.h> +#endif + +#ifdef HAVE_STDINT_H +#include <stdint.h> +#endif + +/** + * T_GLOBAL_DEBUGGING_LEVEL = 0: all debugging turned off, debug macros undefined + * T_GLOBAL_DEBUGGING_LEVEL = 1: all debugging turned on + */ +#define T_GLOBAL_DEBUGGING_LEVEL 0 + + +/** + * T_GLOBAL_LOGGING_LEVEL = 0: all logging turned off, logging macros undefined + * T_GLOBAL_LOGGING_LEVEL = 1: all logging turned on + */ +#define T_GLOBAL_LOGGING_LEVEL 1 + + +/** + * Standard wrapper around fprintf what will prefix the file name and line + * number to the line. Uses T_GLOBAL_DEBUGGING_LEVEL to control whether it is + * turned on or off. + * + * @param format_string + */ +#if T_GLOBAL_DEBUGGING_LEVEL > 0 + #define T_DEBUG(format_string,...) \ + if (T_GLOBAL_DEBUGGING_LEVEL > 0) { \ + fprintf(stderr,"[%s,%d] " #format_string " \n", __FILE__, __LINE__,##__VA_ARGS__); \ + } +#else + #define T_DEBUG(format_string,...) +#endif + + +/** + * analagous to T_DEBUG but also prints the time + * + * @param string format_string input: printf style format string + */ +#if T_GLOBAL_DEBUGGING_LEVEL > 0 + #define T_DEBUG_T(format_string,...) \ + { \ + if (T_GLOBAL_DEBUGGING_LEVEL > 0) { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s,%d] [%s] " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \ + } \ + } +#else + #define T_DEBUG_T(format_string,...) +#endif + + +/** + * analagous to T_DEBUG but uses input level to determine whether or not the string + * should be logged. + * + * @param int level: specified debug level + * @param string format_string input: format string + */ +#define T_DEBUG_L(level, format_string,...) \ + if ((level) > 0) { \ + fprintf(stderr,"[%s,%d] " #format_string " \n", __FILE__, __LINE__,##__VA_ARGS__); \ + } + + +/** + * Explicit error logging. Prints time, file name and line number + * + * @param string format_string input: printf style format string + */ +#define T_ERROR(format_string,...) \ + { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s,%d] [%s] ERROR: " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \ + } + + +/** + * Analagous to T_ERROR, additionally aborting the process. + * WARNING: macro calls abort(), ending program execution + * + * @param string format_string input: printf style format string + */ +#define T_ERROR_ABORT(format_string,...) \ + { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s,%d] [%s] ERROR: Going to abort " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \ + exit(1); \ + } + + +/** + * Log input message + * + * @param string format_string input: printf style format string + */ +#if T_GLOBAL_LOGGING_LEVEL > 0 + #define T_LOG_OPER(format_string,...) \ + { \ + if (T_GLOBAL_LOGGING_LEVEL > 0) { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s] " #format_string " \n", dbgtime,##__VA_ARGS__); \ + } \ + } +#else + #define T_LOG_OPER(format_string,...) +#endif + +#endif // #ifndef _THRIFT_TLOGGING_H_ diff --git a/lib/cpp/src/TProcessor.h b/lib/cpp/src/TProcessor.h new file mode 100644 index 000000000..f2d5279a2 --- /dev/null +++ b/lib/cpp/src/TProcessor.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TPROCESSOR_H_ +#define _THRIFT_TPROCESSOR_H_ 1 + +#include <string> +#include <protocol/TProtocol.h> +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { + +/** + * A processor is a generic object that acts upon two streams of data, one + * an input and the other an output. The definition of this object is loose, + * though the typical case is for some sort of server that either generates + * responses to an input stream or forwards data from one pipe onto another. + * + */ +class TProcessor { + public: + virtual ~TProcessor() {} + + virtual bool process(boost::shared_ptr<protocol::TProtocol> in, + boost::shared_ptr<protocol::TProtocol> out) = 0; + + bool process(boost::shared_ptr<apache::thrift::protocol::TProtocol> io) { + return process(io, io); + } + + protected: + TProcessor() {} +}; + +}} // apache::thrift + +#endif // #ifndef _THRIFT_PROCESSOR_H_ diff --git a/lib/cpp/src/TReflectionLocal.h b/lib/cpp/src/TReflectionLocal.h new file mode 100644 index 000000000..e83e47530 --- /dev/null +++ b/lib/cpp/src/TReflectionLocal.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TREFLECTIONLOCAL_H_ +#define _THRIFT_TREFLECTIONLOCAL_H_ 1 + +#include <stdint.h> +#include <cstring> +#include <protocol/TProtocol.h> + +/** + * Local Reflection is a blanket term referring to the the structure + * and generation of this particular representation of Thrift types. + * (It is called local because it cannot be serialized by Thrift). + * + */ + +namespace apache { namespace thrift { namespace reflection { namespace local { + +using apache::thrift::protocol::TType; + +// We include this many bytes of the structure's fingerprint when serializing +// a top-level structure. Long enough to make collisions unlikely, short +// enough to not significantly affect the amount of memory used. +const int FP_PREFIX_LEN = 4; + +struct FieldMeta { + int16_t tag; + bool is_optional; +}; + +struct TypeSpec { + TType ttype; + uint8_t fp_prefix[FP_PREFIX_LEN]; + + // Use an anonymous union here so we can fit two TypeSpecs in one cache line. + union { + struct { + // Use parallel arrays here for denser packing (of the arrays). + FieldMeta* metas; + TypeSpec** specs; + } tstruct; + struct { + TypeSpec *subtype1; + TypeSpec *subtype2; + } tcontainer; + }; + + // Static initialization of unions isn't really possible, + // so take the plunge and use constructors. + // Hopefully they'll be evaluated at compile time. + + TypeSpec(TType ttype) : ttype(ttype) { + std::memset(fp_prefix, 0, FP_PREFIX_LEN); + } + + TypeSpec(TType ttype, + const uint8_t* fingerprint, + FieldMeta* metas, + TypeSpec** specs) : + ttype(ttype) + { + std::memcpy(fp_prefix, fingerprint, FP_PREFIX_LEN); + tstruct.metas = metas; + tstruct.specs = specs; + } + + TypeSpec(TType ttype, TypeSpec* subtype1, TypeSpec* subtype2) : + ttype(ttype) + { + std::memset(fp_prefix, 0, FP_PREFIX_LEN); + tcontainer.subtype1 = subtype1; + tcontainer.subtype2 = subtype2; + } + +}; + +}}}} // apache::thrift::reflection::local + +#endif // #ifndef _THRIFT_TREFLECTIONLOCAL_H_ diff --git a/lib/cpp/src/Thrift.cpp b/lib/cpp/src/Thrift.cpp new file mode 100644 index 000000000..ed99205b8 --- /dev/null +++ b/lib/cpp/src/Thrift.cpp @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <Thrift.h> +#include <cstring> +#include <boost/lexical_cast.hpp> +#include <protocol/TProtocol.h> +#include <stdarg.h> +#include <stdio.h> + +namespace apache { namespace thrift { + +TOutput GlobalOutput; + +void TOutput::printf(const char *message, ...) { + // Try to reduce heap usage, even if printf is called rarely. + static const int STACK_BUF_SIZE = 256; + char stack_buf[STACK_BUF_SIZE]; + va_list ap; + + va_start(ap, message); + int need = vsnprintf(stack_buf, STACK_BUF_SIZE, message, ap); + va_end(ap); + + if (need < STACK_BUF_SIZE) { + f_(stack_buf); + return; + } + + char *heap_buf = (char*)malloc((need+1) * sizeof(char)); + if (heap_buf == NULL) { + // Malloc failed. We might as well print the stack buffer. + f_(stack_buf); + return; + } + + va_start(ap, message); + int rval = vsnprintf(heap_buf, need+1, message, ap); + va_end(ap); + // TODO(shigin): inform user + if (rval != -1) { + f_(heap_buf); + } + free(heap_buf); +} + +void TOutput::perror(const char *message, int errno_copy) { + std::string out = message + strerror_s(errno_copy); + f_(out.c_str()); +} + +std::string TOutput::strerror_s(int errno_copy) { +#ifndef HAVE_STRERROR_R + return "errno = " + boost::lexical_cast<std::string>(errno_copy); +#else // HAVE_STRERROR_R + + char b_errbuf[1024] = { '\0' }; +#ifdef STRERROR_R_CHAR_P + char *b_error = strerror_r(errno_copy, b_errbuf, sizeof(b_errbuf)); +#else + char *b_error = b_errbuf; + int rv = strerror_r(errno_copy, b_errbuf, sizeof(b_errbuf)); + if (rv == -1) { + // strerror_r failed. omgwtfbbq. + return "XSI-compliant strerror_r() failed with errno = " + + boost::lexical_cast<std::string>(errno_copy); + } +#endif + // Can anyone prove that explicit cast is probably not necessary + // to ensure that the string object is constructed before + // b_error becomes invalid? + return std::string(b_error); + +#endif // HAVE_STRERROR_R +} + +uint32_t TApplicationException::read(apache::thrift::protocol::TProtocol* iprot) { + uint32_t xfer = 0; + std::string fname; + apache::thrift::protocol::TType ftype; + int16_t fid; + + xfer += iprot->readStructBegin(fname); + + while (true) { + xfer += iprot->readFieldBegin(fname, ftype, fid); + if (ftype == apache::thrift::protocol::T_STOP) { + break; + } + switch (fid) { + case 1: + if (ftype == apache::thrift::protocol::T_STRING) { + xfer += iprot->readString(message_); + } else { + xfer += iprot->skip(ftype); + } + break; + case 2: + if (ftype == apache::thrift::protocol::T_I32) { + int32_t type; + xfer += iprot->readI32(type); + type_ = (TApplicationExceptionType)type; + } else { + xfer += iprot->skip(ftype); + } + break; + default: + xfer += iprot->skip(ftype); + break; + } + xfer += iprot->readFieldEnd(); + } + + xfer += iprot->readStructEnd(); + return xfer; +} + +uint32_t TApplicationException::write(apache::thrift::protocol::TProtocol* oprot) const { + uint32_t xfer = 0; + xfer += oprot->writeStructBegin("TApplicationException"); + xfer += oprot->writeFieldBegin("message", apache::thrift::protocol::T_STRING, 1); + xfer += oprot->writeString(message_); + xfer += oprot->writeFieldEnd(); + xfer += oprot->writeFieldBegin("type", apache::thrift::protocol::T_I32, 2); + xfer += oprot->writeI32(type_); + xfer += oprot->writeFieldEnd(); + xfer += oprot->writeFieldStop(); + xfer += oprot->writeStructEnd(); + return xfer; +} + +}} // apache::thrift diff --git a/lib/cpp/src/Thrift.h b/lib/cpp/src/Thrift.h new file mode 100644 index 000000000..26d2b0fcd --- /dev/null +++ b/lib/cpp/src/Thrift.h @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_THRIFT_H_ +#define _THRIFT_THRIFT_H_ 1 + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif +#include <stdio.h> + +#include <netinet/in.h> +#ifdef HAVE_INTTYPES_H +#include <inttypes.h> +#endif +#include <string> +#include <map> +#include <list> +#include <set> +#include <vector> +#include <exception> + +#include "TLogging.h" + +namespace apache { namespace thrift { + +class TOutput { + public: + TOutput() : f_(&errorTimeWrapper) {} + + inline void setOutputFunction(void (*function)(const char *)){ + f_ = function; + } + + inline void operator()(const char *message){ + f_(message); + } + + // It is important to have a const char* overload here instead of + // just the string version, otherwise errno could be corrupted + // if there is some problem allocating memory when constructing + // the string. + void perror(const char *message, int errno_copy); + inline void perror(const std::string &message, int errno_copy) { + perror(message.c_str(), errno_copy); + } + + void printf(const char *message, ...); + + inline static void errorTimeWrapper(const char* msg) { + time_t now; + char dbgtime[25]; + time(&now); + ctime_r(&now, dbgtime); + dbgtime[24] = 0; + fprintf(stderr, "Thrift: %s %s\n", dbgtime, msg); + } + + /** Just like strerror_r but returns a C++ string object. */ + static std::string strerror_s(int errno_copy); + + private: + void (*f_)(const char *); +}; + +extern TOutput GlobalOutput; + +namespace protocol { + class TProtocol; +} + +class TException : public std::exception { + public: + TException() {} + + TException(const std::string& message) : + message_(message) {} + + virtual ~TException() throw() {} + + virtual const char* what() const throw() { + if (message_.empty()) { + return "Default TException."; + } else { + return message_.c_str(); + } + } + + protected: + std::string message_; + +}; + +class TApplicationException : public TException { + public: + + /** + * Error codes for the various types of exceptions. + */ + enum TApplicationExceptionType + { UNKNOWN = 0 + , UNKNOWN_METHOD = 1 + , INVALID_MESSAGE_TYPE = 2 + , WRONG_METHOD_NAME = 3 + , BAD_SEQUENCE_ID = 4 + , MISSING_RESULT = 5 + }; + + TApplicationException() : + TException(), + type_(UNKNOWN) {} + + TApplicationException(TApplicationExceptionType type) : + TException(), + type_(type) {} + + TApplicationException(const std::string& message) : + TException(message), + type_(UNKNOWN) {} + + TApplicationException(TApplicationExceptionType type, + const std::string& message) : + TException(message), + type_(type) {} + + virtual ~TApplicationException() throw() {} + + /** + * Returns an error code that provides information about the type of error + * that has occurred. + * + * @return Error code + */ + TApplicationExceptionType getType() { + return type_; + } + + virtual const char* what() const throw() { + if (message_.empty()) { + switch (type_) { + case UNKNOWN : return "TApplicationException: Unknown application exception"; + case UNKNOWN_METHOD : return "TApplicationException: Unknown method"; + case INVALID_MESSAGE_TYPE : return "TApplicationException: Invalid message type"; + case WRONG_METHOD_NAME : return "TApplicationException: Wrong method name"; + case BAD_SEQUENCE_ID : return "TApplicationException: Bad sequence identifier"; + case MISSING_RESULT : return "TApplicationException: Missing result"; + default : return "TApplicationException: (Invalid exception type)"; + }; + } else { + return message_.c_str(); + } + } + + uint32_t read(protocol::TProtocol* iprot); + uint32_t write(protocol::TProtocol* oprot) const; + + protected: + /** + * Error code + */ + TApplicationExceptionType type_; + +}; + + +// Forward declare this structure used by TDenseProtocol +namespace reflection { namespace local { +struct TypeSpec; +}} + + +}} // apache::thrift + +#endif // #ifndef _THRIFT_THRIFT_H_ diff --git a/lib/cpp/src/concurrency/Exception.h b/lib/cpp/src/concurrency/Exception.h new file mode 100644 index 000000000..ec4662976 --- /dev/null +++ b/lib/cpp/src/concurrency/Exception.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_EXCEPTION_H_ +#define _THRIFT_CONCURRENCY_EXCEPTION_H_ 1 + +#include <exception> +#include <Thrift.h> + +namespace apache { namespace thrift { namespace concurrency { + +class NoSuchTaskException : public apache::thrift::TException {}; + +class UncancellableTaskException : public apache::thrift::TException {}; + +class InvalidArgumentException : public apache::thrift::TException {}; + +class IllegalStateException : public apache::thrift::TException {}; + +class TimedOutException : public apache::thrift::TException { +public: + TimedOutException():TException("TimedOutException"){}; + TimedOutException(const std::string& message ) : + TException(message) {} +}; + +class TooManyPendingTasksException : public apache::thrift::TException { +public: + TooManyPendingTasksException():TException("TooManyPendingTasksException"){}; + TooManyPendingTasksException(const std::string& message ) : + TException(message) {} +}; + +class SystemResourceException : public apache::thrift::TException { +public: + SystemResourceException() {} + + SystemResourceException(const std::string& message) : + TException(message) {} +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_EXCEPTION_H_ diff --git a/lib/cpp/src/concurrency/FunctionRunner.h b/lib/cpp/src/concurrency/FunctionRunner.h new file mode 100644 index 000000000..221692767 --- /dev/null +++ b/lib/cpp/src/concurrency/FunctionRunner.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H +#define _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H 1 + +#include <tr1/functional> +#include "thrift/lib/cpp/concurrency/Thread.h" + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Convenient implementation of Runnable that will execute arbitrary callbacks. + * Interfaces are provided to accept both a generic 'void(void)' callback, and + * a 'void* (void*)' pthread_create-style callback. + * + * Example use: + * void* my_thread_main(void* arg); + * shared_ptr<ThreadFactory> factory = ...; + * shared_ptr<Thread> thread = + * factory->newThread(shared_ptr<FunctionRunner>( + * new FunctionRunner(my_thread_main, some_argument))); + * thread->start(); + * + * + */ + +class FunctionRunner : public Runnable { + public: + // This is the type of callback 'pthread_create()' expects. + typedef void* (*PthreadFuncPtr)(void *arg); + // This a fully-generic void(void) callback for custom bindings. + typedef std::tr1::function<void()> VoidFunc; + + /** + * Given a 'pthread_create' style callback, this FunctionRunner will + * execute the given callback. Note that the 'void*' return value is ignored. + */ + FunctionRunner(PthreadFuncPtr func, void* arg) + : func_(std::tr1::bind(func, arg)) + { } + + /** + * Given a generic callback, this FunctionRunner will execute it. + */ + FunctionRunner(const VoidFunc& cob) + : func_(cob) + { } + + + void run() { + func_(); + } + + private: + VoidFunc func_; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H diff --git a/lib/cpp/src/concurrency/Monitor.cpp b/lib/cpp/src/concurrency/Monitor.cpp new file mode 100644 index 000000000..2055caa95 --- /dev/null +++ b/lib/cpp/src/concurrency/Monitor.cpp @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "Monitor.h" +#include "Exception.h" +#include "Util.h" + +#include <assert.h> +#include <errno.h> + +#include <iostream> + +#include <pthread.h> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Monitor implementation using the POSIX pthread library + * + * @version $Id:$ + */ +class Monitor::Impl { + + public: + + Impl() : + mutexInitialized_(false), + condInitialized_(false) { + + if (pthread_mutex_init(&pthread_mutex_, NULL) == 0) { + mutexInitialized_ = true; + + if (pthread_cond_init(&pthread_cond_, NULL) == 0) { + condInitialized_ = true; + } + } + + if (!mutexInitialized_ || !condInitialized_) { + cleanup(); + throw SystemResourceException(); + } + } + + ~Impl() { cleanup(); } + + void lock() const { pthread_mutex_lock(&pthread_mutex_); } + + void unlock() const { pthread_mutex_unlock(&pthread_mutex_); } + + void wait(int64_t timeout) const { + + // XXX Need to assert that caller owns mutex + assert(timeout >= 0LL); + if (timeout == 0LL) { + int iret = pthread_cond_wait(&pthread_cond_, &pthread_mutex_); + assert(iret == 0); + } else { + struct timespec abstime; + int64_t now = Util::currentTime(); + Util::toTimespec(abstime, now + timeout); + int result = pthread_cond_timedwait(&pthread_cond_, + &pthread_mutex_, + &abstime); + if (result == ETIMEDOUT) { + // pthread_cond_timedwait has been observed to return early on + // various platforms, so comment out this assert. + //assert(Util::currentTime() >= (now + timeout)); + throw TimedOutException(); + } + } + } + + void notify() { + // XXX Need to assert that caller owns mutex + int iret = pthread_cond_signal(&pthread_cond_); + assert(iret == 0); + } + + void notifyAll() { + // XXX Need to assert that caller owns mutex + int iret = pthread_cond_broadcast(&pthread_cond_); + assert(iret == 0); + } + + private: + + void cleanup() { + if (mutexInitialized_) { + mutexInitialized_ = false; + int iret = pthread_mutex_destroy(&pthread_mutex_); + assert(iret == 0); + } + + if (condInitialized_) { + condInitialized_ = false; + int iret = pthread_cond_destroy(&pthread_cond_); + assert(iret == 0); + } + } + + mutable pthread_mutex_t pthread_mutex_; + mutable bool mutexInitialized_; + mutable pthread_cond_t pthread_cond_; + mutable bool condInitialized_; +}; + +Monitor::Monitor() : impl_(new Monitor::Impl()) {} + +Monitor::~Monitor() { delete impl_; } + +void Monitor::lock() const { impl_->lock(); } + +void Monitor::unlock() const { impl_->unlock(); } + +void Monitor::wait(int64_t timeout) const { impl_->wait(timeout); } + +void Monitor::notify() const { impl_->notify(); } + +void Monitor::notifyAll() const { impl_->notifyAll(); } + +}}} // apache::thrift::concurrency diff --git a/lib/cpp/src/concurrency/Monitor.h b/lib/cpp/src/concurrency/Monitor.h new file mode 100644 index 000000000..234bf3269 --- /dev/null +++ b/lib/cpp/src/concurrency/Monitor.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_MONITOR_H_ +#define _THRIFT_CONCURRENCY_MONITOR_H_ 1 + +#include "Exception.h" + +namespace apache { namespace thrift { namespace concurrency { + +/** + * A monitor is a combination mutex and condition-event. Waiting and + * notifying condition events requires that the caller own the mutex. Mutex + * lock and unlock operations can be performed independently of condition + * events. This is more or less analogous to java.lang.Object multi-thread + * operations + * + * Note that all methods are const. Monitors implement logical constness, not + * bit constness. This allows const methods to call monitor methods without + * needing to cast away constness or change to non-const signatures. + * + * @version $Id:$ + */ +class Monitor { + + public: + + Monitor(); + + virtual ~Monitor(); + + virtual void lock() const; + + virtual void unlock() const; + + virtual void wait(int64_t timeout=0LL) const; + + virtual void notify() const; + + virtual void notifyAll() const; + + private: + + class Impl; + + Impl* impl_; +}; + +class Synchronized { + public: + + Synchronized(const Monitor& value) : + monitor_(value) { + monitor_.lock(); + } + + ~Synchronized() { + monitor_.unlock(); + } + + private: + const Monitor& monitor_; +}; + + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_MONITOR_H_ diff --git a/lib/cpp/src/concurrency/Mutex.cpp b/lib/cpp/src/concurrency/Mutex.cpp new file mode 100644 index 000000000..045dbdfe2 --- /dev/null +++ b/lib/cpp/src/concurrency/Mutex.cpp @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "Mutex.h" + +#include <assert.h> +#include <pthread.h> + +using boost::shared_ptr; + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Implementation of Mutex class using POSIX mutex + * + * @version $Id:$ + */ +class Mutex::impl { + public: + impl(Initializer init) : initialized_(false) { + init(&pthread_mutex_); + initialized_ = true; + } + + ~impl() { + if (initialized_) { + initialized_ = false; + int ret = pthread_mutex_destroy(&pthread_mutex_); + assert(ret == 0); + } + } + + void lock() const { pthread_mutex_lock(&pthread_mutex_); } + + bool trylock() const { return (0 == pthread_mutex_trylock(&pthread_mutex_)); } + + void unlock() const { pthread_mutex_unlock(&pthread_mutex_); } + + private: + mutable pthread_mutex_t pthread_mutex_; + mutable bool initialized_; +}; + +Mutex::Mutex(Initializer init) : impl_(new Mutex::impl(init)) {} + +void Mutex::lock() const { impl_->lock(); } + +bool Mutex::trylock() const { return impl_->trylock(); } + +void Mutex::unlock() const { impl_->unlock(); } + +void Mutex::DEFAULT_INITIALIZER(void* arg) { + pthread_mutex_t* pthread_mutex = (pthread_mutex_t*)arg; + int ret = pthread_mutex_init(pthread_mutex, NULL); + assert(ret == 0); +} + +static void init_with_kind(pthread_mutex_t* mutex, int kind) { + pthread_mutexattr_t mutexattr; + int ret = pthread_mutexattr_init(&mutexattr); + assert(ret == 0); + + // Apparently, this can fail. Should we really be aborting? + ret = pthread_mutexattr_settype(&mutexattr, kind); + assert(ret == 0); + + ret = pthread_mutex_init(mutex, &mutexattr); + assert(ret == 0); + + ret = pthread_mutexattr_destroy(&mutexattr); + assert(ret == 0); +} + +#ifdef PTHREAD_ADAPTIVE_MUTEX_INITIALIZER_NP +void Mutex::ADAPTIVE_INITIALIZER(void* arg) { + // From mysql source: mysys/my_thr_init.c + // Set mutex type to "fast" a.k.a "adaptive" + // + // In this case the thread may steal the mutex from some other thread + // that is waiting for the same mutex. This will save us some + // context switches but may cause a thread to 'starve forever' while + // waiting for the mutex (not likely if the code within the mutex is + // short). + init_with_kind((pthread_mutex_t*)arg, PTHREAD_MUTEX_ADAPTIVE_NP); +} +#endif + +#ifdef PTHREAD_RECURSIVE_MUTEX_INITIALIZER_NP +void Mutex::RECURSIVE_INITIALIZER(void* arg) { + init_with_kind((pthread_mutex_t*)arg, PTHREAD_MUTEX_RECURSIVE_NP); +} +#endif + + +/** + * Implementation of ReadWriteMutex class using POSIX rw lock + * + * @version $Id:$ + */ +class ReadWriteMutex::impl { +public: + impl() : initialized_(false) { + int ret = pthread_rwlock_init(&rw_lock_, NULL); + assert(ret == 0); + initialized_ = true; + } + + ~impl() { + if(initialized_) { + initialized_ = false; + int ret = pthread_rwlock_destroy(&rw_lock_); + assert(ret == 0); + } + } + + void acquireRead() const { pthread_rwlock_rdlock(&rw_lock_); } + + void acquireWrite() const { pthread_rwlock_wrlock(&rw_lock_); } + + bool attemptRead() const { return pthread_rwlock_tryrdlock(&rw_lock_); } + + bool attemptWrite() const { return pthread_rwlock_trywrlock(&rw_lock_); } + + void release() const { pthread_rwlock_unlock(&rw_lock_); } + +private: + mutable pthread_rwlock_t rw_lock_; + mutable bool initialized_; +}; + +ReadWriteMutex::ReadWriteMutex() : impl_(new ReadWriteMutex::impl()) {} + +void ReadWriteMutex::acquireRead() const { impl_->acquireRead(); } + +void ReadWriteMutex::acquireWrite() const { impl_->acquireWrite(); } + +bool ReadWriteMutex::attemptRead() const { return impl_->attemptRead(); } + +bool ReadWriteMutex::attemptWrite() const { return impl_->attemptWrite(); } + +void ReadWriteMutex::release() const { impl_->release(); } + +}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/concurrency/Mutex.h b/lib/cpp/src/concurrency/Mutex.h new file mode 100644 index 000000000..884412bea --- /dev/null +++ b/lib/cpp/src/concurrency/Mutex.h @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_MUTEX_H_ +#define _THRIFT_CONCURRENCY_MUTEX_H_ 1 + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * A simple mutex class + * + * @version $Id:$ + */ +class Mutex { + public: + typedef void (*Initializer)(void*); + + Mutex(Initializer init = DEFAULT_INITIALIZER); + virtual ~Mutex() {} + virtual void lock() const; + virtual bool trylock() const; + virtual void unlock() const; + + static void DEFAULT_INITIALIZER(void*); + static void ADAPTIVE_INITIALIZER(void*); + static void RECURSIVE_INITIALIZER(void*); + + private: + + class impl; + boost::shared_ptr<impl> impl_; +}; + +class ReadWriteMutex { +public: + ReadWriteMutex(); + virtual ~ReadWriteMutex() {} + + // these get the lock and block until it is done successfully + virtual void acquireRead() const; + virtual void acquireWrite() const; + + // these attempt to get the lock, returning false immediately if they fail + virtual bool attemptRead() const; + virtual bool attemptWrite() const; + + // this releases both read and write locks + virtual void release() const; + +private: + + class impl; + boost::shared_ptr<impl> impl_; +}; + +class Guard { + public: + Guard(const Mutex& value) : mutex_(value) { + mutex_.lock(); + } + ~Guard() { + mutex_.unlock(); + } + + private: + const Mutex& mutex_; +}; + +class RWGuard { + public: + RWGuard(const ReadWriteMutex& value, bool write = 0) : rw_mutex_(value) { + if (write) { + rw_mutex_.acquireWrite(); + } else { + rw_mutex_.acquireRead(); + } + } + ~RWGuard() { + rw_mutex_.release(); + } + private: + const ReadWriteMutex& rw_mutex_; +}; + + +// A little hack to prevent someone from trying to do "Guard(m);" +// Sorry for polluting the global namespace, but I think it's worth it. +#define Guard(m) incorrect_use_of_Guard(m) +#define RWGuard(m) incorrect_use_of_RWGuard(m) + + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_MUTEX_H_ diff --git a/lib/cpp/src/concurrency/PosixThreadFactory.cpp b/lib/cpp/src/concurrency/PosixThreadFactory.cpp new file mode 100644 index 000000000..e48dce39e --- /dev/null +++ b/lib/cpp/src/concurrency/PosixThreadFactory.cpp @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "PosixThreadFactory.h" +#include "Exception.h" + +#if GOOGLE_PERFTOOLS_REGISTER_THREAD +# include <google/profiler.h> +#endif + +#include <assert.h> +#include <pthread.h> + +#include <iostream> + +#include <boost/weak_ptr.hpp> + +namespace apache { namespace thrift { namespace concurrency { + +using boost::shared_ptr; +using boost::weak_ptr; + +/** + * The POSIX thread class. + * + * @version $Id:$ + */ +class PthreadThread: public Thread { + public: + + enum STATE { + uninitialized, + starting, + started, + stopping, + stopped + }; + + static const int MB = 1024 * 1024; + + static void* threadMain(void* arg); + + private: + pthread_t pthread_; + STATE state_; + int policy_; + int priority_; + int stackSize_; + weak_ptr<PthreadThread> self_; + bool detached_; + + public: + + PthreadThread(int policy, int priority, int stackSize, bool detached, shared_ptr<Runnable> runnable) : + pthread_(0), + state_(uninitialized), + policy_(policy), + priority_(priority), + stackSize_(stackSize), + detached_(detached) { + + this->Thread::runnable(runnable); + } + + ~PthreadThread() { + /* Nothing references this thread, if is is not detached, do a join + now, otherwise the thread-id and, possibly, other resources will + be leaked. */ + if(!detached_) { + try { + join(); + } catch(...) { + // We're really hosed. + } + } + } + + void start() { + if (state_ != uninitialized) { + return; + } + + pthread_attr_t thread_attr; + if (pthread_attr_init(&thread_attr) != 0) { + throw SystemResourceException("pthread_attr_init failed"); + } + + if(pthread_attr_setdetachstate(&thread_attr, + detached_ ? + PTHREAD_CREATE_DETACHED : + PTHREAD_CREATE_JOINABLE) != 0) { + throw SystemResourceException("pthread_attr_setdetachstate failed"); + } + + // Set thread stack size + if (pthread_attr_setstacksize(&thread_attr, MB * stackSize_) != 0) { + throw SystemResourceException("pthread_attr_setstacksize failed"); + } + + // Set thread policy + if (pthread_attr_setschedpolicy(&thread_attr, policy_) != 0) { + throw SystemResourceException("pthread_attr_setschedpolicy failed"); + } + + struct sched_param sched_param; + sched_param.sched_priority = priority_; + + // Set thread priority + if (pthread_attr_setschedparam(&thread_attr, &sched_param) != 0) { + throw SystemResourceException("pthread_attr_setschedparam failed"); + } + + // Create reference + shared_ptr<PthreadThread>* selfRef = new shared_ptr<PthreadThread>(); + *selfRef = self_.lock(); + + state_ = starting; + + if (pthread_create(&pthread_, &thread_attr, threadMain, (void*)selfRef) != 0) { + throw SystemResourceException("pthread_create failed"); + } + } + + void join() { + if (!detached_ && state_ != uninitialized) { + void* ignore; + /* XXX + If join fails it is most likely due to the fact + that the last reference was the thread itself and cannot + join. This results in leaked threads and will eventually + cause the process to run out of thread resources. + We're beyond the point of throwing an exception. Not clear how + best to handle this. */ + detached_ = pthread_join(pthread_, &ignore) == 0; + } + } + + Thread::id_t getId() { + return (Thread::id_t)pthread_; + } + + shared_ptr<Runnable> runnable() const { return Thread::runnable(); } + + void runnable(shared_ptr<Runnable> value) { Thread::runnable(value); } + + void weakRef(shared_ptr<PthreadThread> self) { + assert(self.get() == this); + self_ = weak_ptr<PthreadThread>(self); + } +}; + +void* PthreadThread::threadMain(void* arg) { + shared_ptr<PthreadThread> thread = *(shared_ptr<PthreadThread>*)arg; + delete reinterpret_cast<shared_ptr<PthreadThread>*>(arg); + + if (thread == NULL) { + return (void*)0; + } + + if (thread->state_ != starting) { + return (void*)0; + } + +#if GOOGLE_PERFTOOLS_REGISTER_THREAD + ProfilerRegisterThread(); +#endif + + thread->state_ = starting; + thread->runnable()->run(); + if (thread->state_ != stopping && thread->state_ != stopped) { + thread->state_ = stopping; + } + + return (void*)0; +} + +/** + * POSIX Thread factory implementation + */ +class PosixThreadFactory::Impl { + + private: + POLICY policy_; + PRIORITY priority_; + int stackSize_; + bool detached_; + + /** + * Converts generic posix thread schedule policy enums into pthread + * API values. + */ + static int toPthreadPolicy(POLICY policy) { + switch (policy) { + case OTHER: + return SCHED_OTHER; + case FIFO: + return SCHED_FIFO; + case ROUND_ROBIN: + return SCHED_RR; + } + return SCHED_OTHER; + } + + /** + * Converts relative thread priorities to absolute value based on posix + * thread scheduler policy + * + * The idea is simply to divide up the priority range for the given policy + * into the correpsonding relative priority level (lowest..highest) and + * then pro-rate accordingly. + */ + static int toPthreadPriority(POLICY policy, PRIORITY priority) { + int pthread_policy = toPthreadPolicy(policy); + int min_priority = sched_get_priority_min(pthread_policy); + int max_priority = sched_get_priority_max(pthread_policy); + int quanta = (HIGHEST - LOWEST) + 1; + float stepsperquanta = (max_priority - min_priority) / quanta; + + if (priority <= HIGHEST) { + return (int)(min_priority + stepsperquanta * priority); + } else { + // should never get here for priority increments. + assert(false); + return (int)(min_priority + stepsperquanta * NORMAL); + } + } + + public: + + Impl(POLICY policy, PRIORITY priority, int stackSize, bool detached) : + policy_(policy), + priority_(priority), + stackSize_(stackSize), + detached_(detached) {} + + /** + * Creates a new POSIX thread to run the runnable object + * + * @param runnable A runnable object + */ + shared_ptr<Thread> newThread(shared_ptr<Runnable> runnable) const { + shared_ptr<PthreadThread> result = shared_ptr<PthreadThread>(new PthreadThread(toPthreadPolicy(policy_), toPthreadPriority(policy_, priority_), stackSize_, detached_, runnable)); + result->weakRef(result); + runnable->thread(result); + return result; + } + + int getStackSize() const { return stackSize_; } + + void setStackSize(int value) { stackSize_ = value; } + + PRIORITY getPriority() const { return priority_; } + + /** + * Sets priority. + * + * XXX + * Need to handle incremental priorities properly. + */ + void setPriority(PRIORITY value) { priority_ = value; } + + bool isDetached() const { return detached_; } + + void setDetached(bool value) { detached_ = value; } + + Thread::id_t getCurrentThreadId() const { + // TODO(dreiss): Stop using C-style casts. + return (id_t)pthread_self(); + } + +}; + +PosixThreadFactory::PosixThreadFactory(POLICY policy, PRIORITY priority, int stackSize, bool detached) : + impl_(new PosixThreadFactory::Impl(policy, priority, stackSize, detached)) {} + +shared_ptr<Thread> PosixThreadFactory::newThread(shared_ptr<Runnable> runnable) const { return impl_->newThread(runnable); } + +int PosixThreadFactory::getStackSize() const { return impl_->getStackSize(); } + +void PosixThreadFactory::setStackSize(int value) { impl_->setStackSize(value); } + +PosixThreadFactory::PRIORITY PosixThreadFactory::getPriority() const { return impl_->getPriority(); } + +void PosixThreadFactory::setPriority(PosixThreadFactory::PRIORITY value) { impl_->setPriority(value); } + +bool PosixThreadFactory::isDetached() const { return impl_->isDetached(); } + +void PosixThreadFactory::setDetached(bool value) { impl_->setDetached(value); } + +Thread::id_t PosixThreadFactory::getCurrentThreadId() const { return impl_->getCurrentThreadId(); } + +}}} // apache::thrift::concurrency diff --git a/lib/cpp/src/concurrency/PosixThreadFactory.h b/lib/cpp/src/concurrency/PosixThreadFactory.h new file mode 100644 index 000000000..d6d83a3a1 --- /dev/null +++ b/lib/cpp/src/concurrency/PosixThreadFactory.h @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ +#define _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ 1 + +#include "Thread.h" + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * A thread factory to create posix threads + * + * @version $Id:$ + */ +class PosixThreadFactory : public ThreadFactory { + + public: + + /** + * POSIX Thread scheduler policies + */ + enum POLICY { + OTHER, + FIFO, + ROUND_ROBIN + }; + + /** + * POSIX Thread scheduler relative priorities, + * + * Absolute priority is determined by scheduler policy and OS. This + * enumeration specifies relative priorities such that one can specify a + * priority withing a giving scheduler policy without knowing the absolute + * value of the priority. + */ + enum PRIORITY { + LOWEST = 0, + LOWER = 1, + LOW = 2, + NORMAL = 3, + HIGH = 4, + HIGHER = 5, + HIGHEST = 6, + INCREMENT = 7, + DECREMENT = 8 + }; + + /** + * Posix thread (pthread) factory. All threads created by a factory are reference-counted + * via boost::shared_ptr and boost::weak_ptr. The factory guarantees that threads and + * the Runnable tasks they host will be properly cleaned up once the last strong reference + * to both is given up. + * + * Threads are created with the specified policy, priority, stack-size and detachable-mode + * detached means the thread is free-running and will release all system resources the + * when it completes. A detachable thread is not joinable. The join method + * of a detachable thread will return immediately with no error. + * + * By default threads are not joinable. + */ + + PosixThreadFactory(POLICY policy=ROUND_ROBIN, PRIORITY priority=NORMAL, int stackSize=1, bool detached=true); + + // From ThreadFactory; + boost::shared_ptr<Thread> newThread(boost::shared_ptr<Runnable> runnable) const; + + // From ThreadFactory; + Thread::id_t getCurrentThreadId() const; + + /** + * Gets stack size for created threads + * + * @return int size in megabytes + */ + virtual int getStackSize() const; + + /** + * Sets stack size for created threads + * + * @param value size in megabytes + */ + virtual void setStackSize(int value); + + /** + * Gets priority relative to current policy + */ + virtual PRIORITY getPriority() const; + + /** + * Sets priority relative to current policy + */ + virtual void setPriority(PRIORITY priority); + + /** + * Sets detached mode of threads + */ + virtual void setDetached(bool detached); + + /** + * Gets current detached mode + */ + virtual bool isDetached() const; + + private: + class Impl; + boost::shared_ptr<Impl> impl_; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ diff --git a/lib/cpp/src/concurrency/Thread.h b/lib/cpp/src/concurrency/Thread.h new file mode 100644 index 000000000..d4282adbc --- /dev/null +++ b/lib/cpp/src/concurrency/Thread.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_THREAD_H_ +#define _THRIFT_CONCURRENCY_THREAD_H_ 1 + +#include <stdint.h> +#include <boost/shared_ptr.hpp> +#include <boost/weak_ptr.hpp> + +namespace apache { namespace thrift { namespace concurrency { + +class Thread; + +/** + * Minimal runnable class. More or less analogous to java.lang.Runnable. + * + * @version $Id:$ + */ +class Runnable { + + public: + virtual ~Runnable() {}; + virtual void run() = 0; + + /** + * Gets the thread object that is hosting this runnable object - can return + * an empty boost::shared pointer if no references remain on thet thread object + */ + virtual boost::shared_ptr<Thread> thread() { return thread_.lock(); } + + /** + * Sets the thread that is executing this object. This is only meant for + * use by concrete implementations of Thread. + */ + virtual void thread(boost::shared_ptr<Thread> value) { thread_ = value; } + + private: + boost::weak_ptr<Thread> thread_; +}; + +/** + * Minimal thread class. Returned by thread factory bound to a Runnable object + * and ready to start execution. More or less analogous to java.lang.Thread + * (minus all the thread group, priority, mode and other baggage, since that + * is difficult to abstract across platforms and is left for platform-specific + * ThreadFactory implemtations to deal with + * + * @see apache::thrift::concurrency::ThreadFactory) + */ +class Thread { + + public: + + typedef uint64_t id_t; + + virtual ~Thread() {}; + + /** + * Starts the thread. Does platform specific thread creation and + * configuration then invokes the run method of the Runnable object bound + * to this thread. + */ + virtual void start() = 0; + + /** + * Join this thread. Current thread blocks until this target thread + * completes. + */ + virtual void join() = 0; + + /** + * Gets the thread's platform-specific ID + */ + virtual id_t getId() = 0; + + /** + * Gets the runnable object this thread is hosting + */ + virtual boost::shared_ptr<Runnable> runnable() const { return _runnable; } + + protected: + virtual void runnable(boost::shared_ptr<Runnable> value) { _runnable = value; } + + private: + boost::shared_ptr<Runnable> _runnable; + +}; + +/** + * Factory to create platform-specific thread object and bind them to Runnable + * object for execution + */ +class ThreadFactory { + + public: + virtual ~ThreadFactory() {} + virtual boost::shared_ptr<Thread> newThread(boost::shared_ptr<Runnable> runnable) const = 0; + + /** Gets the current thread id or unknown_thread_id if the current thread is not a thrift thread */ + + static const Thread::id_t unknown_thread_id; + + virtual Thread::id_t getCurrentThreadId() const = 0; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_THREAD_H_ diff --git a/lib/cpp/src/concurrency/ThreadManager.cpp b/lib/cpp/src/concurrency/ThreadManager.cpp new file mode 100644 index 000000000..abfcf6e70 --- /dev/null +++ b/lib/cpp/src/concurrency/ThreadManager.cpp @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "ThreadManager.h" +#include "Exception.h" +#include "Monitor.h" + +#include <boost/shared_ptr.hpp> + +#include <assert.h> +#include <queue> +#include <set> + +#if defined(DEBUG) +#include <iostream> +#endif //defined(DEBUG) + +namespace apache { namespace thrift { namespace concurrency { + +using boost::shared_ptr; +using boost::dynamic_pointer_cast; + +/** + * ThreadManager class + * + * This class manages a pool of threads. It uses a ThreadFactory to create + * threads. It never actually creates or destroys worker threads, rather + * it maintains statistics on number of idle threads, number of active threads, + * task backlog, and average wait and service times. + * + * @version $Id:$ + */ +class ThreadManager::Impl : public ThreadManager { + + public: + Impl() : + workerCount_(0), + workerMaxCount_(0), + idleCount_(0), + pendingTaskCountMax_(0), + state_(ThreadManager::UNINITIALIZED) {} + + ~Impl() { stop(); } + + void start(); + + void stop() { stopImpl(false); } + + void join() { stopImpl(true); } + + const ThreadManager::STATE state() const { + return state_; + } + + shared_ptr<ThreadFactory> threadFactory() const { + Synchronized s(monitor_); + return threadFactory_; + } + + void threadFactory(shared_ptr<ThreadFactory> value) { + Synchronized s(monitor_); + threadFactory_ = value; + } + + void addWorker(size_t value); + + void removeWorker(size_t value); + + size_t idleWorkerCount() const { + return idleCount_; + } + + size_t workerCount() const { + Synchronized s(monitor_); + return workerCount_; + } + + size_t pendingTaskCount() const { + Synchronized s(monitor_); + return tasks_.size(); + } + + size_t totalTaskCount() const { + Synchronized s(monitor_); + return tasks_.size() + workerCount_ - idleCount_; + } + + size_t pendingTaskCountMax() const { + Synchronized s(monitor_); + return pendingTaskCountMax_; + } + + void pendingTaskCountMax(const size_t value) { + Synchronized s(monitor_); + pendingTaskCountMax_ = value; + } + + bool canSleep(); + + void add(shared_ptr<Runnable> value, int64_t timeout); + + void remove(shared_ptr<Runnable> task); + +private: + void stopImpl(bool join); + + size_t workerCount_; + size_t workerMaxCount_; + size_t idleCount_; + size_t pendingTaskCountMax_; + + ThreadManager::STATE state_; + shared_ptr<ThreadFactory> threadFactory_; + + + friend class ThreadManager::Task; + std::queue<shared_ptr<Task> > tasks_; + Monitor monitor_; + Monitor workerMonitor_; + + friend class ThreadManager::Worker; + std::set<shared_ptr<Thread> > workers_; + std::set<shared_ptr<Thread> > deadWorkers_; + std::map<const Thread::id_t, shared_ptr<Thread> > idMap_; +}; + +class ThreadManager::Task : public Runnable { + + public: + enum STATE { + WAITING, + EXECUTING, + CANCELLED, + COMPLETE + }; + + Task(shared_ptr<Runnable> runnable) : + runnable_(runnable), + state_(WAITING) {} + + ~Task() {} + + void run() { + if (state_ == EXECUTING) { + runnable_->run(); + state_ = COMPLETE; + } + } + + private: + shared_ptr<Runnable> runnable_; + friend class ThreadManager::Worker; + STATE state_; +}; + +class ThreadManager::Worker: public Runnable { + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + STOPPING, + STOPPED + }; + + public: + Worker(ThreadManager::Impl* manager) : + manager_(manager), + state_(UNINITIALIZED), + idle_(false) {} + + ~Worker() {} + + private: + bool isActive() const { + return + (manager_->workerCount_ <= manager_->workerMaxCount_) || + (manager_->state_ == JOINING && !manager_->tasks_.empty()); + } + + public: + /** + * Worker entry point + * + * As long as worker thread is running, pull tasks off the task queue and + * execute. + */ + void run() { + bool active = false; + bool notifyManager = false; + + /** + * Increment worker semaphore and notify manager if worker count reached + * desired max + * + * Note: We have to release the monitor and acquire the workerMonitor + * since that is what the manager blocks on for worker add/remove + */ + { + Synchronized s(manager_->monitor_); + active = manager_->workerCount_ < manager_->workerMaxCount_; + if (active) { + manager_->workerCount_++; + notifyManager = manager_->workerCount_ == manager_->workerMaxCount_; + } + } + + if (notifyManager) { + Synchronized s(manager_->workerMonitor_); + manager_->workerMonitor_.notify(); + notifyManager = false; + } + + while (active) { + shared_ptr<ThreadManager::Task> task; + + /** + * While holding manager monitor block for non-empty task queue (Also + * check that the thread hasn't been requested to stop). Once the queue + * is non-empty, dequeue a task, release monitor, and execute. If the + * worker max count has been decremented such that we exceed it, mark + * ourself inactive, decrement the worker count and notify the manager + * (technically we're notifying the next blocked thread but eventually + * the manager will see it. + */ + { + Synchronized s(manager_->monitor_); + active = isActive(); + + while (active && manager_->tasks_.empty()) { + manager_->idleCount_++; + idle_ = true; + manager_->monitor_.wait(); + active = isActive(); + idle_ = false; + manager_->idleCount_--; + } + + if (active) { + if (!manager_->tasks_.empty()) { + task = manager_->tasks_.front(); + manager_->tasks_.pop(); + if (task->state_ == ThreadManager::Task::WAITING) { + task->state_ = ThreadManager::Task::EXECUTING; + } + + /* If we have a pending task max and we just dropped below it, wakeup any + thread that might be blocked on add. */ + if (manager_->pendingTaskCountMax_ != 0 && + manager_->tasks_.size() == manager_->pendingTaskCountMax_ - 1) { + manager_->monitor_.notify(); + } + } + } else { + idle_ = true; + manager_->workerCount_--; + notifyManager = (manager_->workerCount_ == manager_->workerMaxCount_); + } + } + + if (task != NULL) { + if (task->state_ == ThreadManager::Task::EXECUTING) { + try { + task->run(); + } catch(...) { + // XXX need to log this + } + } + } + } + + { + Synchronized s(manager_->workerMonitor_); + manager_->deadWorkers_.insert(this->thread()); + if (notifyManager) { + manager_->workerMonitor_.notify(); + } + } + + return; + } + + private: + ThreadManager::Impl* manager_; + friend class ThreadManager::Impl; + STATE state_; + bool idle_; +}; + + + void ThreadManager::Impl::addWorker(size_t value) { + std::set<shared_ptr<Thread> > newThreads; + for (size_t ix = 0; ix < value; ix++) { + class ThreadManager::Worker; + shared_ptr<ThreadManager::Worker> worker = shared_ptr<ThreadManager::Worker>(new ThreadManager::Worker(this)); + newThreads.insert(threadFactory_->newThread(worker)); + } + + { + Synchronized s(monitor_); + workerMaxCount_ += value; + workers_.insert(newThreads.begin(), newThreads.end()); + } + + for (std::set<shared_ptr<Thread> >::iterator ix = newThreads.begin(); ix != newThreads.end(); ix++) { + shared_ptr<ThreadManager::Worker> worker = dynamic_pointer_cast<ThreadManager::Worker, Runnable>((*ix)->runnable()); + worker->state_ = ThreadManager::Worker::STARTING; + (*ix)->start(); + idMap_.insert(std::pair<const Thread::id_t, shared_ptr<Thread> >((*ix)->getId(), *ix)); + } + + { + Synchronized s(workerMonitor_); + while (workerCount_ != workerMaxCount_) { + workerMonitor_.wait(); + } + } +} + +void ThreadManager::Impl::start() { + + if (state_ == ThreadManager::STOPPED) { + return; + } + + { + Synchronized s(monitor_); + if (state_ == ThreadManager::UNINITIALIZED) { + if (threadFactory_ == NULL) { + throw InvalidArgumentException(); + } + state_ = ThreadManager::STARTED; + monitor_.notifyAll(); + } + + while (state_ == STARTING) { + monitor_.wait(); + } + } +} + +void ThreadManager::Impl::stopImpl(bool join) { + bool doStop = false; + if (state_ == ThreadManager::STOPPED) { + return; + } + + { + Synchronized s(monitor_); + if (state_ != ThreadManager::STOPPING && + state_ != ThreadManager::JOINING && + state_ != ThreadManager::STOPPED) { + doStop = true; + state_ = join ? ThreadManager::JOINING : ThreadManager::STOPPING; + } + } + + if (doStop) { + removeWorker(workerCount_); + } + + // XXX + // should be able to block here for transition to STOPPED since we're no + // using shared_ptrs + + { + Synchronized s(monitor_); + state_ = ThreadManager::STOPPED; + } + +} + +void ThreadManager::Impl::removeWorker(size_t value) { + std::set<shared_ptr<Thread> > removedThreads; + { + Synchronized s(monitor_); + if (value > workerMaxCount_) { + throw InvalidArgumentException(); + } + + workerMaxCount_ -= value; + + if (idleCount_ < value) { + for (size_t ix = 0; ix < idleCount_; ix++) { + monitor_.notify(); + } + } else { + monitor_.notifyAll(); + } + } + + { + Synchronized s(workerMonitor_); + + while (workerCount_ != workerMaxCount_) { + workerMonitor_.wait(); + } + + for (std::set<shared_ptr<Thread> >::iterator ix = deadWorkers_.begin(); ix != deadWorkers_.end(); ix++) { + workers_.erase(*ix); + idMap_.erase((*ix)->getId()); + } + + deadWorkers_.clear(); + } +} + + bool ThreadManager::Impl::canSleep() { + const Thread::id_t id = threadFactory_->getCurrentThreadId(); + return idMap_.find(id) == idMap_.end(); + } + + void ThreadManager::Impl::add(shared_ptr<Runnable> value, int64_t timeout) { + Synchronized s(monitor_); + + if (state_ != ThreadManager::STARTED) { + throw IllegalStateException(); + } + + if (pendingTaskCountMax_ > 0 && (tasks_.size() >= pendingTaskCountMax_)) { + if (canSleep() && timeout >= 0) { + while (pendingTaskCountMax_ > 0 && tasks_.size() >= pendingTaskCountMax_) { + monitor_.wait(timeout); + } + } else { + throw TooManyPendingTasksException(); + } + } + + tasks_.push(shared_ptr<ThreadManager::Task>(new ThreadManager::Task(value))); + + // If idle thread is available notify it, otherwise all worker threads are + // running and will get around to this task in time. + if (idleCount_ > 0) { + monitor_.notify(); + } + } + +void ThreadManager::Impl::remove(shared_ptr<Runnable> task) { + Synchronized s(monitor_); + if (state_ != ThreadManager::STARTED) { + throw IllegalStateException(); + } +} + +class SimpleThreadManager : public ThreadManager::Impl { + + public: + SimpleThreadManager(size_t workerCount=4, size_t pendingTaskCountMax=0) : + workerCount_(workerCount), + pendingTaskCountMax_(pendingTaskCountMax), + firstTime_(true) { + } + + void start() { + ThreadManager::Impl::pendingTaskCountMax(pendingTaskCountMax_); + ThreadManager::Impl::start(); + addWorker(workerCount_); + } + + private: + const size_t workerCount_; + const size_t pendingTaskCountMax_; + bool firstTime_; + Monitor monitor_; +}; + + +shared_ptr<ThreadManager> ThreadManager::newThreadManager() { + return shared_ptr<ThreadManager>(new ThreadManager::Impl()); +} + +shared_ptr<ThreadManager> ThreadManager::newSimpleThreadManager(size_t count, size_t pendingTaskCountMax) { + return shared_ptr<ThreadManager>(new SimpleThreadManager(count, pendingTaskCountMax)); +} + +}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/concurrency/ThreadManager.h b/lib/cpp/src/concurrency/ThreadManager.h new file mode 100644 index 000000000..6e5a17817 --- /dev/null +++ b/lib/cpp/src/concurrency/ThreadManager.h @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_THREADMANAGER_H_ +#define _THRIFT_CONCURRENCY_THREADMANAGER_H_ 1 + +#include <boost/shared_ptr.hpp> +#include <sys/types.h> +#include "Thread.h" + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Thread Pool Manager and related classes + * + * @version $Id:$ + */ +class ThreadManager; + +/** + * ThreadManager class + * + * This class manages a pool of threads. It uses a ThreadFactory to create + * threads. It never actually creates or destroys worker threads, rather + * It maintains statistics on number of idle threads, number of active threads, + * task backlog, and average wait and service times and informs the PoolPolicy + * object bound to instances of this manager of interesting transitions. It is + * then up the PoolPolicy object to decide if the thread pool size needs to be + * adjusted and call this object addWorker and removeWorker methods to make + * changes. + * + * This design allows different policy implementations to used this code to + * handle basic worker thread management and worker task execution and focus on + * policy issues. The simplest policy, StaticPolicy, does nothing other than + * create a fixed number of threads. + */ +class ThreadManager { + + protected: + ThreadManager() {} + + public: + virtual ~ThreadManager() {} + + /** + * Starts the thread manager. Verifies all attributes have been properly + * initialized, then allocates necessary resources to begin operation + */ + virtual void start() = 0; + + /** + * Stops the thread manager. Aborts all remaining unprocessed task, shuts + * down all created worker threads, and realeases all allocated resources. + * This method blocks for all worker threads to complete, thus it can + * potentially block forever if a worker thread is running a task that + * won't terminate. + */ + virtual void stop() = 0; + + /** + * Joins the thread manager. This is the same as stop, except that it will + * block until all the workers have finished their work. At that point + * the ThreadManager will transition into the STOPPED state. + */ + virtual void join() = 0; + + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + JOINING, + STOPPING, + STOPPED + }; + + virtual const STATE state() const = 0; + + virtual boost::shared_ptr<ThreadFactory> threadFactory() const = 0; + + virtual void threadFactory(boost::shared_ptr<ThreadFactory> value) = 0; + + virtual void addWorker(size_t value=1) = 0; + + virtual void removeWorker(size_t value=1) = 0; + + /** + * Gets the current number of idle worker threads + */ + virtual size_t idleWorkerCount() const = 0; + + /** + * Gets the current number of total worker threads + */ + virtual size_t workerCount() const = 0; + + /** + * Gets the current number of pending tasks + */ + virtual size_t pendingTaskCount() const = 0; + + /** + * Gets the current number of pending and executing tasks + */ + virtual size_t totalTaskCount() const = 0; + + /** + * Gets the maximum pending task count. 0 indicates no maximum + */ + virtual size_t pendingTaskCountMax() const = 0; + + /** + * Adds a task to be executed at some time in the future by a worker thread. + * + * This method will block if pendingTaskCountMax() in not zero and pendingTaskCount() + * is greater than or equalt to pendingTaskCountMax(). If this method is called in the + * context of a ThreadManager worker thread it will throw a + * TooManyPendingTasksException + * + * @param task The task to queue for execution + * + * @param timeout Time to wait in milliseconds to add a task when a pending-task-count + * is specified. Specific cases: + * timeout = 0 : Wait forever to queue task. + * timeout = -1 : Return immediately if pending task count exceeds specified max + * + * @throws TooManyPendingTasksException Pending task count exceeds max pending task count + */ + virtual void add(boost::shared_ptr<Runnable>task, int64_t timeout=0LL) = 0; + + /** + * Removes a pending task + */ + virtual void remove(boost::shared_ptr<Runnable> task) = 0; + + static boost::shared_ptr<ThreadManager> newThreadManager(); + + /** + * Creates a simple thread manager the uses count number of worker threads and has + * a pendingTaskCountMax maximum pending tasks. The default, 0, specified no limit + * on pending tasks + */ + static boost::shared_ptr<ThreadManager> newSimpleThreadManager(size_t count=4, size_t pendingTaskCountMax=0); + + class Task; + + class Worker; + + class Impl; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_THREADMANAGER_H_ diff --git a/lib/cpp/src/concurrency/TimerManager.cpp b/lib/cpp/src/concurrency/TimerManager.cpp new file mode 100644 index 000000000..25515dc82 --- /dev/null +++ b/lib/cpp/src/concurrency/TimerManager.cpp @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TimerManager.h" +#include "Exception.h" +#include "Util.h" + +#include <assert.h> +#include <iostream> +#include <set> + +namespace apache { namespace thrift { namespace concurrency { + +using boost::shared_ptr; + +typedef std::multimap<int64_t, shared_ptr<TimerManager::Task> >::iterator task_iterator; +typedef std::pair<task_iterator, task_iterator> task_range; + +/** + * TimerManager class + * + * @version $Id:$ + */ +class TimerManager::Task : public Runnable { + + public: + enum STATE { + WAITING, + EXECUTING, + CANCELLED, + COMPLETE + }; + + Task(shared_ptr<Runnable> runnable) : + runnable_(runnable), + state_(WAITING) {} + + ~Task() { + } + + void run() { + if (state_ == EXECUTING) { + runnable_->run(); + state_ = COMPLETE; + } + } + + private: + shared_ptr<Runnable> runnable_; + class TimerManager::Dispatcher; + friend class TimerManager::Dispatcher; + STATE state_; +}; + +class TimerManager::Dispatcher: public Runnable { + + public: + Dispatcher(TimerManager* manager) : + manager_(manager) {} + + ~Dispatcher() {} + + /** + * Dispatcher entry point + * + * As long as dispatcher thread is running, pull tasks off the task taskMap_ + * and execute. + */ + void run() { + { + Synchronized s(manager_->monitor_); + if (manager_->state_ == TimerManager::STARTING) { + manager_->state_ = TimerManager::STARTED; + manager_->monitor_.notifyAll(); + } + } + + do { + std::set<shared_ptr<TimerManager::Task> > expiredTasks; + { + Synchronized s(manager_->monitor_); + task_iterator expiredTaskEnd; + int64_t now = Util::currentTime(); + while (manager_->state_ == TimerManager::STARTED && + (expiredTaskEnd = manager_->taskMap_.upper_bound(now)) == manager_->taskMap_.begin()) { + int64_t timeout = 0LL; + if (!manager_->taskMap_.empty()) { + timeout = manager_->taskMap_.begin()->first - now; + } + assert((timeout != 0 && manager_->taskCount_ > 0) || (timeout == 0 && manager_->taskCount_ == 0)); + try { + manager_->monitor_.wait(timeout); + } catch (TimedOutException &e) {} + now = Util::currentTime(); + } + + if (manager_->state_ == TimerManager::STARTED) { + for (task_iterator ix = manager_->taskMap_.begin(); ix != expiredTaskEnd; ix++) { + shared_ptr<TimerManager::Task> task = ix->second; + expiredTasks.insert(task); + if (task->state_ == TimerManager::Task::WAITING) { + task->state_ = TimerManager::Task::EXECUTING; + } + manager_->taskCount_--; + } + manager_->taskMap_.erase(manager_->taskMap_.begin(), expiredTaskEnd); + } + } + + for (std::set<shared_ptr<Task> >::iterator ix = expiredTasks.begin(); ix != expiredTasks.end(); ix++) { + (*ix)->run(); + } + + } while (manager_->state_ == TimerManager::STARTED); + + { + Synchronized s(manager_->monitor_); + if (manager_->state_ == TimerManager::STOPPING) { + manager_->state_ = TimerManager::STOPPED; + manager_->monitor_.notify(); + } + } + return; + } + + private: + TimerManager* manager_; + friend class TimerManager; +}; + +TimerManager::TimerManager() : + taskCount_(0), + state_(TimerManager::UNINITIALIZED), + dispatcher_(shared_ptr<Dispatcher>(new Dispatcher(this))) { +} + + +TimerManager::~TimerManager() { + + // If we haven't been explicitly stopped, do so now. We don't need to grab + // the monitor here, since stop already takes care of reentrancy. + + if (state_ != STOPPED) { + try { + stop(); + } catch(...) { + throw; + // uhoh + } + } +} + +void TimerManager::start() { + bool doStart = false; + { + Synchronized s(monitor_); + if (threadFactory_ == NULL) { + throw InvalidArgumentException(); + } + if (state_ == TimerManager::UNINITIALIZED) { + state_ = TimerManager::STARTING; + doStart = true; + } + } + + if (doStart) { + dispatcherThread_ = threadFactory_->newThread(dispatcher_); + dispatcherThread_->start(); + } + + { + Synchronized s(monitor_); + while (state_ == TimerManager::STARTING) { + monitor_.wait(); + } + assert(state_ != TimerManager::STARTING); + } +} + +void TimerManager::stop() { + bool doStop = false; + { + Synchronized s(monitor_); + if (state_ == TimerManager::UNINITIALIZED) { + state_ = TimerManager::STOPPED; + } else if (state_ != STOPPING && state_ != STOPPED) { + doStop = true; + state_ = STOPPING; + monitor_.notifyAll(); + } + while (state_ != STOPPED) { + monitor_.wait(); + } + } + + if (doStop) { + // Clean up any outstanding tasks + for (task_iterator ix = taskMap_.begin(); ix != taskMap_.end(); ix++) { + taskMap_.erase(ix); + } + + // Remove dispatcher's reference to us. + dispatcher_->manager_ = NULL; + } +} + +shared_ptr<const ThreadFactory> TimerManager::threadFactory() const { + Synchronized s(monitor_); + return threadFactory_; +} + +void TimerManager::threadFactory(shared_ptr<const ThreadFactory> value) { + Synchronized s(monitor_); + threadFactory_ = value; +} + +size_t TimerManager::taskCount() const { + return taskCount_; +} + +void TimerManager::add(shared_ptr<Runnable> task, int64_t timeout) { + int64_t now = Util::currentTime(); + timeout += now; + + { + Synchronized s(monitor_); + if (state_ != TimerManager::STARTED) { + throw IllegalStateException(); + } + + taskCount_++; + taskMap_.insert(std::pair<int64_t, shared_ptr<Task> >(timeout, shared_ptr<Task>(new Task(task)))); + + // If the task map was empty, or if we have an expiration that is earlier + // than any previously seen, kick the dispatcher so it can update its + // timeout + if (taskCount_ == 1 || timeout < taskMap_.begin()->first) { + monitor_.notify(); + } + } +} + +void TimerManager::add(shared_ptr<Runnable> task, const struct timespec& value) { + + int64_t expiration; + Util::toMilliseconds(expiration, value); + + int64_t now = Util::currentTime(); + + if (expiration < now) { + throw InvalidArgumentException(); + } + + add(task, expiration - now); +} + + +void TimerManager::remove(shared_ptr<Runnable> task) { + Synchronized s(monitor_); + if (state_ != TimerManager::STARTED) { + throw IllegalStateException(); + } +} + +const TimerManager::STATE TimerManager::state() const { return state_; } + +}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/concurrency/TimerManager.h b/lib/cpp/src/concurrency/TimerManager.h new file mode 100644 index 000000000..f3f799f93 --- /dev/null +++ b/lib/cpp/src/concurrency/TimerManager.h @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_TIMERMANAGER_H_ +#define _THRIFT_CONCURRENCY_TIMERMANAGER_H_ 1 + +#include "Exception.h" +#include "Monitor.h" +#include "Thread.h" + +#include <boost/shared_ptr.hpp> +#include <map> +#include <time.h> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Timer Manager + * + * This class dispatches timer tasks when they fall due. + * + * @version $Id:$ + */ +class TimerManager { + + public: + + TimerManager(); + + virtual ~TimerManager(); + + virtual boost::shared_ptr<const ThreadFactory> threadFactory() const; + + virtual void threadFactory(boost::shared_ptr<const ThreadFactory> value); + + /** + * Starts the timer manager service + * + * @throws IllegalArgumentException Missing thread factory attribute + */ + virtual void start(); + + /** + * Stops the timer manager service + */ + virtual void stop(); + + virtual size_t taskCount() const ; + + /** + * Adds a task to be executed at some time in the future by a worker thread. + * + * @param task The task to execute + * @param timeout Time in milliseconds to delay before executing task + */ + virtual void add(boost::shared_ptr<Runnable> task, int64_t timeout); + + /** + * Adds a task to be executed at some time in the future by a worker thread. + * + * @param task The task to execute + * @param timeout Absolute time in the future to execute task. + */ + virtual void add(boost::shared_ptr<Runnable> task, const struct timespec& timeout); + + /** + * Removes a pending task + * + * @throws NoSuchTaskException Specified task doesn't exist. It was either + * processed already or this call was made for a + * task that was never added to this timer + * + * @throws UncancellableTaskException Specified task is already being + * executed or has completed execution. + */ + virtual void remove(boost::shared_ptr<Runnable> task); + + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + STOPPING, + STOPPED + }; + + virtual const STATE state() const; + + private: + boost::shared_ptr<const ThreadFactory> threadFactory_; + class Task; + friend class Task; + std::multimap<int64_t, boost::shared_ptr<Task> > taskMap_; + size_t taskCount_; + Monitor monitor_; + STATE state_; + class Dispatcher; + friend class Dispatcher; + boost::shared_ptr<Dispatcher> dispatcher_; + boost::shared_ptr<Thread> dispatcherThread_; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_TIMERMANAGER_H_ diff --git a/lib/cpp/src/concurrency/Util.cpp b/lib/cpp/src/concurrency/Util.cpp new file mode 100644 index 000000000..1c4493716 --- /dev/null +++ b/lib/cpp/src/concurrency/Util.cpp @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "Util.h" + +#ifdef HAVE_CONFIG_H +#include <config.h> +#endif + +#if defined(HAVE_CLOCK_GETTIME) +#include <time.h> +#elif defined(HAVE_GETTIMEOFDAY) +#include <sys/time.h> +#endif // defined(HAVE_CLOCK_GETTIME) + +namespace apache { namespace thrift { namespace concurrency { + +const int64_t Util::currentTime() { + int64_t result; + +#if defined(HAVE_CLOCK_GETTIME) + struct timespec now; + int ret = clock_gettime(CLOCK_REALTIME, &now); + assert(ret == 0); + toMilliseconds(result, now); +#elif defined(HAVE_GETTIMEOFDAY) + struct timeval now; + int ret = gettimeofday(&now, NULL); + assert(ret == 0); + toMilliseconds(result, now); +#else +#error "No high-precision clock is available." +#endif // defined(HAVE_CLOCK_GETTIME) + + return result; +} + + +}}} // apache::thrift::concurrency diff --git a/lib/cpp/src/concurrency/Util.h b/lib/cpp/src/concurrency/Util.h new file mode 100644 index 000000000..25fcc2086 --- /dev/null +++ b/lib/cpp/src/concurrency/Util.h @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_UTIL_H_ +#define _THRIFT_CONCURRENCY_UTIL_H_ 1 + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> +#include <time.h> +#include <sys/time.h> + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Utility methods + * + * This class contains basic utility methods for converting time formats, + * and other common platform-dependent concurrency operations. + * It should not be included in API headers for other concurrency library + * headers, since it will, by definition, pull in all sorts of horrid + * platform dependent crap. Rather it should be inluded directly in + * concurrency library implementation source. + * + * @version $Id:$ + */ +class Util { + + static const int64_t NS_PER_S = 1000000000LL; + static const int64_t US_PER_S = 1000000LL; + static const int64_t MS_PER_S = 1000LL; + + static const int64_t NS_PER_MS = NS_PER_S / MS_PER_S; + static const int64_t US_PER_MS = US_PER_S / MS_PER_S; + + public: + + /** + * Converts millisecond timestamp into a timespec struct + * + * @param struct timespec& result + * @param time or duration in milliseconds + */ + static void toTimespec(struct timespec& result, int64_t value) { + result.tv_sec = value / MS_PER_S; // ms to s + result.tv_nsec = (value % MS_PER_S) * NS_PER_MS; // ms to ns + } + + static void toTimeval(struct timeval& result, int64_t value) { + result.tv_sec = value / MS_PER_S; // ms to s + result.tv_usec = (value % MS_PER_S) * US_PER_MS; // ms to us + } + + /** + * Converts struct timespec to milliseconds + */ + static const void toMilliseconds(int64_t& result, const struct timespec& value) { + result = (value.tv_sec * MS_PER_S) + (value.tv_nsec / NS_PER_MS); + // round up -- int64_t cast is to avoid a compiler error for some GCCs + if (int64_t(value.tv_nsec) % NS_PER_MS >= (NS_PER_MS / 2)) { + ++result; + } + } + + /** + * Converts struct timeval to milliseconds + */ + static const void toMilliseconds(int64_t& result, const struct timeval& value) { + result = (value.tv_sec * MS_PER_S) + (value.tv_usec / US_PER_MS); + // round up -- int64_t cast is to avoid a compiler error for some GCCs + if (int64_t(value.tv_usec) % US_PER_MS >= (US_PER_MS / 2)) { + ++result; + } + } + + /** + * Get current time as milliseconds from epoch + */ + static const int64_t currentTime(); +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_UTIL_H_ diff --git a/lib/cpp/src/concurrency/test/Tests.cpp b/lib/cpp/src/concurrency/test/Tests.cpp new file mode 100644 index 000000000..c80bb883f --- /dev/null +++ b/lib/cpp/src/concurrency/test/Tests.cpp @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <iostream> +#include <vector> +#include <string> + +#include "ThreadFactoryTests.h" +#include "TimerManagerTests.h" +#include "ThreadManagerTests.h" + +int main(int argc, char** argv) { + + std::string arg; + + std::vector<std::string> args(argc - 1 > 1 ? argc - 1 : 1); + + args[0] = "all"; + + for (int ix = 1; ix < argc; ix++) { + args[ix - 1] = std::string(argv[ix]); + } + + bool runAll = args[0].compare("all") == 0; + + if (runAll || args[0].compare("thread-factory") == 0) { + + ThreadFactoryTests threadFactoryTests; + + std::cout << "ThreadFactory tests..." << std::endl; + + size_t count = 1000; + size_t floodLoops = 1; + size_t floodCount = 100000; + + std::cout << "\t\tThreadFactory reap N threads test: N = " << count << std::endl; + + assert(threadFactoryTests.reapNThreads(count)); + + std::cout << "\t\tThreadFactory floodN threads test: N = " << floodCount << std::endl; + + assert(threadFactoryTests.floodNTest(floodLoops, floodCount)); + + std::cout << "\t\tThreadFactory synchronous start test" << std::endl; + + assert(threadFactoryTests.synchStartTest()); + + std::cout << "\t\tThreadFactory monitor timeout test" << std::endl; + + assert(threadFactoryTests.monitorTimeoutTest()); + } + + if (runAll || args[0].compare("util") == 0) { + + std::cout << "Util tests..." << std::endl; + + std::cout << "\t\tUtil minimum time" << std::endl; + + int64_t time00 = Util::currentTime(); + int64_t time01 = Util::currentTime(); + + std::cout << "\t\t\tMinimum time: " << time01 - time00 << "ms" << std::endl; + + time00 = Util::currentTime(); + time01 = time00; + size_t count = 0; + + while (time01 < time00 + 10) { + count++; + time01 = Util::currentTime(); + } + + std::cout << "\t\t\tscall per ms: " << count / (time01 - time00) << std::endl; + } + + + if (runAll || args[0].compare("timer-manager") == 0) { + + std::cout << "TimerManager tests..." << std::endl; + + std::cout << "\t\tTimerManager test00" << std::endl; + + TimerManagerTests timerManagerTests; + + assert(timerManagerTests.test00()); + } + + if (runAll || args[0].compare("thread-manager") == 0) { + + std::cout << "ThreadManager tests..." << std::endl; + + { + + size_t workerCount = 100; + + size_t taskCount = 100000; + + int64_t delay = 10LL; + + std::cout << "\t\tThreadManager load test: worker count: " << workerCount << " task count: " << taskCount << " delay: " << delay << std::endl; + + ThreadManagerTests threadManagerTests; + + assert(threadManagerTests.loadTest(taskCount, delay, workerCount)); + + std::cout << "\t\tThreadManager block test: worker count: " << workerCount << " delay: " << delay << std::endl; + + assert(threadManagerTests.blockTest(delay, workerCount)); + + } + } + + if (runAll || args[0].compare("thread-manager-benchmark") == 0) { + + std::cout << "ThreadManager benchmark tests..." << std::endl; + + { + + size_t minWorkerCount = 2; + + size_t maxWorkerCount = 512; + + size_t tasksPerWorker = 1000; + + int64_t delay = 10LL; + + for (size_t workerCount = minWorkerCount; workerCount < maxWorkerCount; workerCount*= 2) { + + size_t taskCount = workerCount * tasksPerWorker; + + std::cout << "\t\tThreadManager load test: worker count: " << workerCount << " task count: " << taskCount << " delay: " << delay << std::endl; + + ThreadManagerTests threadManagerTests; + + threadManagerTests.loadTest(taskCount, delay, workerCount); + } + } + } +} diff --git a/lib/cpp/src/concurrency/test/ThreadFactoryTests.h b/lib/cpp/src/concurrency/test/ThreadFactoryTests.h new file mode 100644 index 000000000..859fbaf51 --- /dev/null +++ b/lib/cpp/src/concurrency/test/ThreadFactoryTests.h @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <config.h> +#include <concurrency/Thread.h> +#include <concurrency/PosixThreadFactory.h> +#include <concurrency/Monitor.h> +#include <concurrency/Util.h> + +#include <assert.h> +#include <iostream> +#include <set> + +namespace apache { namespace thrift { namespace concurrency { namespace test { + +using boost::shared_ptr; +using namespace apache::thrift::concurrency; + +/** + * ThreadManagerTests class + * + * @version $Id:$ + */ +class ThreadFactoryTests { + +public: + + static const double ERROR; + + class Task: public Runnable { + + public: + + Task() {} + + void run() { + std::cout << "\t\t\tHello World" << std::endl; + } + }; + + /** + * Hello world test + */ + bool helloWorldTest() { + + PosixThreadFactory threadFactory = PosixThreadFactory(); + + shared_ptr<Task> task = shared_ptr<Task>(new ThreadFactoryTests::Task()); + + shared_ptr<Thread> thread = threadFactory.newThread(task); + + thread->start(); + + thread->join(); + + std::cout << "\t\t\tSuccess!" << std::endl; + + return true; + } + + /** + * Reap N threads + */ + class ReapNTask: public Runnable { + + public: + + ReapNTask(Monitor& monitor, int& activeCount) : + _monitor(monitor), + _count(activeCount) {} + + void run() { + Synchronized s(_monitor); + + _count--; + + //std::cout << "\t\t\tthread count: " << _count << std::endl; + + if (_count == 0) { + _monitor.notify(); + } + } + + Monitor& _monitor; + + int& _count; + }; + + bool reapNThreads(int loop=1, int count=10) { + + PosixThreadFactory threadFactory = PosixThreadFactory(); + + Monitor* monitor = new Monitor(); + + for(int lix = 0; lix < loop; lix++) { + + int* activeCount = new int(count); + + std::set<shared_ptr<Thread> > threads; + + int tix; + + for (tix = 0; tix < count; tix++) { + try { + threads.insert(threadFactory.newThread(shared_ptr<Runnable>(new ReapNTask(*monitor, *activeCount)))); + } catch(SystemResourceException& e) { + std::cout << "\t\t\tfailed to create " << lix * count + tix << " thread " << e.what() << std::endl; + throw e; + } + } + + tix = 0; + for (std::set<shared_ptr<Thread> >::const_iterator thread = threads.begin(); thread != threads.end(); tix++, ++thread) { + + try { + (*thread)->start(); + } catch(SystemResourceException& e) { + std::cout << "\t\t\tfailed to start " << lix * count + tix << " thread " << e.what() << std::endl; + throw e; + } + } + + { + Synchronized s(*monitor); + while (*activeCount > 0) { + monitor->wait(1000); + } + } + + for (std::set<shared_ptr<Thread> >::const_iterator thread = threads.begin(); thread != threads.end(); thread++) { + threads.erase(*thread); + } + + std::cout << "\t\t\treaped " << lix * count << " threads" << std::endl; + } + + std::cout << "\t\t\tSuccess!" << std::endl; + + return true; + } + + class SynchStartTask: public Runnable { + + public: + + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + STOPPING, + STOPPED + }; + + SynchStartTask(Monitor& monitor, volatile STATE& state) : + _monitor(monitor), + _state(state) {} + + void run() { + { + Synchronized s(_monitor); + if (_state == SynchStartTask::STARTING) { + _state = SynchStartTask::STARTED; + _monitor.notify(); + } + } + + { + Synchronized s(_monitor); + while (_state == SynchStartTask::STARTED) { + _monitor.wait(); + } + + if (_state == SynchStartTask::STOPPING) { + _state = SynchStartTask::STOPPED; + _monitor.notifyAll(); + } + } + } + + private: + Monitor& _monitor; + volatile STATE& _state; + }; + + bool synchStartTest() { + + Monitor monitor; + + SynchStartTask::STATE state = SynchStartTask::UNINITIALIZED; + + shared_ptr<SynchStartTask> task = shared_ptr<SynchStartTask>(new SynchStartTask(monitor, state)); + + PosixThreadFactory threadFactory = PosixThreadFactory(); + + shared_ptr<Thread> thread = threadFactory.newThread(task); + + if (state == SynchStartTask::UNINITIALIZED) { + + state = SynchStartTask::STARTING; + + thread->start(); + } + + { + Synchronized s(monitor); + while (state == SynchStartTask::STARTING) { + monitor.wait(); + } + } + + assert(state != SynchStartTask::STARTING); + + { + Synchronized s(monitor); + + try { + monitor.wait(100); + } catch(TimedOutException& e) { + } + + if (state == SynchStartTask::STARTED) { + + state = SynchStartTask::STOPPING; + + monitor.notify(); + } + + while (state == SynchStartTask::STOPPING) { + monitor.wait(); + } + } + + assert(state == SynchStartTask::STOPPED); + + bool success = true; + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "!" << std::endl; + + return true; + } + + /** See how accurate monitor timeout is. */ + + bool monitorTimeoutTest(size_t count=1000, int64_t timeout=10) { + + Monitor monitor; + + int64_t startTime = Util::currentTime(); + + for (size_t ix = 0; ix < count; ix++) { + { + Synchronized s(monitor); + try { + monitor.wait(timeout); + } catch(TimedOutException& e) { + } + } + } + + int64_t endTime = Util::currentTime(); + + double error = ((endTime - startTime) - (count * timeout)) / (double)(count * timeout); + + if (error < 0.0) { + + error *= 1.0; + } + + bool success = error < ThreadFactoryTests::ERROR; + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "! expected time: " << count * timeout << "ms elapsed time: "<< endTime - startTime << "ms error%: " << error * 100.0 << std::endl; + + return success; + } + + + class FloodTask : public Runnable { + public: + + FloodTask(const size_t id) :_id(id) {} + ~FloodTask(){ + if(_id % 1000 == 0) { + std::cout << "\t\tthread " << _id << " done" << std::endl; + } + } + + void run(){ + if(_id % 1000 == 0) { + std::cout << "\t\tthread " << _id << " started" << std::endl; + } + + usleep(1); + } + const size_t _id; + }; + + void foo(PosixThreadFactory *tf) { + } + + bool floodNTest(size_t loop=1, size_t count=100000) { + + bool success = false; + + for(size_t lix = 0; lix < loop; lix++) { + + PosixThreadFactory threadFactory = PosixThreadFactory(); + threadFactory.setDetached(true); + + for(size_t tix = 0; tix < count; tix++) { + + try { + + shared_ptr<FloodTask> task(new FloodTask(lix * count + tix )); + + shared_ptr<Thread> thread = threadFactory.newThread(task); + + thread->start(); + + usleep(1); + + } catch (TException& e) { + + std::cout << "\t\t\tfailed to start " << lix * count + tix << " thread " << e.what() << std::endl; + + return success; + } + } + + std::cout << "\t\t\tflooded " << (lix + 1) * count << " threads" << std::endl; + + success = true; + } + + return success; + } +}; + +const double ThreadFactoryTests::ERROR = .20; + +}}}} // apache::thrift::concurrency::test + diff --git a/lib/cpp/src/concurrency/test/ThreadManagerTests.h b/lib/cpp/src/concurrency/test/ThreadManagerTests.h new file mode 100644 index 000000000..e7b517431 --- /dev/null +++ b/lib/cpp/src/concurrency/test/ThreadManagerTests.h @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <config.h> +#include <concurrency/ThreadManager.h> +#include <concurrency/PosixThreadFactory.h> +#include <concurrency/Monitor.h> +#include <concurrency/Util.h> + +#include <assert.h> +#include <set> +#include <iostream> +#include <set> +#include <stdint.h> + +namespace apache { namespace thrift { namespace concurrency { namespace test { + +using namespace apache::thrift::concurrency; + +/** + * ThreadManagerTests class + * + * @version $Id:$ + */ +class ThreadManagerTests { + +public: + + static const double ERROR; + + class Task: public Runnable { + + public: + + Task(Monitor& monitor, size_t& count, int64_t timeout) : + _monitor(monitor), + _count(count), + _timeout(timeout), + _done(false) {} + + void run() { + + _startTime = Util::currentTime(); + + { + Synchronized s(_sleep); + + try { + _sleep.wait(_timeout); + } catch(TimedOutException& e) { + ; + }catch(...) { + assert(0); + } + } + + _endTime = Util::currentTime(); + + _done = true; + + { + Synchronized s(_monitor); + + // std::cout << "Thread " << _count << " completed " << std::endl; + + _count--; + + if (_count == 0) { + + _monitor.notify(); + } + } + } + + Monitor& _monitor; + size_t& _count; + int64_t _timeout; + int64_t _startTime; + int64_t _endTime; + bool _done; + Monitor _sleep; + }; + + /** + * Dispatch count tasks, each of which blocks for timeout milliseconds then + * completes. Verify that all tasks completed and that thread manager cleans + * up properly on delete. + */ + bool loadTest(size_t count=100, int64_t timeout=100LL, size_t workerCount=4) { + + Monitor monitor; + + size_t activeCount = count; + + shared_ptr<ThreadManager> threadManager = ThreadManager::newSimpleThreadManager(workerCount); + + shared_ptr<PosixThreadFactory> threadFactory = shared_ptr<PosixThreadFactory>(new PosixThreadFactory()); + + threadFactory->setPriority(PosixThreadFactory::HIGHEST); + + threadManager->threadFactory(threadFactory); + + threadManager->start(); + + std::set<shared_ptr<ThreadManagerTests::Task> > tasks; + + for (size_t ix = 0; ix < count; ix++) { + + tasks.insert(shared_ptr<ThreadManagerTests::Task>(new ThreadManagerTests::Task(monitor, activeCount, timeout))); + } + + int64_t time00 = Util::currentTime(); + + for (std::set<shared_ptr<ThreadManagerTests::Task> >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) { + + threadManager->add(*ix); + } + + { + Synchronized s(monitor); + + while(activeCount > 0) { + + monitor.wait(); + } + } + + int64_t time01 = Util::currentTime(); + + int64_t firstTime = 9223372036854775807LL; + int64_t lastTime = 0; + + double averageTime = 0; + int64_t minTime = 9223372036854775807LL; + int64_t maxTime = 0; + + for (std::set<shared_ptr<ThreadManagerTests::Task> >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) { + + shared_ptr<ThreadManagerTests::Task> task = *ix; + + int64_t delta = task->_endTime - task->_startTime; + + assert(delta > 0); + + if (task->_startTime < firstTime) { + firstTime = task->_startTime; + } + + if (task->_endTime > lastTime) { + lastTime = task->_endTime; + } + + if (delta < minTime) { + minTime = delta; + } + + if (delta > maxTime) { + maxTime = delta; + } + + averageTime+= delta; + } + + averageTime /= count; + + std::cout << "\t\t\tfirst start: " << firstTime << "ms Last end: " << lastTime << "ms min: " << minTime << "ms max: " << maxTime << "ms average: " << averageTime << "ms" << std::endl; + + double expectedTime = ((count + (workerCount - 1)) / workerCount) * timeout; + + double error = ((time01 - time00) - expectedTime) / expectedTime; + + if (error < 0) { + error*= -1.0; + } + + bool success = error < ERROR; + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "! expected time: " << expectedTime << "ms elapsed time: "<< time01 - time00 << "ms error%: " << error * 100.0 << std::endl; + + return success; + } + + class BlockTask: public Runnable { + + public: + + BlockTask(Monitor& monitor, Monitor& bmonitor, size_t& count) : + _monitor(monitor), + _bmonitor(bmonitor), + _count(count) {} + + void run() { + { + Synchronized s(_bmonitor); + + _bmonitor.wait(); + + } + + { + Synchronized s(_monitor); + + _count--; + + if (_count == 0) { + + _monitor.notify(); + } + } + } + + Monitor& _monitor; + Monitor& _bmonitor; + size_t& _count; + }; + + /** + * Block test. Create pendingTaskCountMax tasks. Verify that we block adding the + * pendingTaskCountMax + 1th task. Verify that we unblock when a task completes */ + + bool blockTest(int64_t timeout=100LL, size_t workerCount=2) { + + bool success = false; + + try { + + Monitor bmonitor; + Monitor monitor; + + size_t pendingTaskMaxCount = workerCount; + + size_t activeCounts[] = {workerCount, pendingTaskMaxCount, 1}; + + shared_ptr<ThreadManager> threadManager = ThreadManager::newSimpleThreadManager(workerCount, pendingTaskMaxCount); + + shared_ptr<PosixThreadFactory> threadFactory = shared_ptr<PosixThreadFactory>(new PosixThreadFactory()); + + threadFactory->setPriority(PosixThreadFactory::HIGHEST); + + threadManager->threadFactory(threadFactory); + + threadManager->start(); + + std::set<shared_ptr<ThreadManagerTests::BlockTask> > tasks; + + for (size_t ix = 0; ix < workerCount; ix++) { + + tasks.insert(shared_ptr<ThreadManagerTests::BlockTask>(new ThreadManagerTests::BlockTask(monitor, bmonitor,activeCounts[0]))); + } + + for (size_t ix = 0; ix < pendingTaskMaxCount; ix++) { + + tasks.insert(shared_ptr<ThreadManagerTests::BlockTask>(new ThreadManagerTests::BlockTask(monitor, bmonitor,activeCounts[1]))); + } + + for (std::set<shared_ptr<ThreadManagerTests::BlockTask> >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) { + threadManager->add(*ix); + } + + if(!(success = (threadManager->totalTaskCount() == pendingTaskMaxCount + workerCount))) { + throw TException("Unexpected pending task count"); + } + + shared_ptr<ThreadManagerTests::BlockTask> extraTask(new ThreadManagerTests::BlockTask(monitor, bmonitor, activeCounts[2])); + + try { + threadManager->add(extraTask, 1); + throw TException("Unexpected success adding task in excess of pending task count"); + } catch(TimedOutException& e) { + } + + std::cout << "\t\t\t" << "Pending tasks " << threadManager->pendingTaskCount() << std::endl; + + { + Synchronized s(bmonitor); + + bmonitor.notifyAll(); + } + + { + Synchronized s(monitor); + + while(activeCounts[0] != 0) { + monitor.wait(); + } + } + + std::cout << "\t\t\t" << "Pending tasks " << threadManager->pendingTaskCount() << std::endl; + + try { + threadManager->add(extraTask, 1); + } catch(TimedOutException& e) { + std::cout << "\t\t\t" << "add timed out unexpectedly" << std::endl; + throw TException("Unexpected timeout adding task"); + + } catch(TooManyPendingTasksException& e) { + std::cout << "\t\t\t" << "add encountered too many pending exepctions" << std::endl; + throw TException("Unexpected timeout adding task"); + } + + // Wake up tasks that were pending before and wait for them to complete + + { + Synchronized s(bmonitor); + + bmonitor.notifyAll(); + } + + { + Synchronized s(monitor); + + while(activeCounts[1] != 0) { + monitor.wait(); + } + } + + // Wake up the extra task and wait for it to complete + + { + Synchronized s(bmonitor); + + bmonitor.notifyAll(); + } + + { + Synchronized s(monitor); + + while(activeCounts[2] != 0) { + monitor.wait(); + } + } + + if(!(success = (threadManager->totalTaskCount() == 0))) { + throw TException("Unexpected pending task count"); + } + + } catch(TException& e) { + } + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << std::endl; + return success; + } +}; + +const double ThreadManagerTests::ERROR = .20; + +}}}} // apache::thrift::concurrency + +using namespace apache::thrift::concurrency::test; + diff --git a/lib/cpp/src/concurrency/test/TimerManagerTests.h b/lib/cpp/src/concurrency/test/TimerManagerTests.h new file mode 100644 index 000000000..e6fe6ce7e --- /dev/null +++ b/lib/cpp/src/concurrency/test/TimerManagerTests.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <concurrency/TimerManager.h> +#include <concurrency/PosixThreadFactory.h> +#include <concurrency/Monitor.h> +#include <concurrency/Util.h> + +#include <assert.h> +#include <iostream> + +namespace apache { namespace thrift { namespace concurrency { namespace test { + +using namespace apache::thrift::concurrency; + +/** + * ThreadManagerTests class + * + * @version $Id:$ + */ +class TimerManagerTests { + + public: + + static const double ERROR; + + class Task: public Runnable { + public: + + Task(Monitor& monitor, int64_t timeout) : + _timeout(timeout), + _startTime(Util::currentTime()), + _monitor(monitor), + _success(false), + _done(false) {} + + ~Task() { std::cerr << this << std::endl; } + + void run() { + + _endTime = Util::currentTime(); + + // Figure out error percentage + + int64_t delta = _endTime - _startTime; + + + delta = delta > _timeout ? delta - _timeout : _timeout - delta; + + float error = delta / _timeout; + + if(error < ERROR) { + _success = true; + } + + _done = true; + + std::cout << "\t\t\tTimerManagerTests::Task[" << this << "] done" << std::endl; //debug + + {Synchronized s(_monitor); + _monitor.notifyAll(); + } + } + + int64_t _timeout; + int64_t _startTime; + int64_t _endTime; + Monitor& _monitor; + bool _success; + bool _done; + }; + + /** + * This test creates two tasks and waits for the first to expire within 10% + * of the expected expiration time. It then verifies that the timer manager + * properly clean up itself and the remaining orphaned timeout task when the + * manager goes out of scope and its destructor is called. + */ + bool test00(int64_t timeout=1000LL) { + + shared_ptr<TimerManagerTests::Task> orphanTask = shared_ptr<TimerManagerTests::Task>(new TimerManagerTests::Task(_monitor, 10 * timeout)); + + { + + TimerManager timerManager; + + timerManager.threadFactory(shared_ptr<PosixThreadFactory>(new PosixThreadFactory())); + + timerManager.start(); + + assert(timerManager.state() == TimerManager::STARTED); + + shared_ptr<TimerManagerTests::Task> task = shared_ptr<TimerManagerTests::Task>(new TimerManagerTests::Task(_monitor, timeout)); + + { + Synchronized s(_monitor); + + timerManager.add(orphanTask, 10 * timeout); + + timerManager.add(task, timeout); + + _monitor.wait(); + } + + assert(task->_done); + + + std::cout << "\t\t\t" << (task->_success ? "Success" : "Failure") << "!" << std::endl; + } + + // timerManager.stop(); This is where it happens via destructor + + assert(!orphanTask->_done); + + return true; + } + + friend class TestTask; + + Monitor _monitor; +}; + +const double TimerManagerTests::ERROR = .20; + +}}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/processor/PeekProcessor.cpp b/lib/cpp/src/processor/PeekProcessor.cpp new file mode 100644 index 000000000..c721861bc --- /dev/null +++ b/lib/cpp/src/processor/PeekProcessor.cpp @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "PeekProcessor.h" + +using namespace apache::thrift::transport; +using namespace apache::thrift::protocol; +using namespace apache::thrift; + +namespace apache { namespace thrift { namespace processor { + +PeekProcessor::PeekProcessor() { + memoryBuffer_.reset(new TMemoryBuffer()); + targetTransport_ = memoryBuffer_; +} +PeekProcessor::~PeekProcessor() {} + +void PeekProcessor::initialize(boost::shared_ptr<TProcessor> actualProcessor, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<TPipedTransportFactory> transportFactory) { + actualProcessor_ = actualProcessor; + pipedProtocol_ = protocolFactory->getProtocol(targetTransport_); + transportFactory_ = transportFactory; + transportFactory_->initializeTargetTransport(targetTransport_); +} + +boost::shared_ptr<TTransport> PeekProcessor::getPipedTransport(boost::shared_ptr<TTransport> in) { + return transportFactory_->getTransport(in); +} + +void PeekProcessor::setTargetTransport(boost::shared_ptr<TTransport> targetTransport) { + targetTransport_ = targetTransport; + if (boost::dynamic_pointer_cast<TMemoryBuffer>(targetTransport_)) { + memoryBuffer_ = boost::dynamic_pointer_cast<TMemoryBuffer>(targetTransport); + } else if (boost::dynamic_pointer_cast<TPipedTransport>(targetTransport_)) { + memoryBuffer_ = boost::dynamic_pointer_cast<TMemoryBuffer>(boost::dynamic_pointer_cast<TPipedTransport>(targetTransport_)->getTargetTransport()); + } + + if (!memoryBuffer_) { + throw TException("Target transport must be a TMemoryBuffer or a TPipedTransport with TMemoryBuffer"); + } +} + +bool PeekProcessor::process(boost::shared_ptr<TProtocol> in, + boost::shared_ptr<TProtocol> out) { + + std::string fname; + TMessageType mtype; + int32_t seqid; + in->readMessageBegin(fname, mtype, seqid); + + if (mtype != T_CALL) { + throw TException("Unexpected message type"); + } + + // Peek at the name + peekName(fname); + + TType ftype; + int16_t fid; + while (true) { + in->readFieldBegin(fname, ftype, fid); + if (ftype == T_STOP) { + break; + } + + // Peek at the variable + peek(in, ftype, fid); + in->readFieldEnd(); + } + in->readMessageEnd(); + in->getTransport()->readEnd(); + + // + // All the data is now in memoryBuffer_ and ready to be processed + // + + // Let's first take a peek at the full data in memory + uint8_t* buffer; + uint32_t size; + memoryBuffer_->getBuffer(&buffer, &size); + peekBuffer(buffer, size); + + // Done peeking at variables + peekEnd(); + + bool ret = actualProcessor_->process(pipedProtocol_, out); + memoryBuffer_->resetBuffer(); + return ret; +} + +void PeekProcessor::peekName(const std::string& fname) { +} + +void PeekProcessor::peekBuffer(uint8_t* buffer, uint32_t size) { +} + +void PeekProcessor::peek(boost::shared_ptr<TProtocol> in, + TType ftype, + int16_t fid) { + in->skip(ftype); +} + +void PeekProcessor::peekEnd() {} + +}}} diff --git a/lib/cpp/src/processor/PeekProcessor.h b/lib/cpp/src/processor/PeekProcessor.h new file mode 100644 index 000000000..0f7c016a0 --- /dev/null +++ b/lib/cpp/src/processor/PeekProcessor.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef PEEKPROCESSOR_H +#define PEEKPROCESSOR_H + +#include <string> +#include <TProcessor.h> +#include <transport/TTransport.h> +#include <transport/TTransportUtils.h> +#include <transport/TBufferTransports.h> +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace processor { + +/* + * Class for peeking at the raw data that is being processed by another processor + * and gives the derived class a chance to change behavior accordingly + * + */ +class PeekProcessor : public apache::thrift::TProcessor { + + public: + PeekProcessor(); + virtual ~PeekProcessor(); + + // Input here: actualProcessor - the underlying processor + // protocolFactory - the protocol factory used to wrap the memory buffer + // transportFactory - this TPipedTransportFactory is used to wrap the source transport + // via a call to getPipedTransport + void initialize(boost::shared_ptr<apache::thrift::TProcessor> actualProcessor, + boost::shared_ptr<apache::thrift::protocol::TProtocolFactory> protocolFactory, + boost::shared_ptr<apache::thrift::transport::TPipedTransportFactory> transportFactory); + + boost::shared_ptr<apache::thrift::transport::TTransport> getPipedTransport(boost::shared_ptr<apache::thrift::transport::TTransport> in); + + void setTargetTransport(boost::shared_ptr<apache::thrift::transport::TTransport> targetTransport); + + virtual bool process(boost::shared_ptr<apache::thrift::protocol::TProtocol> in, + boost::shared_ptr<apache::thrift::protocol::TProtocol> out); + + // The following three functions can be overloaded by child classes to + // achieve desired peeking behavior + virtual void peekName(const std::string& fname); + virtual void peekBuffer(uint8_t* buffer, uint32_t size); + virtual void peek(boost::shared_ptr<apache::thrift::protocol::TProtocol> in, + apache::thrift::protocol::TType ftype, + int16_t fid); + virtual void peekEnd(); + + private: + boost::shared_ptr<apache::thrift::TProcessor> actualProcessor_; + boost::shared_ptr<apache::thrift::protocol::TProtocol> pipedProtocol_; + boost::shared_ptr<apache::thrift::transport::TPipedTransportFactory> transportFactory_; + boost::shared_ptr<apache::thrift::transport::TMemoryBuffer> memoryBuffer_; + boost::shared_ptr<apache::thrift::transport::TTransport> targetTransport_; +}; + +}}} // apache::thrift::processor + +#endif diff --git a/lib/cpp/src/processor/StatsProcessor.h b/lib/cpp/src/processor/StatsProcessor.h new file mode 100644 index 000000000..820b3ad4b --- /dev/null +++ b/lib/cpp/src/processor/StatsProcessor.h @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef STATSPROCESSOR_H +#define STATSPROCESSOR_H + +#include <boost/shared_ptr.hpp> +#include <transport/TTransport.h> +#include <protocol/TProtocol.h> +#include <TProcessor.h> + +namespace apache { namespace thrift { namespace processor { + +/* + * Class for keeping track of function call statistics and printing them if desired + * + */ +class StatsProcessor : public apache::thrift::TProcessor { +public: + StatsProcessor(bool print, bool frequency) + : print_(print), + frequency_(frequency) + {} + virtual ~StatsProcessor() {}; + + virtual bool process(boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot, boost::shared_ptr<apache::thrift::protocol::TProtocol> poprot) { + + piprot_ = piprot; + + std::string fname; + apache::thrift::protocol::TMessageType mtype; + int32_t seqid; + + piprot_->readMessageBegin(fname, mtype, seqid); + if (mtype != apache::thrift::protocol::T_CALL) { + if (print_) { + printf("Unknown message type\n"); + } + throw apache::thrift::TException("Unexpected message type"); + } + if (print_) { + printf("%s (", fname.c_str()); + } + if (frequency_) { + if (frequency_map_.find(fname) != frequency_map_.end()) { + frequency_map_[fname]++; + } else { + frequency_map_[fname] = 1; + } + } + + apache::thrift::protocol::TType ftype; + int16_t fid; + + while (true) { + piprot_->readFieldBegin(fname, ftype, fid); + if (ftype == apache::thrift::protocol::T_STOP) { + break; + } + + printAndPassToBuffer(ftype); + if (print_) { + printf(", "); + } + } + + if (print_) { + printf("\b\b)\n"); + } + return true; + } + + const std::map<std::string, int64_t>& get_frequency_map() { + return frequency_map_; + } + +protected: + void printAndPassToBuffer(apache::thrift::protocol::TType ftype) { + switch (ftype) { + case apache::thrift::protocol::T_BOOL: + { + bool boolv; + piprot_->readBool(boolv); + if (print_) { + printf("%d", boolv); + } + } + break; + case apache::thrift::protocol::T_BYTE: + { + int8_t bytev; + piprot_->readByte(bytev); + if (print_) { + printf("%d", bytev); + } + } + break; + case apache::thrift::protocol::T_I16: + { + int16_t i16; + piprot_->readI16(i16); + if (print_) { + printf("%d", i16); + } + } + break; + case apache::thrift::protocol::T_I32: + { + int32_t i32; + piprot_->readI32(i32); + if (print_) { + printf("%d", i32); + } + } + break; + case apache::thrift::protocol::T_I64: + { + int64_t i64; + piprot_->readI64(i64); + if (print_) { + printf("%ld", i64); + } + } + break; + case apache::thrift::protocol::T_DOUBLE: + { + double dub; + piprot_->readDouble(dub); + if (print_) { + printf("%f", dub); + } + } + break; + case apache::thrift::protocol::T_STRING: + { + std::string str; + piprot_->readString(str); + if (print_) { + printf("%s", str.c_str()); + } + } + break; + case apache::thrift::protocol::T_STRUCT: + { + std::string name; + int16_t fid; + apache::thrift::protocol::TType ftype; + piprot_->readStructBegin(name); + if (print_) { + printf("<"); + } + while (true) { + piprot_->readFieldBegin(name, ftype, fid); + if (ftype == apache::thrift::protocol::T_STOP) { + break; + } + printAndPassToBuffer(ftype); + if (print_) { + printf(","); + } + piprot_->readFieldEnd(); + } + piprot_->readStructEnd(); + if (print_) { + printf("\b>"); + } + } + break; + case apache::thrift::protocol::T_MAP: + { + apache::thrift::protocol::TType keyType; + apache::thrift::protocol::TType valType; + uint32_t i, size; + piprot_->readMapBegin(keyType, valType, size); + if (print_) { + printf("{"); + } + for (i = 0; i < size; i++) { + printAndPassToBuffer(keyType); + if (print_) { + printf("=>"); + } + printAndPassToBuffer(valType); + if (print_) { + printf(","); + } + } + piprot_->readMapEnd(); + if (print_) { + printf("\b}"); + } + } + break; + case apache::thrift::protocol::T_SET: + { + apache::thrift::protocol::TType elemType; + uint32_t i, size; + piprot_->readSetBegin(elemType, size); + if (print_) { + printf("{"); + } + for (i = 0; i < size; i++) { + printAndPassToBuffer(elemType); + if (print_) { + printf(","); + } + } + piprot_->readSetEnd(); + if (print_) { + printf("\b}"); + } + } + break; + case apache::thrift::protocol::T_LIST: + { + apache::thrift::protocol::TType elemType; + uint32_t i, size; + piprot_->readListBegin(elemType, size); + if (print_) { + printf("["); + } + for (i = 0; i < size; i++) { + printAndPassToBuffer(elemType); + if (print_) { + printf(","); + } + } + piprot_->readListEnd(); + if (print_) { + printf("\b]"); + } + } + break; + default: + break; + } + } + + boost::shared_ptr<apache::thrift::protocol::TProtocol> piprot_; + std::map<std::string, int64_t> frequency_map_; + + bool print_; + bool frequency_; +}; + +}}} // apache::thrift::processor + +#endif diff --git a/lib/cpp/src/protocol/TBase64Utils.cpp b/lib/cpp/src/protocol/TBase64Utils.cpp new file mode 100644 index 000000000..14481c49c --- /dev/null +++ b/lib/cpp/src/protocol/TBase64Utils.cpp @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TBase64Utils.h" + +#include <boost/static_assert.hpp> + +using std::string; + +namespace apache { namespace thrift { namespace protocol { + + +static const uint8_t *kBase64EncodeTable = (const uint8_t *) + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +void base64_encode(const uint8_t *in, uint32_t len, uint8_t *buf) { + buf[0] = kBase64EncodeTable[(in[0] >> 2) & 0x3F]; + if (len == 3) { + buf[1] = kBase64EncodeTable[((in[0] << 4) + (in[1] >> 4)) & 0x3f]; + buf[2] = kBase64EncodeTable[((in[1] << 2) + (in[2] >> 6)) & 0x3f]; + buf[3] = kBase64EncodeTable[in[2] & 0x3f]; + } else if (len == 2) { + buf[1] = kBase64EncodeTable[((in[0] << 4) + (in[1] >> 4)) & 0x3f]; + buf[2] = kBase64EncodeTable[(in[1] << 2) & 0x3f]; + } else { // len == 1 + buf[1] = kBase64EncodeTable[(in[0] << 4) & 0x3f]; + } +} + +static const uint8_t kBase64DecodeTable[256] ={ + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,62,-1,-1,-1,63, + 52,53,54,55,56,57,58,59,60,61,-1,-1,-1,-1,-1,-1, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14, + 15,16,17,18,19,20,21,22,23,24,25,-1,-1,-1,-1,-1, + -1,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40, + 41,42,43,44,45,46,47,48,49,50,51,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, +}; + +void base64_decode(uint8_t *buf, uint32_t len) { + buf[0] = (kBase64DecodeTable[buf[0]] << 2) | + (kBase64DecodeTable[buf[1]] >> 4); + if (len > 2) { + buf[1] = ((kBase64DecodeTable[buf[1]] << 4) & 0xf0) | + (kBase64DecodeTable[buf[2]] >> 2); + if (len > 3) { + buf[2] = ((kBase64DecodeTable[buf[2]] << 6) & 0xc0) | + (kBase64DecodeTable[buf[3]]); + } + } +} + + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TBase64Utils.h b/lib/cpp/src/protocol/TBase64Utils.h new file mode 100644 index 000000000..3def73350 --- /dev/null +++ b/lib/cpp/src/protocol/TBase64Utils.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TBASE64UTILS_H_ +#define _THRIFT_PROTOCOL_TBASE64UTILS_H_ + +#include <stdint.h> +#include <string> + +namespace apache { namespace thrift { namespace protocol { + +// in must be at least len bytes +// len must be 1, 2, or 3 +// buf must be a buffer of at least 4 bytes and may not overlap in +// the data is not padded with '='; the caller can do this if desired +void base64_encode(const uint8_t *in, uint32_t len, uint8_t *buf); + +// buf must be a buffer of at least 4 bytes and contain base64 encoded values +// buf will be changed to contain output bytes +// len is number of bytes to consume from input (must be 2, 3, or 4) +// no '=' padding should be included in the input +void base64_decode(uint8_t *buf, uint32_t len); + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_PROTOCOL_TBASE64UTILS_H_ diff --git a/lib/cpp/src/protocol/TBinaryProtocol.cpp b/lib/cpp/src/protocol/TBinaryProtocol.cpp new file mode 100644 index 000000000..6a4838b44 --- /dev/null +++ b/lib/cpp/src/protocol/TBinaryProtocol.cpp @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TBinaryProtocol.h" + +#include <limits> + +using std::string; + +namespace apache { namespace thrift { namespace protocol { + +uint32_t TBinaryProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + if (strict_write_) { + int32_t version = (VERSION_1) | ((int32_t)messageType); + uint32_t wsize = 0; + wsize += writeI32(version); + wsize += writeString(name); + wsize += writeI32(seqid); + return wsize; + } else { + uint32_t wsize = 0; + wsize += writeString(name); + wsize += writeByte((int8_t)messageType); + wsize += writeI32(seqid); + return wsize; + } +} + +uint32_t TBinaryProtocol::writeMessageEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeStructBegin(const char* name) { + return 0; +} + +uint32_t TBinaryProtocol::writeStructEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)fieldType); + wsize += writeI16(fieldId); + return wsize; +} + +uint32_t TBinaryProtocol::writeFieldEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeFieldStop() { + return + writeByte((int8_t)T_STOP); +} + +uint32_t TBinaryProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)keyType); + wsize += writeByte((int8_t)valType); + wsize += writeI32((int32_t)size); + return wsize; +} + +uint32_t TBinaryProtocol::writeMapEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t) elemType); + wsize += writeI32((int32_t)size); + return wsize; +} + +uint32_t TBinaryProtocol::writeListEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)elemType); + wsize += writeI32((int32_t)size); + return wsize; +} + +uint32_t TBinaryProtocol::writeSetEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeBool(const bool value) { + uint8_t tmp = value ? 1 : 0; + trans_->write(&tmp, 1); + return 1; +} + +uint32_t TBinaryProtocol::writeByte(const int8_t byte) { + trans_->write((uint8_t*)&byte, 1); + return 1; +} + +uint32_t TBinaryProtocol::writeI16(const int16_t i16) { + int16_t net = (int16_t)htons(i16); + trans_->write((uint8_t*)&net, 2); + return 2; +} + +uint32_t TBinaryProtocol::writeI32(const int32_t i32) { + int32_t net = (int32_t)htonl(i32); + trans_->write((uint8_t*)&net, 4); + return 4; +} + +uint32_t TBinaryProtocol::writeI64(const int64_t i64) { + int64_t net = (int64_t)htonll(i64); + trans_->write((uint8_t*)&net, 8); + return 8; +} + +uint32_t TBinaryProtocol::writeDouble(const double dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559); + + uint64_t bits = bitwise_cast<uint64_t>(dub); + bits = htonll(bits); + trans_->write((uint8_t*)&bits, 8); + return 8; +} + + +uint32_t TBinaryProtocol::writeString(const string& str) { + uint32_t size = str.size(); + uint32_t result = writeI32((int32_t)size); + if (size > 0) { + trans_->write((uint8_t*)str.data(), size); + } + return result + size; +} + +uint32_t TBinaryProtocol::writeBinary(const string& str) { + return TBinaryProtocol::writeString(str); +} + +/** + * Reading functions + */ + +uint32_t TBinaryProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t result = 0; + int32_t sz; + result += readI32(sz); + + if (sz < 0) { + // Check for correct version number + int32_t version = sz & VERSION_MASK; + if (version != VERSION_1) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier"); + } + messageType = (TMessageType)(sz & 0x000000ff); + result += readString(name); + result += readI32(seqid); + } else { + if (strict_read_) { + throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?"); + } else { + // Handle pre-versioned input + int8_t type; + result += readStringBody(name, sz); + result += readByte(type); + messageType = (TMessageType)type; + result += readI32(seqid); + } + } + return result; +} + +uint32_t TBinaryProtocol::readMessageEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readStructBegin(string& name) { + name = ""; + return 0; +} + +uint32_t TBinaryProtocol::readStructEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readFieldBegin(string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t result = 0; + int8_t type; + result += readByte(type); + fieldType = (TType)type; + if (fieldType == T_STOP) { + fieldId = 0; + return result; + } + result += readI16(fieldId); + return result; +} + +uint32_t TBinaryProtocol::readFieldEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + int8_t k, v; + uint32_t result = 0; + int32_t sizei; + result += readByte(k); + keyType = (TType)k; + result += readByte(v); + valType = (TType)v; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +uint32_t TBinaryProtocol::readMapEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readListBegin(TType& elemType, + uint32_t& size) { + int8_t e; + uint32_t result = 0; + int32_t sizei; + result += readByte(e); + elemType = (TType)e; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +uint32_t TBinaryProtocol::readListEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + int8_t e; + uint32_t result = 0; + int32_t sizei; + result += readByte(e); + elemType = (TType)e; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +uint32_t TBinaryProtocol::readSetEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readBool(bool& value) { + uint8_t b[1]; + trans_->readAll(b, 1); + value = *(int8_t*)b != 0; + return 1; +} + +uint32_t TBinaryProtocol::readByte(int8_t& byte) { + uint8_t b[1]; + trans_->readAll(b, 1); + byte = *(int8_t*)b; + return 1; +} + +uint32_t TBinaryProtocol::readI16(int16_t& i16) { + uint8_t b[2]; + trans_->readAll(b, 2); + i16 = *(int16_t*)b; + i16 = (int16_t)ntohs(i16); + return 2; +} + +uint32_t TBinaryProtocol::readI32(int32_t& i32) { + uint8_t b[4]; + trans_->readAll(b, 4); + i32 = *(int32_t*)b; + i32 = (int32_t)ntohl(i32); + return 4; +} + +uint32_t TBinaryProtocol::readI64(int64_t& i64) { + uint8_t b[8]; + trans_->readAll(b, 8); + i64 = *(int64_t*)b; + i64 = (int64_t)ntohll(i64); + return 8; +} + +uint32_t TBinaryProtocol::readDouble(double& dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559); + + uint64_t bits; + uint8_t b[8]; + trans_->readAll(b, 8); + bits = *(uint64_t*)b; + bits = ntohll(bits); + dub = bitwise_cast<double>(bits); + return 8; +} + +uint32_t TBinaryProtocol::readString(string& str) { + uint32_t result; + int32_t size; + result = readI32(size); + return result + readStringBody(str, size); +} + +uint32_t TBinaryProtocol::readBinary(string& str) { + return TBinaryProtocol::readString(str); +} + +uint32_t TBinaryProtocol::readStringBody(string& str, int32_t size) { + uint32_t result = 0; + + // Catch error cases + if (size < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } + if (string_limit_ > 0 && size > string_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + // Catch empty string case + if (size == 0) { + str = ""; + return result; + } + + // Use the heap here to prevent stack overflow for v. large strings + if (size > string_buf_size_ || string_buf_ == NULL) { + void* new_string_buf = std::realloc(string_buf_, (uint32_t)size); + if (new_string_buf == NULL) { + throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TBinaryProtocol::readString"); + } + string_buf_ = (uint8_t*)new_string_buf; + string_buf_size_ = size; + } + trans_->readAll(string_buf_, size); + str = string((char*)string_buf_, size); + return (uint32_t)size; +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TBinaryProtocol.h b/lib/cpp/src/protocol/TBinaryProtocol.h new file mode 100644 index 000000000..7fd3de673 --- /dev/null +++ b/lib/cpp/src/protocol/TBinaryProtocol.h @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ 1 + +#include "TProtocol.h" + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace protocol { + +/** + * The default binary protocol for thrift. Writes all data in a very basic + * binary format, essentially just spitting out the raw bytes. + * + */ +class TBinaryProtocol : public TProtocol { + protected: + static const int32_t VERSION_MASK = 0xffff0000; + static const int32_t VERSION_1 = 0x80010000; + // VERSION_2 (0x80020000) is taken by TDenseProtocol. + + public: + TBinaryProtocol(boost::shared_ptr<TTransport> trans) : + TProtocol(trans), + string_limit_(0), + container_limit_(0), + strict_read_(false), + strict_write_(true), + string_buf_(NULL), + string_buf_size_(0) {} + + TBinaryProtocol(boost::shared_ptr<TTransport> trans, + int32_t string_limit, + int32_t container_limit, + bool strict_read, + bool strict_write) : + TProtocol(trans), + string_limit_(string_limit), + container_limit_(container_limit), + strict_read_(strict_read), + strict_write_(strict_write), + string_buf_(NULL), + string_buf_size_(0) {} + + ~TBinaryProtocol() { + if (string_buf_ != NULL) { + std::free(string_buf_); + string_buf_size_ = 0; + } + } + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) { + container_limit_ = container_limit; + } + + void setStrict(bool strict_read, bool strict_write) { + strict_read_ = strict_read; + strict_write_ = strict_write; + } + + /** + * Writing functions. + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + virtual uint32_t writeMessageEnd(); + + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldEnd(); + + uint32_t writeFieldStop(); + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeMapEnd(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeListEnd(); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetEnd(); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + /** + * Reading functions + */ + + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readMessageEnd(); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readFieldEnd(); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readMapEnd(); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readListEnd(); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetEnd(); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + protected: + uint32_t readStringBody(std::string& str, int32_t sz); + + int32_t string_limit_; + int32_t container_limit_; + + // Enforce presence of version identifier + bool strict_read_; + bool strict_write_; + + // Buffer for reading strings, save for the lifetime of the protocol to + // avoid memory churn allocating memory on every string read + uint8_t* string_buf_; + int32_t string_buf_size_; + +}; + +/** + * Constructs binary protocol handlers + */ +class TBinaryProtocolFactory : public TProtocolFactory { + public: + TBinaryProtocolFactory() : + string_limit_(0), + container_limit_(0), + strict_read_(false), + strict_write_(true) {} + + TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit, bool strict_read, bool strict_write) : + string_limit_(string_limit), + container_limit_(container_limit), + strict_read_(strict_read), + strict_write_(strict_write) {} + + virtual ~TBinaryProtocolFactory() {} + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) { + container_limit_ = container_limit; + } + + void setStrict(bool strict_read, bool strict_write) { + strict_read_ = strict_read; + strict_write_ = strict_write; + } + + boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TProtocol>(new TBinaryProtocol(trans, string_limit_, container_limit_, strict_read_, strict_write_)); + } + + private: + int32_t string_limit_; + int32_t container_limit_; + bool strict_read_; + bool strict_write_; + +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TCompactProtocol.cpp b/lib/cpp/src/protocol/TCompactProtocol.cpp new file mode 100644 index 000000000..ce2ee54d2 --- /dev/null +++ b/lib/cpp/src/protocol/TCompactProtocol.cpp @@ -0,0 +1,736 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TCompactProtocol.h" + +#include <config.h> +#include <limits> + +/* + * TCompactProtocol::i*ToZigzag depend on the fact that the right shift + * operator on a signed integer is an arithmetic (sign-extending) shift. + * If this is not the case, the current implementation will not work. + * If anyone encounters this error, we can try to figure out the best + * way to implement an arithmetic right shift on their platform. + */ +#if !defined(SIGNED_RIGHT_SHIFT_IS) || !defined(ARITHMETIC_RIGHT_SHIFT) +# error "Unable to determine the behavior of a signed right shift" +#endif +#if SIGNED_RIGHT_SHIFT_IS != ARITHMETIC_RIGHT_SHIFT +# error "TCompactProtocol currenly only works if a signed right shift is arithmetic" +#endif + +#ifdef __GNUC__ +#define UNLIKELY(val) (__builtin_expect((val), 0)) +#else +#define UNLIKELY(val) (val) +#endif + +namespace apache { namespace thrift { namespace protocol { + +const int8_t TCompactProtocol::TTypeToCType[16] = { + CT_STOP, // T_STOP + 0, // unused + CT_BOOLEAN_TRUE, // T_BOOL + CT_BYTE, // T_BYTE + CT_DOUBLE, // T_DOUBLE + 0, // unused + CT_I16, // T_I16 + 0, // unused + CT_I32, // T_I32 + 0, // unused + CT_I64, // T_I64 + CT_BINARY, // T_STRING + CT_STRUCT, // T_STRUCT + CT_MAP, // T_MAP + CT_SET, // T_SET + CT_LIST, // T_LIST + }; + + +uint32_t TCompactProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + uint32_t wsize = 0; + wsize += writeByte(PROTOCOL_ID); + wsize += writeByte((VERSION_N & VERSION_MASK) | (((int32_t)messageType << TYPE_SHIFT_AMOUNT) & TYPE_MASK)); + wsize += writeVarint32(seqid); + wsize += writeString(name); + return wsize; +} + +/** + * Write a field header containing the field id and field type. If the + * difference between the current field id and the last one is small (< 15), + * then the field id will be encoded in the 4 MSB as a delta. Otherwise, the + * field id will follow the type header as a zigzag varint. + */ +uint32_t TCompactProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + if (fieldType == T_BOOL) { + booleanField_.name = name; + booleanField_.fieldType = fieldType; + booleanField_.fieldId = fieldId; + } else { + return writeFieldBeginInternal(name, fieldType, fieldId, -1); + } + return 0; +} + +/** + * Write the STOP symbol so we know there are no more fields in this struct. + */ +uint32_t TCompactProtocol::writeFieldStop() { + return writeByte(T_STOP); +} + +/** + * Write a struct begin. This doesn't actually put anything on the wire. We + * use it as an opportunity to put special placeholder markers on the field + * stack so we can get the field id deltas correct. + */ +uint32_t TCompactProtocol::writeStructBegin(const char* name) { + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + return 0; +} + +/** + * Write a struct end. This doesn't actually put anything on the wire. We use + * this as an opportunity to pop the last field from the current struct off + * of the field stack. + */ +uint32_t TCompactProtocol::writeStructEnd() { + lastFieldId_ = lastField_.top(); + lastField_.pop(); + return 0; +} + +/** + * Write a List header. + */ +uint32_t TCompactProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + return writeCollectionBegin(elemType, size); +} + +/** + * Write a set header. + */ +uint32_t TCompactProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + return writeCollectionBegin(elemType, size); +} + +/** + * Write a map header. If the map is empty, omit the key and value type + * headers, as we don't need any additional information to skip it. + */ +uint32_t TCompactProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t wsize = 0; + + if (size == 0) { + wsize += writeByte(0); + } else { + wsize += writeVarint32(size); + wsize += writeByte(getCompactType(keyType) << 4 | getCompactType(valType)); + } + return wsize; +} + +/** + * Write a boolean value. Potentially, this could be a boolean field, in + * which case the field header info isn't written yet. If so, decide what the + * right type header is for the value and then write the field header. + * Otherwise, write a single byte. + */ +uint32_t TCompactProtocol::writeBool(const bool value) { + uint32_t wsize = 0; + + if (booleanField_.name != NULL) { + // we haven't written the field header yet + wsize += writeFieldBeginInternal(booleanField_.name, + booleanField_.fieldType, + booleanField_.fieldId, + value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE); + booleanField_.name = NULL; + } else { + // we're not part of a field, so just write the value + wsize += writeByte(value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE); + } + return wsize; +} + +uint32_t TCompactProtocol::writeByte(const int8_t byte) { + trans_->write((uint8_t*)&byte, 1); + return 1; +} + +/** + * Write an i16 as a zigzag varint. + */ +uint32_t TCompactProtocol::writeI16(const int16_t i16) { + return writeVarint32(i32ToZigzag(i16)); +} + +/** + * Write an i32 as a zigzag varint. + */ +uint32_t TCompactProtocol::writeI32(const int32_t i32) { + return writeVarint32(i32ToZigzag(i32)); +} + +/** + * Write an i64 as a zigzag varint. + */ +uint32_t TCompactProtocol::writeI64(const int64_t i64) { + return writeVarint64(i64ToZigzag(i64)); +} + +/** + * Write a double to the wire as 8 bytes. + */ +uint32_t TCompactProtocol::writeDouble(const double dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559); + + uint64_t bits = bitwise_cast<uint64_t>(dub); + bits = htolell(bits); + trans_->write((uint8_t*)&bits, 8); + return 8; +} + +/** + * Write a string to the wire with a varint size preceeding. + */ +uint32_t TCompactProtocol::writeString(const std::string& str) { + return writeBinary(str); +} + +uint32_t TCompactProtocol::writeBinary(const std::string& str) { + uint32_t ssize = str.size(); + uint32_t wsize = writeVarint32(ssize) + ssize; + trans_->write((uint8_t*)str.data(), ssize); + return wsize; +} + +// +// Internal Writing methods +// + +/** + * The workhorse of writeFieldBegin. It has the option of doing a + * 'type override' of the type header. This is used specifically in the + * boolean field case. + */ +int32_t TCompactProtocol::writeFieldBeginInternal(const char* name, + const TType fieldType, + const int16_t fieldId, + int8_t typeOverride) { + uint32_t wsize = 0; + + // if there's a type override, use that. + int8_t typeToWrite = (typeOverride == -1 ? getCompactType(fieldType) : typeOverride); + + // check if we can use delta encoding for the field id + if (fieldId > lastFieldId_ && fieldId - lastFieldId_ <= 15) { + // write them together + wsize += writeByte((fieldId - lastFieldId_) << 4 | typeToWrite); + } else { + // write them separate + wsize += writeByte(typeToWrite); + wsize += writeI16(fieldId); + } + + lastFieldId_ = fieldId; + return wsize; +} + +/** + * Abstract method for writing the start of lists and sets. List and sets on + * the wire differ only by the type indicator. + */ +uint32_t TCompactProtocol::writeCollectionBegin(int8_t elemType, int32_t size) { + uint32_t wsize = 0; + if (size <= 14) { + wsize += writeByte(size << 4 | getCompactType(elemType)); + } else { + wsize += writeByte(0xf0 | getCompactType(elemType)); + wsize += writeVarint32(size); + } + return wsize; +} + +/** + * Write an i32 as a varint. Results in 1-5 bytes on the wire. + */ +uint32_t TCompactProtocol::writeVarint32(uint32_t n) { + uint8_t buf[5]; + uint32_t wsize = 0; + + while (true) { + if ((n & ~0x7F) == 0) { + buf[wsize++] = (int8_t)n; + break; + } else { + buf[wsize++] = (int8_t)((n & 0x7F) | 0x80); + n >>= 7; + } + } + trans_->write(buf, wsize); + return wsize; +} + +/** + * Write an i64 as a varint. Results in 1-10 bytes on the wire. + */ +uint32_t TCompactProtocol::writeVarint64(uint64_t n) { + uint8_t buf[10]; + uint32_t wsize = 0; + + while (true) { + if ((n & ~0x7FL) == 0) { + buf[wsize++] = (int8_t)n; + break; + } else { + buf[wsize++] = (int8_t)((n & 0x7F) | 0x80); + n >>= 7; + } + } + trans_->write(buf, wsize); + return wsize; +} + +/** + * Convert l into a zigzag long. This allows negative numbers to be + * represented compactly as a varint. + */ +uint64_t TCompactProtocol::i64ToZigzag(const int64_t l) { + return (l << 1) ^ (l >> 63); +} + +/** + * Convert n into a zigzag int. This allows negative numbers to be + * represented compactly as a varint. + */ +uint32_t TCompactProtocol::i32ToZigzag(const int32_t n) { + return (n << 1) ^ (n >> 31); +} + +/** + * Given a TType value, find the appropriate TCompactProtocol.Type value + */ +int8_t TCompactProtocol::getCompactType(int8_t ttype) { + return TTypeToCType[ttype]; +} + +// +// Reading Methods +// + +/** + * Read a message header. + */ +uint32_t TCompactProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t rsize = 0; + int8_t protocolId; + int8_t versionAndType; + int8_t version; + + rsize += readByte(protocolId); + if (protocolId != PROTOCOL_ID) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad protocol identifier"); + } + + rsize += readByte(versionAndType); + version = (int8_t)(versionAndType & VERSION_MASK); + if (version != VERSION_N) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad protocol version"); + } + + messageType = (TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & 0x03); + rsize += readVarint32(seqid); + rsize += readString(name); + + return rsize; +} + +/** + * Read a struct begin. There's nothing on the wire for this, but it is our + * opportunity to push a new struct begin marker on the field stack. + */ +uint32_t TCompactProtocol::readStructBegin(std::string& name) { + name = ""; + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + return 0; +} + +/** + * Doesn't actually consume any wire data, just removes the last field for + * this struct from the field stack. + */ +uint32_t TCompactProtocol::readStructEnd() { + lastFieldId_ = lastField_.top(); + lastField_.pop(); + return 0; +} + +/** + * Read a field header off the wire. + */ +uint32_t TCompactProtocol::readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t rsize = 0; + int8_t byte; + int8_t type; + + rsize += readByte(byte); + type = (byte & 0x0f); + + // if it's a stop, then we can return immediately, as the struct is over. + if (type == T_STOP) { + fieldType = T_STOP; + fieldId = 0; + return rsize; + } + + // mask off the 4 MSB of the type header. it could contain a field id delta. + int16_t modifier = (int16_t)(((uint8_t)byte & 0xf0) >> 4); + if (modifier == 0) { + // not a delta, look ahead for the zigzag varint field id. + rsize += readI16(fieldId); + } else { + fieldId = (int16_t)(lastFieldId_ + modifier); + } + fieldType = getTType(type); + + // if this happens to be a boolean field, the value is encoded in the type + if (type == CT_BOOLEAN_TRUE || type == CT_BOOLEAN_FALSE) { + // save the boolean value in a special instance variable. + boolValue_.hasBoolValue = true; + boolValue_.boolValue = (type == CT_BOOLEAN_TRUE ? true : false); + } + + // push the new field onto the field stack so we can keep the deltas going. + lastFieldId_ = fieldId; + return rsize; +} + +/** + * Read a map header off the wire. If the size is zero, skip reading the key + * and value type. This means that 0-length maps will yield TMaps without the + * "correct" types. + */ +uint32_t TCompactProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + uint32_t rsize = 0; + int8_t kvType = 0; + int32_t msize = 0; + + rsize += readVarint32(msize); + if (msize != 0) + rsize += readByte(kvType); + + if (msize < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && msize > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + keyType = getTType((int8_t)((uint8_t)kvType >> 4)); + valType = getTType((int8_t)((uint8_t)kvType & 0xf)); + size = (uint32_t)msize; + + return rsize; +} + +/** + * Read a list header off the wire. If the list size is 0-14, the size will + * be packed into the element type header. If it's a longer list, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ +uint32_t TCompactProtocol::readListBegin(TType& elemType, + uint32_t& size) { + int8_t size_and_type; + uint32_t rsize = 0; + int32_t lsize; + + rsize += readByte(size_and_type); + + lsize = ((uint8_t)size_and_type >> 4) & 0x0f; + if (lsize == 15) { + rsize += readVarint32(lsize); + } + + if (lsize < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && lsize > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + elemType = getTType((int8_t)(size_and_type & 0x0f)); + size = (uint32_t)lsize; + + return rsize; +} + +/** + * Read a set header off the wire. If the set size is 0-14, the size will + * be packed into the element type header. If it's a longer set, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ +uint32_t TCompactProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + return readListBegin(elemType, size); +} + +/** + * Read a boolean off the wire. If this is a boolean field, the value should + * already have been read during readFieldBegin, so we'll just consume the + * pre-stored value. Otherwise, read a byte. + */ +uint32_t TCompactProtocol::readBool(bool& value) { + if (boolValue_.hasBoolValue == true) { + value = boolValue_.boolValue; + boolValue_.hasBoolValue = false; + return 0; + } else { + int8_t val; + readByte(val); + value = (val == CT_BOOLEAN_TRUE); + return 1; + } +} + +/** + * Read a single byte off the wire. Nothing interesting here. + */ +uint32_t TCompactProtocol::readByte(int8_t& byte) { + uint8_t b[1]; + trans_->readAll(b, 1); + byte = *(int8_t*)b; + return 1; +} + +/** + * Read an i16 from the wire as a zigzag varint. + */ +uint32_t TCompactProtocol::readI16(int16_t& i16) { + int32_t value; + uint32_t rsize = readVarint32(value); + i16 = (int16_t)zigzagToI32(value); + return rsize; +} + +/** + * Read an i32 from the wire as a zigzag varint. + */ +uint32_t TCompactProtocol::readI32(int32_t& i32) { + int32_t value; + uint32_t rsize = readVarint32(value); + i32 = zigzagToI32(value); + return rsize; +} + +/** + * Read an i64 from the wire as a zigzag varint. + */ +uint32_t TCompactProtocol::readI64(int64_t& i64) { + int64_t value; + uint32_t rsize = readVarint64(value); + i64 = zigzagToI64(value); + return rsize; +} + +/** + * No magic here - just read a double off the wire. + */ +uint32_t TCompactProtocol::readDouble(double& dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits<double>::is_iec559); + + uint64_t bits; + uint8_t b[8]; + trans_->readAll(b, 8); + bits = *(uint64_t*)b; + bits = letohll(bits); + dub = bitwise_cast<double>(bits); + return 8; +} + +uint32_t TCompactProtocol::readString(std::string& str) { + return readBinary(str); +} + +/** + * Read a byte[] from the wire. + */ +uint32_t TCompactProtocol::readBinary(std::string& str) { + int32_t rsize = 0; + int32_t size; + + rsize += readVarint32(size); + // Catch empty string case + if (size == 0) { + str = ""; + return rsize; + } + + // Catch error cases + if (size < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } + if (string_limit_ > 0 && size > string_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + // Use the heap here to prevent stack overflow for v. large strings + if (size > string_buf_size_ || string_buf_ == NULL) { + void* new_string_buf = std::realloc(string_buf_, (uint32_t)size); + if (new_string_buf == NULL) { + throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TCompactProtocol::readString"); + } + string_buf_ = (uint8_t*)new_string_buf; + string_buf_size_ = size; + } + trans_->readAll(string_buf_, size); + str.assign((char*)string_buf_, size); + + return rsize + (uint32_t)size; +} + +/** + * Read an i32 from the wire as a varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 5 bytes. + */ +uint32_t TCompactProtocol::readVarint32(int32_t& i32) { + int64_t val; + uint32_t rsize = readVarint64(val); + i32 = (int32_t)val; + return rsize; +} + +/** + * Read an i64 from the wire as a proper varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 10 bytes. + */ +uint32_t TCompactProtocol::readVarint64(int64_t& i64) { + uint32_t rsize = 0; + uint64_t val = 0; + int shift = 0; + uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. + uint32_t buf_size = sizeof(buf); + const uint8_t* borrowed = trans_->borrow(buf, &buf_size); + + // Fast path. + if (borrowed != NULL) { + while (true) { + uint8_t byte = borrowed[rsize]; + rsize++; + val |= (uint64_t)(byte & 0x7f) << shift; + shift += 7; + if (!(byte & 0x80)) { + i64 = val; + trans_->consume(rsize); + return rsize; + } + // Have to check for invalid data so we don't crash. + if (UNLIKELY(rsize == sizeof(buf))) { + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } + + // Slow path. + else { + while (true) { + uint8_t byte; + rsize += trans_->readAll(&byte, 1); + val |= (uint64_t)(byte & 0x7f) << shift; + shift += 7; + if (!(byte & 0x80)) { + i64 = val; + return rsize; + } + // Might as well check for invalid data on the slow path too. + if (UNLIKELY(rsize >= sizeof(buf))) { + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } +} + +/** + * Convert from zigzag int to int. + */ +int32_t TCompactProtocol::zigzagToI32(uint32_t n) { + return (n >> 1) ^ -(n & 1); +} + +/** + * Convert from zigzag long to long. + */ +int64_t TCompactProtocol::zigzagToI64(uint64_t n) { + return (n >> 1) ^ -(n & 1); +} + +TType TCompactProtocol::getTType(int8_t type) { + switch (type) { + case T_STOP: + return T_STOP; + case CT_BOOLEAN_FALSE: + case CT_BOOLEAN_TRUE: + return T_BOOL; + case CT_BYTE: + return T_BYTE; + case CT_I16: + return T_I16; + case CT_I32: + return T_I32; + case CT_I64: + return T_I64; + case CT_DOUBLE: + return T_DOUBLE; + case CT_BINARY: + return T_STRING; + case CT_LIST: + return T_LIST; + case CT_SET: + return T_SET; + case CT_MAP: + return T_MAP; + case CT_STRUCT: + return T_STRUCT; + default: + throw TException("don't know what type: " + type); + } + return T_STOP; +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TCompactProtocol.h b/lib/cpp/src/protocol/TCompactProtocol.h new file mode 100644 index 000000000..b4e06f0aa --- /dev/null +++ b/lib/cpp/src/protocol/TCompactProtocol.h @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_H_ 1 + +#include "TProtocol.h" + +#include <stack> +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace protocol { + +/** + * C++ Implementation of the Compact Protocol as described in THRIFT-110 + */ +class TCompactProtocol : public TProtocol { + + protected: + static const int8_t PROTOCOL_ID = 0x82; + static const int8_t VERSION_N = 1; + static const int8_t VERSION_MASK = 0x1f; // 0001 1111 + static const int8_t TYPE_MASK = 0xE0; // 1110 0000 + static const int32_t TYPE_SHIFT_AMOUNT = 5; + + /** + * (Writing) If we encounter a boolean field begin, save the TField here + * so it can have the value incorporated. + */ + struct { + const char* name; + TType fieldType; + int16_t fieldId; + } booleanField_; + + /** + * (Reading) If we read a field header, and it's a boolean field, save + * the boolean value here so that readBool can use it. + */ + struct { + bool hasBoolValue; + bool boolValue; + } boolValue_; + + /** + * Used to keep track of the last field for the current and previous structs, + * so we can do the delta stuff. + */ + + std::stack<int16_t> lastField_; + int16_t lastFieldId_; + + enum Types { + CT_STOP = 0x00, + CT_BOOLEAN_TRUE = 0x01, + CT_BOOLEAN_FALSE = 0x02, + CT_BYTE = 0x03, + CT_I16 = 0x04, + CT_I32 = 0x05, + CT_I64 = 0x06, + CT_DOUBLE = 0x07, + CT_BINARY = 0x08, + CT_LIST = 0x09, + CT_SET = 0x0A, + CT_MAP = 0x0B, + CT_STRUCT = 0x0C, + }; + + static const int8_t TTypeToCType[16]; + + public: + TCompactProtocol(boost::shared_ptr<TTransport> trans) : + TProtocol(trans), + lastFieldId_(0), + string_limit_(0), + string_buf_(NULL), + string_buf_size_(0), + container_limit_(0) { + booleanField_.name = NULL; + boolValue_.hasBoolValue = false; + } + + TCompactProtocol(boost::shared_ptr<TTransport> trans, + int32_t string_limit, + int32_t container_limit) : + TProtocol(trans), + lastFieldId_(0), + string_limit_(string_limit), + string_buf_(NULL), + string_buf_size_(0), + container_limit_(container_limit) { + booleanField_.name = NULL; + boolValue_.hasBoolValue = false; + } + + + + /** + * Writing functions + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldStop(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + virtual uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + /** + * These methods are called by structs, but don't actually have any wired + * output or purpose + */ + virtual uint32_t writeMessageEnd() { return 0; } + uint32_t writeMapEnd() { return 0; } + uint32_t writeListEnd() { return 0; } + uint32_t writeSetEnd() { return 0; } + uint32_t writeFieldEnd() { return 0; } + + protected: + int32_t writeFieldBeginInternal(const char* name, + const TType fieldType, + const int16_t fieldId, + int8_t typeOverride); + uint32_t writeCollectionBegin(int8_t elemType, int32_t size); + uint32_t writeVarint32(uint32_t n); + uint32_t writeVarint64(uint64_t n); + uint64_t i64ToZigzag(const int64_t l); + uint32_t i32ToZigzag(const int32_t n); + inline int8_t getCompactType(int8_t ttype); + + public: + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + /* + *These methods are here for the struct to call, but don't have any wire + * encoding. + */ + uint32_t readMessageEnd() { return 0; } + uint32_t readFieldEnd() { return 0; } + uint32_t readMapEnd() { return 0; } + uint32_t readListEnd() { return 0; } + uint32_t readSetEnd() { return 0; } + + protected: + uint32_t readVarint32(int32_t& i32); + uint32_t readVarint64(int64_t& i64); + int32_t zigzagToI32(uint32_t n); + int64_t zigzagToI64(uint64_t n); + TType getTType(int8_t type); + + // Buffer for reading strings, save for the lifetime of the protocol to + // avoid memory churn allocating memory on every string read + int32_t string_limit_; + uint8_t* string_buf_; + int32_t string_buf_size_; + int32_t container_limit_; +}; + +/** + * Constructs compact protocol handlers + */ +class TCompactProtocolFactory : public TProtocolFactory { + public: + TCompactProtocolFactory() : + string_limit_(0), + container_limit_(0) {} + + TCompactProtocolFactory(int32_t string_limit, int32_t container_limit) : + string_limit_(string_limit), + container_limit_(container_limit) {} + + virtual ~TCompactProtocolFactory() {} + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) { + container_limit_ = container_limit; + } + + boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TProtocol>(new TCompactProtocol(trans, string_limit_, container_limit_)); + } + + private: + int32_t string_limit_; + int32_t container_limit_; + +}; + +}}} // apache::thrift::protocol + +#endif diff --git a/lib/cpp/src/protocol/TDebugProtocol.cpp b/lib/cpp/src/protocol/TDebugProtocol.cpp new file mode 100644 index 000000000..40aa36bad --- /dev/null +++ b/lib/cpp/src/protocol/TDebugProtocol.cpp @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TDebugProtocol.h" + +#include <cassert> +#include <cctype> +#include <cstdio> +#include <stdexcept> +#include <boost/static_assert.hpp> +#include <boost/lexical_cast.hpp> + +using std::string; + + +static string byte_to_hex(const uint8_t byte) { + char buf[3]; + int ret = std::sprintf(buf, "%02x", (int)byte); + assert(ret == 2); + assert(buf[2] == '\0'); + return buf; +} + + +namespace apache { namespace thrift { namespace protocol { + +string TDebugProtocol::fieldTypeName(TType type) { + switch (type) { + case T_STOP : return "stop" ; + case T_VOID : return "void" ; + case T_BOOL : return "bool" ; + case T_BYTE : return "byte" ; + case T_I16 : return "i16" ; + case T_I32 : return "i32" ; + case T_U64 : return "u64" ; + case T_I64 : return "i64" ; + case T_DOUBLE : return "double" ; + case T_STRING : return "string" ; + case T_STRUCT : return "struct" ; + case T_MAP : return "map" ; + case T_SET : return "set" ; + case T_LIST : return "list" ; + case T_UTF8 : return "utf8" ; + case T_UTF16 : return "utf16" ; + default: return "unknown"; + } +} + +void TDebugProtocol::indentUp() { + indent_str_ += string(indent_inc, ' '); +} + +void TDebugProtocol::indentDown() { + if (indent_str_.length() < (string::size_type)indent_inc) { + throw TProtocolException(TProtocolException::INVALID_DATA); + } + indent_str_.erase(indent_str_.length() - indent_inc); +} + +uint32_t TDebugProtocol::writePlain(const string& str) { + trans_->write((uint8_t*)str.data(), str.length()); + return str.length(); +} + +uint32_t TDebugProtocol::writeIndented(const string& str) { + trans_->write((uint8_t*)indent_str_.data(), indent_str_.length()); + trans_->write((uint8_t*)str.data(), str.length()); + return indent_str_.length() + str.length(); +} + +uint32_t TDebugProtocol::startItem() { + uint32_t size; + + switch (write_state_.back()) { + case UNINIT: + // XXX figure out what to do here. + //throw TProtocolException(TProtocolException::INVALID_DATA); + //return writeIndented(str); + return 0; + case STRUCT: + return 0; + case SET: + return writeIndented(""); + case MAP_KEY: + return writeIndented(""); + case MAP_VALUE: + return writePlain(" -> "); + case LIST: + size = writeIndented( + "[" + boost::lexical_cast<string>(list_idx_.back()) + "] = "); + list_idx_.back()++; + return size; + default: + throw std::logic_error("Invalid enum value."); + } +} + +uint32_t TDebugProtocol::endItem() { + //uint32_t size; + + switch (write_state_.back()) { + case UNINIT: + // XXX figure out what to do here. + //throw TProtocolException(TProtocolException::INVALID_DATA); + //return writeIndented(str); + return 0; + case STRUCT: + return writePlain(",\n"); + case SET: + return writePlain(",\n"); + case MAP_KEY: + write_state_.back() = MAP_VALUE; + return 0; + case MAP_VALUE: + write_state_.back() = MAP_KEY; + return writePlain(",\n"); + case LIST: + return writePlain(",\n"); + default: + throw std::logic_error("Invalid enum value."); + } +} + +uint32_t TDebugProtocol::writeItem(const std::string& str) { + uint32_t size = 0; + size += startItem(); + size += writePlain(str); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + string mtype; + switch (messageType) { + case T_CALL : mtype = "call" ; break; + case T_REPLY : mtype = "reply" ; break; + case T_EXCEPTION : mtype = "exn" ; break; + } + + uint32_t size = writeIndented("(" + mtype + ") " + name + "("); + indentUp(); + return size; +} + +uint32_t TDebugProtocol::writeMessageEnd() { + indentDown(); + return writeIndented(")\n"); +} + +uint32_t TDebugProtocol::writeStructBegin(const char* name) { + uint32_t size = 0; + size += startItem(); + size += writePlain(string(name) + " {\n"); + indentUp(); + write_state_.push_back(STRUCT); + return size; +} + +uint32_t TDebugProtocol::writeStructEnd() { + indentDown(); + write_state_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + // sprintf(id_str, "%02d", fieldId); + string id_str = boost::lexical_cast<string>(fieldId); + if (id_str.length() == 1) id_str = '0' + id_str; + + return writeIndented( + id_str + ": " + + name + " (" + + fieldTypeName(fieldType) + ") = "); +} + +uint32_t TDebugProtocol::writeFieldEnd() { + assert(write_state_.back() == STRUCT); + return 0; +} + +uint32_t TDebugProtocol::writeFieldStop() { + return 0; + //writeIndented("***STOP***\n"); +} + +uint32_t TDebugProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + // TODO(dreiss): Optimize short maps? + uint32_t bsize = 0; + bsize += startItem(); + bsize += writePlain( + "map<" + fieldTypeName(keyType) + "," + fieldTypeName(valType) + ">" + "[" + boost::lexical_cast<string>(size) + "] {\n"); + indentUp(); + write_state_.push_back(MAP_KEY); + return bsize; +} + +uint32_t TDebugProtocol::writeMapEnd() { + indentDown(); + write_state_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + // TODO(dreiss): Optimize short arrays. + uint32_t bsize = 0; + bsize += startItem(); + bsize += writePlain( + "list<" + fieldTypeName(elemType) + ">" + "[" + boost::lexical_cast<string>(size) + "] {\n"); + indentUp(); + write_state_.push_back(LIST); + list_idx_.push_back(0); + return bsize; +} + +uint32_t TDebugProtocol::writeListEnd() { + indentDown(); + write_state_.pop_back(); + list_idx_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + // TODO(dreiss): Optimize short sets. + uint32_t bsize = 0; + bsize += startItem(); + bsize += writePlain( + "set<" + fieldTypeName(elemType) + ">" + "[" + boost::lexical_cast<string>(size) + "] {\n"); + indentUp(); + write_state_.push_back(SET); + return bsize; +} + +uint32_t TDebugProtocol::writeSetEnd() { + indentDown(); + write_state_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeBool(const bool value) { + return writeItem(value ? "true" : "false"); +} + +uint32_t TDebugProtocol::writeByte(const int8_t byte) { + return writeItem("0x" + byte_to_hex(byte)); +} + +uint32_t TDebugProtocol::writeI16(const int16_t i16) { + return writeItem(boost::lexical_cast<string>(i16)); +} + +uint32_t TDebugProtocol::writeI32(const int32_t i32) { + return writeItem(boost::lexical_cast<string>(i32)); +} + +uint32_t TDebugProtocol::writeI64(const int64_t i64) { + return writeItem(boost::lexical_cast<string>(i64)); +} + +uint32_t TDebugProtocol::writeDouble(const double dub) { + return writeItem(boost::lexical_cast<string>(dub)); +} + + +uint32_t TDebugProtocol::writeString(const string& str) { + // XXX Raw/UTF-8? + + string to_show = str; + if (to_show.length() > (string::size_type)string_limit_) { + to_show = str.substr(0, string_prefix_size_); + to_show += "[...](" + boost::lexical_cast<string>(str.length()) + ")"; + } + + string output = "\""; + + for (string::const_iterator it = to_show.begin(); it != to_show.end(); ++it) { + if (*it == '\\') { + output += "\\\\"; + } else if (*it == '"') { + output += "\\\""; + } else if (std::isprint(*it)) { + output += *it; + } else { + switch (*it) { + case '\a': output += "\\a"; break; + case '\b': output += "\\b"; break; + case '\f': output += "\\f"; break; + case '\n': output += "\\n"; break; + case '\r': output += "\\r"; break; + case '\t': output += "\\t"; break; + case '\v': output += "\\v"; break; + default: + output += "\\x"; + output += byte_to_hex(*it); + } + } + } + + output += '\"'; + return writeItem(output); +} + +uint32_t TDebugProtocol::writeBinary(const string& str) { + // XXX Hex? + return TDebugProtocol::writeString(str); +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TDebugProtocol.h b/lib/cpp/src/protocol/TDebugProtocol.h new file mode 100644 index 000000000..ab69e0ca5 --- /dev/null +++ b/lib/cpp/src/protocol/TDebugProtocol.h @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ 1 + +#include "TProtocol.h" +#include "TOneWayProtocol.h" + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace protocol { + +/* + +!!! EXPERIMENTAL CODE !!! + +This protocol is very much a work in progress. +It doesn't handle many cases properly. +It throws exceptions in many cases. +It probably segfaults in many cases. +Bug reports and feature requests are welcome. +Complaints are not. :R + +*/ + + +/** + * Protocol that prints the payload in a nice human-readable format. + * Reading from this protocol is not supported. + * + */ +class TDebugProtocol : public TWriteOnlyProtocol { + private: + enum write_state_t + { UNINIT + , STRUCT + , LIST + , SET + , MAP_KEY + , MAP_VALUE + }; + + public: + TDebugProtocol(boost::shared_ptr<TTransport> trans) + : TWriteOnlyProtocol(trans, "TDebugProtocol") + , string_limit_(DEFAULT_STRING_LIMIT) + , string_prefix_size_(DEFAULT_STRING_PREFIX_SIZE) + { + write_state_.push_back(UNINIT); + } + + static const int32_t DEFAULT_STRING_LIMIT = 256; + static const int32_t DEFAULT_STRING_PREFIX_SIZE = 16; + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setStringPrefixSize(int32_t string_prefix_size) { + string_prefix_size_ = string_prefix_size; + } + + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + virtual uint32_t writeMessageEnd(); + + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldEnd(); + + uint32_t writeFieldStop(); + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeMapEnd(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeListEnd(); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetEnd(); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + + private: + void indentUp(); + void indentDown(); + uint32_t writePlain(const std::string& str); + uint32_t writeIndented(const std::string& str); + uint32_t startItem(); + uint32_t endItem(); + uint32_t writeItem(const std::string& str); + + static std::string fieldTypeName(TType type); + + int32_t string_limit_; + int32_t string_prefix_size_; + + std::string indent_str_; + static const int indent_inc = 2; + + std::vector<write_state_t> write_state_; + std::vector<int> list_idx_; +}; + +/** + * Constructs debug protocol handlers + */ +class TDebugProtocolFactory : public TProtocolFactory { + public: + TDebugProtocolFactory() {} + virtual ~TDebugProtocolFactory() {} + + boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TProtocol>(new TDebugProtocol(trans)); + } + +}; + +}}} // apache::thrift::protocol + + +// TODO(dreiss): Move (part of) ThriftDebugString into a .cpp file and remove this. +#include <transport/TBufferTransports.h> + +namespace apache { namespace thrift { + +template<typename ThriftStruct> +std::string ThriftDebugString(const ThriftStruct& ts) { + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + TMemoryBuffer* buffer = new TMemoryBuffer; + boost::shared_ptr<TTransport> trans(buffer); + TDebugProtocol protocol(trans); + + ts.write(&protocol); + + uint8_t* buf; + uint32_t size; + buffer->getBuffer(&buf, &size); + return std::string((char*)buf, (unsigned int)size); +} + +// TODO(dreiss): This is badly broken. Don't use it unless you are me. +#if 0 +template<typename Object> +std::string DebugString(const std::vector<Object>& vec) { + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + TMemoryBuffer* buffer = new TMemoryBuffer; + boost::shared_ptr<TTransport> trans(buffer); + TDebugProtocol protocol(trans); + + // I am gross! + protocol.writeStructBegin("SomeRandomVector"); + + // TODO: Fix this with a trait. + protocol.writeListBegin((TType)99, vec.size()); + typename std::vector<Object>::const_iterator it; + for (it = vec.begin(); it != vec.end(); ++it) { + it->write(&protocol); + } + protocol.writeListEnd(); + + uint8_t* buf; + uint32_t size; + buffer->getBuffer(&buf, &size); + return std::string((char*)buf, (unsigned int)size); +} +#endif // 0 + +}} // apache::thrift + + +#endif // #ifndef _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ + + diff --git a/lib/cpp/src/protocol/TDenseProtocol.cpp b/lib/cpp/src/protocol/TDenseProtocol.cpp new file mode 100644 index 000000000..8e76dc479 --- /dev/null +++ b/lib/cpp/src/protocol/TDenseProtocol.cpp @@ -0,0 +1,762 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + +IMPLEMENTATION DETAILS + +TDenseProtocol was designed to have a smaller serialized form than +TBinaryProtocol. This is accomplished using two techniques. The first is +variable-length integer encoding. We use the same technique that the Standard +MIDI File format uses for "variable-length quantities" +(http://en.wikipedia.org/wiki/Variable-length_quantity). +All integers (including i16, but not byte) are first cast to uint64_t, +then written out as variable-length quantities. This has the unfortunate side +effect that all negative numbers require 10 bytes, but negative numbers tend +to be far less common than positive ones. + +The second technique eliminating the field ids used by TBinaryProtocol. This +decision required support from the Thrift compiler and also sacrifices some of +the backward and forward compatibility of TBinaryProtocol. + +We considered implementing this technique by generating separate readers and +writers for the dense protocol (this is how Pillar, Thrift's predecessor, +worked), but this idea had a few problems: +- Our abstractions go out the window. +- We would have to maintain a second code generator. +- Preserving compatibility with old versions of the structures would be a + nightmare. + +Therefore, we chose an alternate implementation that stored the description of +the data neither in the data itself (like TBinaryProtocol) nor in the +serialization code (like Pillar), but instead in a separate data structure, +called a TypeSpec. TypeSpecs are generated by the Thrift compiler +(specifically in the t_cpp_generator), and their structure should be +documented there (TODO(dreiss): s/should be/is/). + +We maintain a stack of TypeSpecs within the protocol so it knows where the +generated code is in the reading/writing process. For example, if we are +writing an i32 contained in a struct bar, contained in a struct foo, then the +stack would look like: TOP , i32 , struct bar , struct foo , BOTTOM. +The following invariant: whenever we are about to read/write an object +(structBegin, containerBegin, or a scalar), the TypeSpec on the top of the +stack must match the type being read/written. The main reasons that this +invariant must be maintained is that if we ever start reading a structure, we +must have its exact TypeSpec in order to pass the right tags to the +deserializer. + +We use the following strategies for maintaining this invariant: + +- For structures, we have a separate stack of indexes, one for each structure + on the TypeSpec stack. These are indexes into the list of fields in the + structure's TypeSpec. When we {read,write}FieldBegin, we push on the + TypeSpec for the field. +- When we begin writing a list or set, we push on the TypeSpec for the + element type. +- For maps, we have a separate stack of booleans, one for each map on the + TypeSpec stack. The boolean is true if we are writing the key for that + map, and false if we are writing the value. Maps are the trickiest case + because the generated code does not call any protocol method between + the key and the value. As a result, we potentially have to switch + between map key state and map value state after reading/writing any object. +- This job is handled by the stateTransition method. It is called after + reading/writing every object. It pops the current TypeSpec off the stack, + then optionally pushes a new one on, depending on what the next TypeSpec is. + If it is a struct, the job is left to the next writeFieldBegin. If it is a + set or list, the just-popped typespec is pushed back on. If it is a map, + the top of the key/value stack is toggled, and the appropriate TypeSpec + is pushed. + +Optional fields are a little tricky also. We write a zero byte if they are +absent and prefix them with an 0x01 byte if they are present +*/ + +#define __STDC_LIMIT_MACROS +#include <stdint.h> +#include "TDenseProtocol.h" +#include "TReflectionLocal.h" + +// Leaving this on for now. Disabling it will turn off asserts, which should +// give a performance boost. When we have *really* thorough test cases, +// we should drop this. +#define DEBUG_TDENSEPROTOCOL + +// NOTE: Assertions should *only* be used to detect bugs in code, +// either in TDenseProtocol itself, or in code using it. +// (For example, using the wrong TypeSpec.) +// Invalid data should NEVER cause an assertion failure, +// no matter how grossly corrupted, nor how ingeniously crafted. +#ifdef DEBUG_TDENSEPROTOCOL +#undef NDEBUG +#else +#define NDEBUG +#endif +#include <cassert> + +using std::string; + +#ifdef __GNUC__ +#define UNLIKELY(val) (__builtin_expect((val), 0)) +#else +#define UNLIKELY(val) (val) +#endif + +namespace apache { namespace thrift { namespace protocol { + +const int TDenseProtocol::FP_PREFIX_LEN = + apache::thrift::reflection::local::FP_PREFIX_LEN; + +// Top TypeSpec. TypeSpec of the structure being encoded. +#define TTS (ts_stack_.back()) // type = TypeSpec* +// InDeX. Index into TTS of the current/next field to encode. +#define IDX (idx_stack_.back()) // type = int +// Field TypeSpec. TypeSpec of the current/next field to encode. +#define FTS (TTS->tstruct.specs[IDX]) // type = TypeSpec* +// Field MeTa. Metadata of the current/next field to encode. +#define FMT (TTS->tstruct.metas[IDX]) // type = FieldMeta +// SubType 1/2. TypeSpec of the first/second subtype of this container. +#define ST1 (TTS->tcontainer.subtype1) +#define ST2 (TTS->tcontainer.subtype2) + + +/** + * Checks that @c ttype is indeed the ttype that we should be writing, + * according to our typespec. Aborts if the test fails and debugging in on. + */ +inline void TDenseProtocol::checkTType(const TType ttype) { + assert(!ts_stack_.empty()); + assert(TTS->ttype == ttype); +} + +/** + * Makes sure that the TypeSpec stack is correct for the next object. + * See top-of-file comments. + */ +inline void TDenseProtocol::stateTransition() { + TypeSpec* old_tts = ts_stack_.back(); + ts_stack_.pop_back(); + + // If this is the end of the top-level write, we should have just popped + // the TypeSpec passed to the constructor. + if (ts_stack_.empty()) { + assert(old_tts = type_spec_); + return; + } + + switch (TTS->ttype) { + + case T_STRUCT: + assert(old_tts == FTS); + break; + + case T_LIST: + case T_SET: + assert(old_tts == ST1); + ts_stack_.push_back(old_tts); + break; + + case T_MAP: + assert(old_tts == (mkv_stack_.back() ? ST1 : ST2)); + mkv_stack_.back() = !mkv_stack_.back(); + ts_stack_.push_back(mkv_stack_.back() ? ST1 : ST2); + break; + + default: + assert(!"Invalid TType in stateTransition."); + break; + + } +} + + +/* + * Variable-length quantity functions. + */ + +inline uint32_t TDenseProtocol::vlqRead(uint64_t& vlq) { + uint32_t used = 0; + uint64_t val = 0; + uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. + uint32_t buf_size = sizeof(buf); + const uint8_t* borrowed = trans_->borrow(buf, &buf_size); + + // Fast path. TODO(dreiss): Make it faster. + if (borrowed != NULL) { + while (true) { + uint8_t byte = borrowed[used]; + used++; + val = (val << 7) | (byte & 0x7f); + if (!(byte & 0x80)) { + vlq = val; + trans_->consume(used); + return used; + } + // Have to check for invalid data so we don't crash. + if (UNLIKELY(used == sizeof(buf))) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } + + // Slow path. + else { + while (true) { + uint8_t byte; + used += trans_->readAll(&byte, 1); + val = (val << 7) | (byte & 0x7f); + if (!(byte & 0x80)) { + vlq = val; + return used; + } + // Might as well check for invalid data on the slow path too. + if (UNLIKELY(used >= sizeof(buf))) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } +} + +inline uint32_t TDenseProtocol::vlqWrite(uint64_t vlq) { + uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. + int32_t pos = sizeof(buf) - 1; + + // Write the thing from back to front. + buf[pos] = vlq & 0x7f; + vlq >>= 7; + pos--; + + while (vlq > 0) { + assert(pos >= 0); + buf[pos] = (vlq | 0x80); + vlq >>= 7; + pos--; + } + + // Back up one step before writing. + pos++; + + trans_->write(buf+pos, sizeof(buf) - pos); + return sizeof(buf) - pos; +} + + + +/* + * Writing functions. + */ + +uint32_t TDenseProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + throw TApplicationException("TDenseProtocol doesn't work with messages (yet)."); + + int32_t version = (VERSION_2) | ((int32_t)messageType); + uint32_t wsize = 0; + wsize += subWriteI32(version); + wsize += subWriteString(name); + wsize += subWriteI32(seqid); + return wsize; +} + +uint32_t TDenseProtocol::writeMessageEnd() { + return 0; +} + +uint32_t TDenseProtocol::writeStructBegin(const char* name) { + uint32_t xfer = 0; + + // The TypeSpec stack should be empty if this is the top-level read/write. + // If it is, we push the TypeSpec passed to the constructor. + if (ts_stack_.empty()) { + assert(standalone_); + + if (type_spec_ == NULL) { + resetState(); + throw TApplicationException("TDenseProtocol: No type specified."); + } else { + assert(type_spec_->ttype == T_STRUCT); + ts_stack_.push_back(type_spec_); + // Write out a prefix of the structure fingerprint. + trans_->write(type_spec_->fp_prefix, FP_PREFIX_LEN); + xfer += FP_PREFIX_LEN; + } + } + + // We need a new field index for this structure. + idx_stack_.push_back(0); + return 0; +} + +uint32_t TDenseProtocol::writeStructEnd() { + idx_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t xfer = 0; + + // Skip over optional fields. + while (FMT.tag != fieldId) { + // TODO(dreiss): Old meta here. + assert(FTS->ttype != T_STOP); + assert(FMT.is_optional); + // Write a zero byte so the reader can skip it. + xfer += subWriteBool(false); + // And advance to the next field. + IDX++; + } + + // TODO(dreiss): give a better exception. + assert(FTS->ttype == fieldType); + + if (FMT.is_optional) { + subWriteBool(true); + xfer += 1; + } + + // writeFieldStop shares all lot of logic up to this point. + // Instead of replicating it all, we just call this method from that one + // and use a gross special case here. + if (UNLIKELY(FTS->ttype != T_STOP)) { + // For normal fields, push the TypeSpec that we're about to use. + ts_stack_.push_back(FTS); + } + return xfer; +} + +uint32_t TDenseProtocol::writeFieldEnd() { + // Just move on to the next field. + IDX++; + return 0; +} + +uint32_t TDenseProtocol::writeFieldStop() { + return TDenseProtocol::writeFieldBegin("", T_STOP, 0); +} + +uint32_t TDenseProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + checkTType(T_MAP); + + assert(keyType == ST1->ttype); + assert(valType == ST2->ttype); + + ts_stack_.push_back(ST1); + mkv_stack_.push_back(true); + + return subWriteI32((int32_t)size); +} + +uint32_t TDenseProtocol::writeMapEnd() { + // Pop off the value type, as well as our entry in the map key/value stack. + // stateTransition takes care of popping off our TypeSpec. + ts_stack_.pop_back(); + mkv_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + checkTType(T_LIST); + + assert(elemType == ST1->ttype); + ts_stack_.push_back(ST1); + return subWriteI32((int32_t)size); +} + +uint32_t TDenseProtocol::writeListEnd() { + // Pop off the element type. stateTransition takes care of popping off ours. + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + checkTType(T_SET); + + assert(elemType == ST1->ttype); + ts_stack_.push_back(ST1); + return subWriteI32((int32_t)size); +} + +uint32_t TDenseProtocol::writeSetEnd() { + // Pop off the element type. stateTransition takes care of popping off ours. + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeBool(const bool value) { + checkTType(T_BOOL); + stateTransition(); + return TBinaryProtocol::writeBool(value); +} + +uint32_t TDenseProtocol::writeByte(const int8_t byte) { + checkTType(T_BYTE); + stateTransition(); + return TBinaryProtocol::writeByte(byte); +} + +uint32_t TDenseProtocol::writeI16(const int16_t i16) { + checkTType(T_I16); + stateTransition(); + return vlqWrite(i16); +} + +uint32_t TDenseProtocol::writeI32(const int32_t i32) { + checkTType(T_I32); + stateTransition(); + return vlqWrite(i32); +} + +uint32_t TDenseProtocol::writeI64(const int64_t i64) { + checkTType(T_I64); + stateTransition(); + return vlqWrite(i64); +} + +uint32_t TDenseProtocol::writeDouble(const double dub) { + checkTType(T_DOUBLE); + stateTransition(); + return TBinaryProtocol::writeDouble(dub); +} + +uint32_t TDenseProtocol::writeString(const std::string& str) { + checkTType(T_STRING); + stateTransition(); + return subWriteString(str); +} + +uint32_t TDenseProtocol::writeBinary(const std::string& str) { + return TDenseProtocol::writeString(str); +} + +inline uint32_t TDenseProtocol::subWriteI32(const int32_t i32) { + return vlqWrite(i32); +} + +uint32_t TDenseProtocol::subWriteString(const std::string& str) { + uint32_t size = str.size(); + uint32_t xfer = subWriteI32((int32_t)size); + if (size > 0) { + trans_->write((uint8_t*)str.data(), size); + } + return xfer + size; +} + + + +/* + * Reading functions + * + * These have a lot of the same logic as the writing functions, so if + * something is confusing, look for comments in the corresponding writer. + */ + +uint32_t TDenseProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + throw TApplicationException("TDenseProtocol doesn't work with messages (yet)."); + + uint32_t xfer = 0; + int32_t sz; + xfer += subReadI32(sz); + + if (sz < 0) { + // Check for correct version number + int32_t version = sz & VERSION_MASK; + if (version != VERSION_2) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier"); + } + messageType = (TMessageType)(sz & 0x000000ff); + xfer += subReadString(name); + xfer += subReadI32(seqid); + } else { + throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?"); + } + return xfer; +} + +uint32_t TDenseProtocol::readMessageEnd() { + return 0; +} + +uint32_t TDenseProtocol::readStructBegin(string& name) { + uint32_t xfer = 0; + + if (ts_stack_.empty()) { + assert(standalone_); + + if (type_spec_ == NULL) { + resetState(); + throw TApplicationException("TDenseProtocol: No type specified."); + } else { + assert(type_spec_->ttype == T_STRUCT); + ts_stack_.push_back(type_spec_); + + // Check the fingerprint prefix. + uint8_t buf[FP_PREFIX_LEN]; + xfer += trans_->read(buf, FP_PREFIX_LEN); + if (std::memcmp(buf, type_spec_->fp_prefix, FP_PREFIX_LEN) != 0) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "Fingerprint in data does not match type_spec."); + } + } + } + + // We need a new field index for this structure. + idx_stack_.push_back(0); + return 0; +} + +uint32_t TDenseProtocol::readStructEnd() { + idx_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readFieldBegin(string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t xfer = 0; + + // For optional fields, check to see if they are there. + while (FMT.is_optional) { + bool is_present; + xfer += subReadBool(is_present); + if (is_present) { + break; + } + IDX++; + } + + // Once we hit a mandatory field, or an optional field that is present, + // we know that FMT and FTS point to the appropriate field. + + fieldId = FMT.tag; + fieldType = FTS->ttype; + + // Normally, we push the TypeSpec that we are about to read, + // but no reading is done for T_STOP. + if (FTS->ttype != T_STOP) { + ts_stack_.push_back(FTS); + } + return xfer; +} + +uint32_t TDenseProtocol::readFieldEnd() { + IDX++; + return 0; +} + +uint32_t TDenseProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + checkTType(T_MAP); + + uint32_t xfer = 0; + int32_t sizei; + xfer += subReadI32(sizei); + if (sizei < 0) { + resetState(); + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + resetState(); + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + + keyType = ST1->ttype; + valType = ST2->ttype; + + ts_stack_.push_back(ST1); + mkv_stack_.push_back(true); + + return xfer; +} + +uint32_t TDenseProtocol::readMapEnd() { + ts_stack_.pop_back(); + mkv_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readListBegin(TType& elemType, + uint32_t& size) { + checkTType(T_LIST); + + uint32_t xfer = 0; + int32_t sizei; + xfer += subReadI32(sizei); + if (sizei < 0) { + resetState(); + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + resetState(); + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + + elemType = ST1->ttype; + + ts_stack_.push_back(ST1); + + return xfer; +} + +uint32_t TDenseProtocol::readListEnd() { + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + checkTType(T_SET); + + uint32_t xfer = 0; + int32_t sizei; + xfer += subReadI32(sizei); + if (sizei < 0) { + resetState(); + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + resetState(); + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + + elemType = ST1->ttype; + + ts_stack_.push_back(ST1); + + return xfer; +} + +uint32_t TDenseProtocol::readSetEnd() { + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readBool(bool& value) { + checkTType(T_BOOL); + stateTransition(); + return TBinaryProtocol::readBool(value); +} + +uint32_t TDenseProtocol::readByte(int8_t& byte) { + checkTType(T_BYTE); + stateTransition(); + return TBinaryProtocol::readByte(byte); +} + +uint32_t TDenseProtocol::readI16(int16_t& i16) { + checkTType(T_I16); + stateTransition(); + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT16_MAX || val < INT16_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i16 out of range."); + } + i16 = (int16_t)val; + return rv; +} + +uint32_t TDenseProtocol::readI32(int32_t& i32) { + checkTType(T_I32); + stateTransition(); + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT32_MAX || val < INT32_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i32 out of range."); + } + i32 = (int32_t)val; + return rv; +} + +uint32_t TDenseProtocol::readI64(int64_t& i64) { + checkTType(T_I64); + stateTransition(); + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT64_MAX || val < INT64_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i64 out of range."); + } + i64 = (int64_t)val; + return rv; +} + +uint32_t TDenseProtocol::readDouble(double& dub) { + checkTType(T_DOUBLE); + stateTransition(); + return TBinaryProtocol::readDouble(dub); +} + +uint32_t TDenseProtocol::readString(std::string& str) { + checkTType(T_STRING); + stateTransition(); + return subReadString(str); +} + +uint32_t TDenseProtocol::readBinary(std::string& str) { + return TDenseProtocol::readString(str); +} + +uint32_t TDenseProtocol::subReadI32(int32_t& i32) { + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT32_MAX || val < INT32_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i32 out of range."); + } + i32 = (int32_t)val; + return rv; +} + +uint32_t TDenseProtocol::subReadString(std::string& str) { + uint32_t xfer; + int32_t size; + xfer = subReadI32(size); + return xfer + readStringBody(str, size); +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TDenseProtocol.h b/lib/cpp/src/protocol/TDenseProtocol.h new file mode 100644 index 000000000..7655a479a --- /dev/null +++ b/lib/cpp/src/protocol/TDenseProtocol.h @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ 1 + +#include "TBinaryProtocol.h" + +namespace apache { namespace thrift { namespace protocol { + +/** + * !!!WARNING!!! + * This class is still highly experimental. Incompatible changes + * WILL be made to it without notice. DO NOT USE IT YET unless + * you are coordinating your testing with the author. + * + * The dense protocol is designed to use as little space as possible. + * + * There are two types of dense protocol instances. Standalone instances + * are not used for RPC and just encoded and decode structures of + * a predetermined type. Non-standalone instances are used for RPC. + * Currently, only standalone instances exist. + * + * To use a standalone dense protocol object, you must set the type_spec + * property (either in the constructor, or with setTypeSpec) to the local + * reflection TypeSpec of the structures you will write to (or read from) the + * protocol instance. + * + * BEST PRACTICES: + * - Never use optional for primitives or containers. + * - Only use optional for structures if they are very big and very rarely set. + * - All integers are variable-length, so you can use i64 without bloating. + * - NEVER EVER change the struct definitions IN ANY WAY without either + * changing your cache keys or talking to dreiss. + * + * TODO(dreiss): New class write with old meta. + * + * We override all of TBinaryProtocol's methods. + * We inherit so that we can can explicitly call TBPs's primitive-writing + * methods within our versions. + * + */ +class TDenseProtocol : public TBinaryProtocol { + protected: + static const int32_t VERSION_MASK = 0xffff0000; + // VERSION_1 (0x80010000) is taken by TBinaryProtocol. + static const int32_t VERSION_2 = 0x80020000; + + public: + typedef apache::thrift::reflection::local::TypeSpec TypeSpec; + static const int FP_PREFIX_LEN; + + /** + * @param tran The transport to use. + * @param type_spec The TypeSpec of the structures using this protocol. + */ + TDenseProtocol(boost::shared_ptr<TTransport> trans, + TypeSpec* type_spec = NULL) : + TBinaryProtocol(trans), + type_spec_(type_spec), + standalone_(true) + {} + + void setTypeSpec(TypeSpec* type_spec) { + type_spec_ = type_spec; + } + TypeSpec* getTypeSpec() { + return type_spec_; + } + + + /* + * Writing functions. + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + virtual uint32_t writeMessageEnd(); + + + virtual uint32_t writeStructBegin(const char* name); + + virtual uint32_t writeStructEnd(); + + virtual uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + virtual uint32_t writeFieldEnd(); + + virtual uint32_t writeFieldStop(); + + virtual uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + virtual uint32_t writeMapEnd(); + + virtual uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + virtual uint32_t writeListEnd(); + + virtual uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + virtual uint32_t writeSetEnd(); + + virtual uint32_t writeBool(const bool value); + + virtual uint32_t writeByte(const int8_t byte); + + virtual uint32_t writeI16(const int16_t i16); + + virtual uint32_t writeI32(const int32_t i32); + + virtual uint32_t writeI64(const int64_t i64); + + virtual uint32_t writeDouble(const double dub); + + virtual uint32_t writeString(const std::string& str); + + virtual uint32_t writeBinary(const std::string& str); + + + /* + * Helper writing functions (don't do state transitions). + */ + inline uint32_t subWriteI32(const int32_t i32); + + inline uint32_t subWriteString(const std::string& str); + + uint32_t subWriteBool(const bool value) { + return TBinaryProtocol::writeBool(value); + } + + + /* + * Reading functions + */ + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readMessageEnd(); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readFieldEnd(); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readMapEnd(); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readListEnd(); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetEnd(); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + /* + * Helper reading functions (don't do state transitions). + */ + inline uint32_t subReadI32(int32_t& i32); + + inline uint32_t subReadString(std::string& str); + + uint32_t subReadBool(bool& value) { + return TBinaryProtocol::readBool(value); + } + + + private: + + // Implementation functions, documented in the .cpp. + inline void checkTType(const TType ttype); + inline void stateTransition(); + + // Read and write variable-length integers. + // Uses the same technique as the MIDI file format. + inline uint32_t vlqRead(uint64_t& vlq); + inline uint32_t vlqWrite(uint64_t vlq); + + // Called before throwing an exception to make the object reusable. + void resetState() { + ts_stack_.clear(); + idx_stack_.clear(); + mkv_stack_.clear(); + } + + // TypeSpec of the top-level structure to write, + // for standalone protocol objects. + TypeSpec* type_spec_; + + std::vector<TypeSpec*> ts_stack_; // TypeSpec stack. + std::vector<int> idx_stack_; // InDeX stack. + std::vector<bool> mkv_stack_; // Map Key/Vlue stack. + // True = key, False = value. + + // True iff this is a standalone instance (no RPC). + bool standalone_; +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TJSONProtocol.cpp b/lib/cpp/src/protocol/TJSONProtocol.cpp new file mode 100644 index 000000000..2a9c8f0b2 --- /dev/null +++ b/lib/cpp/src/protocol/TJSONProtocol.cpp @@ -0,0 +1,998 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TJSONProtocol.h" + +#include <math.h> +#include <boost/lexical_cast.hpp> +#include "TBase64Utils.h" +#include <transport/TTransportException.h> + +using namespace apache::thrift::transport; + +namespace apache { namespace thrift { namespace protocol { + + +// Static data + +static const uint8_t kJSONObjectStart = '{'; +static const uint8_t kJSONObjectEnd = '}'; +static const uint8_t kJSONArrayStart = '['; +static const uint8_t kJSONArrayEnd = ']'; +static const uint8_t kJSONNewline = '\n'; +static const uint8_t kJSONPairSeparator = ':'; +static const uint8_t kJSONElemSeparator = ','; +static const uint8_t kJSONBackslash = '\\'; +static const uint8_t kJSONStringDelimiter = '"'; +static const uint8_t kJSONZeroChar = '0'; +static const uint8_t kJSONEscapeChar = 'u'; + +static const std::string kJSONEscapePrefix("\\u00"); + +static const uint32_t kThriftVersion1 = 1; + +static const std::string kThriftNan("NaN"); +static const std::string kThriftInfinity("Infinity"); +static const std::string kThriftNegativeInfinity("-Infinity"); + +static const std::string kTypeNameBool("tf"); +static const std::string kTypeNameByte("i8"); +static const std::string kTypeNameI16("i16"); +static const std::string kTypeNameI32("i32"); +static const std::string kTypeNameI64("i64"); +static const std::string kTypeNameDouble("dbl"); +static const std::string kTypeNameStruct("rec"); +static const std::string kTypeNameString("str"); +static const std::string kTypeNameMap("map"); +static const std::string kTypeNameList("lst"); +static const std::string kTypeNameSet("set"); + +static const std::string &getTypeNameForTypeID(TType typeID) { + switch (typeID) { + case T_BOOL: + return kTypeNameBool; + case T_BYTE: + return kTypeNameByte; + case T_I16: + return kTypeNameI16; + case T_I32: + return kTypeNameI32; + case T_I64: + return kTypeNameI64; + case T_DOUBLE: + return kTypeNameDouble; + case T_STRING: + return kTypeNameString; + case T_STRUCT: + return kTypeNameStruct; + case T_MAP: + return kTypeNameMap; + case T_SET: + return kTypeNameSet; + case T_LIST: + return kTypeNameList; + default: + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + "Unrecognized type"); + } +} + +static TType getTypeIDForTypeName(const std::string &name) { + TType result = T_STOP; // Sentinel value + if (name.length() > 1) { + switch (name[0]) { + case 'd': + result = T_DOUBLE; + break; + case 'i': + switch (name[1]) { + case '8': + result = T_BYTE; + break; + case '1': + result = T_I16; + break; + case '3': + result = T_I32; + break; + case '6': + result = T_I64; + break; + } + break; + case 'l': + result = T_LIST; + break; + case 'm': + result = T_MAP; + break; + case 'r': + result = T_STRUCT; + break; + case 's': + if (name[1] == 't') { + result = T_STRING; + } + else if (name[1] == 'e') { + result = T_SET; + } + break; + case 't': + result = T_BOOL; + break; + } + } + if (result == T_STOP) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + "Unrecognized type"); + } + return result; +} + + +// This table describes the handling for the first 0x30 characters +// 0 : escape using "\u00xx" notation +// 1 : just output index +// <other> : escape using "\<other>" notation +static const uint8_t kJSONCharTable[0x30] = { +// 0 1 2 3 4 5 6 7 8 9 A B C D E F + 0, 0, 0, 0, 0, 0, 0, 0,'b','t','n', 0,'f','r', 0, 0, // 0 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 1 + 1, 1,'"', 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2 +}; + + +// This string's characters must match up with the elements in kEscapeCharVals. +// I don't have '/' on this list even though it appears on www.json.org -- +// it is not in the RFC +const static std::string kEscapeChars("\"\\bfnrt"); + +// The elements of this array must match up with the sequence of characters in +// kEscapeChars +const static uint8_t kEscapeCharVals[7] = { + '"', '\\', '\b', '\f', '\n', '\r', '\t', +}; + + +// Static helper functions + +// Read 1 character from the transport trans and verify that it is the +// expected character ch. +// Throw a protocol exception if it is not. +static uint32_t readSyntaxChar(TJSONProtocol::LookaheadReader &reader, + uint8_t ch) { + uint8_t ch2 = reader.read(); + if (ch2 != ch) { + throw TProtocolException(TProtocolException::INVALID_DATA, + "Expected \'" + std::string((char *)&ch, 1) + + "\'; got \'" + std::string((char *)&ch2, 1) + + "\'."); + } + return 1; +} + +// Return the integer value of a hex character ch. +// Throw a protocol exception if the character is not [0-9a-f]. +static uint8_t hexVal(uint8_t ch) { + if ((ch >= '0') && (ch <= '9')) { + return ch - '0'; + } + else if ((ch >= 'a') && (ch <= 'f')) { + return ch - 'a'; + } + else { + throw TProtocolException(TProtocolException::INVALID_DATA, + "Expected hex val ([0-9a-f]); got \'" + + std::string((char *)&ch, 1) + "\'."); + } +} + +// Return the hex character representing the integer val. The value is masked +// to make sure it is in the correct range. +static uint8_t hexChar(uint8_t val) { + val &= 0x0F; + if (val < 10) { + return val + '0'; + } + else { + return val + 'a'; + } +} + +// Return true if the character ch is in [-+0-9.Ee]; false otherwise +static bool isJSONNumeric(uint8_t ch) { + switch (ch) { + case '+': + case '-': + case '.': + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case 'E': + case 'e': + return true; + } + return false; +} + + +/** + * Class to serve as base JSON context and as base class for other context + * implementations + */ +class TJSONContext { + + public: + + TJSONContext() {}; + + virtual ~TJSONContext() {}; + + /** + * Write context data to the transport. Default is to do nothing. + */ + virtual uint32_t write(TTransport &trans) { + return 0; + }; + + /** + * Read context data from the transport. Default is to do nothing. + */ + virtual uint32_t read(TJSONProtocol::LookaheadReader &reader) { + return 0; + }; + + /** + * Return true if numbers need to be escaped as strings in this context. + * Default behavior is to return false. + */ + virtual bool escapeNum() { + return false; + } +}; + +// Context class for object member key-value pairs +class JSONPairContext : public TJSONContext { + +public: + + JSONPairContext() : + first_(true), + colon_(true) { + } + + uint32_t write(TTransport &trans) { + if (first_) { + first_ = false; + colon_ = true; + return 0; + } + else { + trans.write(colon_ ? &kJSONPairSeparator : &kJSONElemSeparator, 1); + colon_ = !colon_; + return 1; + } + } + + uint32_t read(TJSONProtocol::LookaheadReader &reader) { + if (first_) { + first_ = false; + colon_ = true; + return 0; + } + else { + uint8_t ch = (colon_ ? kJSONPairSeparator : kJSONElemSeparator); + colon_ = !colon_; + return readSyntaxChar(reader, ch); + } + } + + // Numbers must be turned into strings if they are the key part of a pair + virtual bool escapeNum() { + return colon_; + } + + private: + + bool first_; + bool colon_; +}; + +// Context class for lists +class JSONListContext : public TJSONContext { + +public: + + JSONListContext() : + first_(true) { + } + + uint32_t write(TTransport &trans) { + if (first_) { + first_ = false; + return 0; + } + else { + trans.write(&kJSONElemSeparator, 1); + return 1; + } + } + + uint32_t read(TJSONProtocol::LookaheadReader &reader) { + if (first_) { + first_ = false; + return 0; + } + else { + return readSyntaxChar(reader, kJSONElemSeparator); + } + } + + private: + bool first_; +}; + + +TJSONProtocol::TJSONProtocol(boost::shared_ptr<TTransport> ptrans) : + TProtocol(ptrans), + context_(new TJSONContext()), + reader_(*ptrans) { +} + +TJSONProtocol::~TJSONProtocol() {} + +void TJSONProtocol::pushContext(boost::shared_ptr<TJSONContext> c) { + contexts_.push(context_); + context_ = c; +} + +void TJSONProtocol::popContext() { + context_ = contexts_.top(); + contexts_.pop(); +} + +// Write the character ch as a JSON escape sequence ("\u00xx") +uint32_t TJSONProtocol::writeJSONEscapeChar(uint8_t ch) { + trans_->write((const uint8_t *)kJSONEscapePrefix.c_str(), + kJSONEscapePrefix.length()); + uint8_t outCh = hexChar(ch >> 4); + trans_->write(&outCh, 1); + outCh = hexChar(ch); + trans_->write(&outCh, 1); + return 6; +} + +// Write the character ch as part of a JSON string, escaping as appropriate. +uint32_t TJSONProtocol::writeJSONChar(uint8_t ch) { + if (ch >= 0x30) { + if (ch == kJSONBackslash) { // Only special character >= 0x30 is '\' + trans_->write(&kJSONBackslash, 1); + trans_->write(&kJSONBackslash, 1); + return 2; + } + else { + trans_->write(&ch, 1); + return 1; + } + } + else { + uint8_t outCh = kJSONCharTable[ch]; + // Check if regular character, backslash escaped, or JSON escaped + if (outCh == 1) { + trans_->write(&ch, 1); + return 1; + } + else if (outCh > 1) { + trans_->write(&kJSONBackslash, 1); + trans_->write(&outCh, 1); + return 2; + } + else { + return writeJSONEscapeChar(ch); + } + } +} + +// Write out the contents of the string str as a JSON string, escaping +// characters as appropriate. +uint32_t TJSONProtocol::writeJSONString(const std::string &str) { + uint32_t result = context_->write(*trans_); + result += 2; // For quotes + trans_->write(&kJSONStringDelimiter, 1); + std::string::const_iterator iter(str.begin()); + std::string::const_iterator end(str.end()); + while (iter != end) { + result += writeJSONChar(*iter++); + } + trans_->write(&kJSONStringDelimiter, 1); + return result; +} + +// Write out the contents of the string as JSON string, base64-encoding +// the string's contents, and escaping as appropriate +uint32_t TJSONProtocol::writeJSONBase64(const std::string &str) { + uint32_t result = context_->write(*trans_); + result += 2; // For quotes + trans_->write(&kJSONStringDelimiter, 1); + uint8_t b[4]; + const uint8_t *bytes = (const uint8_t *)str.c_str(); + uint32_t len = str.length(); + while (len >= 3) { + // Encode 3 bytes at a time + base64_encode(bytes, 3, b); + trans_->write(b, 4); + result += 4; + bytes += 3; + len -=3; + } + if (len) { // Handle remainder + base64_encode(bytes, len, b); + trans_->write(b, len + 1); + result += len + 1; + } + trans_->write(&kJSONStringDelimiter, 1); + return result; +} + +// Convert the given integer type to a JSON number, or a string +// if the context requires it (eg: key in a map pair). +template <typename NumberType> +uint32_t TJSONProtocol::writeJSONInteger(NumberType num) { + uint32_t result = context_->write(*trans_); + std::string val(boost::lexical_cast<std::string>(num)); + bool escapeNum = context_->escapeNum(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + trans_->write((const uint8_t *)val.c_str(), val.length()); + result += val.length(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + return result; +} + +// Convert the given double to a JSON string, which is either the number, +// "NaN" or "Infinity" or "-Infinity". +uint32_t TJSONProtocol::writeJSONDouble(double num) { + uint32_t result = context_->write(*trans_); + std::string val(boost::lexical_cast<std::string>(num)); + + // Normalize output of boost::lexical_cast for NaNs and Infinities + bool special = false; + switch (val[0]) { + case 'N': + case 'n': + val = kThriftNan; + special = true; + break; + case 'I': + case 'i': + val = kThriftInfinity; + special = true; + break; + case '-': + if ((val[1] == 'I') || (val[1] == 'i')) { + val = kThriftNegativeInfinity; + special = true; + } + break; + } + + bool escapeNum = special || context_->escapeNum(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + trans_->write((const uint8_t *)val.c_str(), val.length()); + result += val.length(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + return result; +} + +uint32_t TJSONProtocol::writeJSONObjectStart() { + uint32_t result = context_->write(*trans_); + trans_->write(&kJSONObjectStart, 1); + pushContext(boost::shared_ptr<TJSONContext>(new JSONPairContext())); + return result + 1; +} + +uint32_t TJSONProtocol::writeJSONObjectEnd() { + popContext(); + trans_->write(&kJSONObjectEnd, 1); + return 1; +} + +uint32_t TJSONProtocol::writeJSONArrayStart() { + uint32_t result = context_->write(*trans_); + trans_->write(&kJSONArrayStart, 1); + pushContext(boost::shared_ptr<TJSONContext>(new JSONListContext())); + return result + 1; +} + +uint32_t TJSONProtocol::writeJSONArrayEnd() { + popContext(); + trans_->write(&kJSONArrayEnd, 1); + return 1; +} + +uint32_t TJSONProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONInteger(kThriftVersion1); + result += writeJSONString(name); + result += writeJSONInteger(messageType); + result += writeJSONInteger(seqid); + return result; +} + +uint32_t TJSONProtocol::writeMessageEnd() { + return writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeStructBegin(const char* name) { + return writeJSONObjectStart(); +} + +uint32_t TJSONProtocol::writeStructEnd() { + return writeJSONObjectEnd(); +} + +uint32_t TJSONProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t result = writeJSONInteger(fieldId); + result += writeJSONObjectStart(); + result += writeJSONString(getTypeNameForTypeID(fieldType)); + return result; +} + +uint32_t TJSONProtocol::writeFieldEnd() { + return writeJSONObjectEnd(); +} + +uint32_t TJSONProtocol::writeFieldStop() { + return 0; +} + +uint32_t TJSONProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONString(getTypeNameForTypeID(keyType)); + result += writeJSONString(getTypeNameForTypeID(valType)); + result += writeJSONInteger((int64_t)size); + result += writeJSONObjectStart(); + return result; +} + +uint32_t TJSONProtocol::writeMapEnd() { + return writeJSONObjectEnd() + writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONString(getTypeNameForTypeID(elemType)); + result += writeJSONInteger((int64_t)size); + return result; +} + +uint32_t TJSONProtocol::writeListEnd() { + return writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONString(getTypeNameForTypeID(elemType)); + result += writeJSONInteger((int64_t)size); + return result; +} + +uint32_t TJSONProtocol::writeSetEnd() { + return writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeBool(const bool value) { + return writeJSONInteger(value); +} + +uint32_t TJSONProtocol::writeByte(const int8_t byte) { + // writeByte() must be handled specially becuase boost::lexical cast sees + // int8_t as a text type instead of an integer type + return writeJSONInteger((int16_t)byte); +} + +uint32_t TJSONProtocol::writeI16(const int16_t i16) { + return writeJSONInteger(i16); +} + +uint32_t TJSONProtocol::writeI32(const int32_t i32) { + return writeJSONInteger(i32); +} + +uint32_t TJSONProtocol::writeI64(const int64_t i64) { + return writeJSONInteger(i64); +} + +uint32_t TJSONProtocol::writeDouble(const double dub) { + return writeJSONDouble(dub); +} + +uint32_t TJSONProtocol::writeString(const std::string& str) { + return writeJSONString(str); +} + +uint32_t TJSONProtocol::writeBinary(const std::string& str) { + return writeJSONBase64(str); +} + + /** + * Reading functions + */ + +// Reads 1 byte and verifies that it matches ch. +uint32_t TJSONProtocol::readJSONSyntaxChar(uint8_t ch) { + return readSyntaxChar(reader_, ch); +} + +// Decodes the four hex parts of a JSON escaped string character and returns +// the character via out. The first two characters must be "00". +uint32_t TJSONProtocol::readJSONEscapeChar(uint8_t *out) { + uint8_t b[2]; + readJSONSyntaxChar(kJSONZeroChar); + readJSONSyntaxChar(kJSONZeroChar); + b[0] = reader_.read(); + b[1] = reader_.read(); + *out = (hexVal(b[0]) << 4) + hexVal(b[1]); + return 4; +} + +// Decodes a JSON string, including unescaping, and returns the string via str +uint32_t TJSONProtocol::readJSONString(std::string &str, bool skipContext) { + uint32_t result = (skipContext ? 0 : context_->read(reader_)); + result += readJSONSyntaxChar(kJSONStringDelimiter); + uint8_t ch; + str.clear(); + while (true) { + ch = reader_.read(); + ++result; + if (ch == kJSONStringDelimiter) { + break; + } + if (ch == kJSONBackslash) { + ch = reader_.read(); + ++result; + if (ch == kJSONEscapeChar) { + result += readJSONEscapeChar(&ch); + } + else { + size_t pos = kEscapeChars.find(ch); + if (pos == std::string::npos) { + throw TProtocolException(TProtocolException::INVALID_DATA, + "Expected control char, got '" + + std::string((const char *)&ch, 1) + "'."); + } + ch = kEscapeCharVals[pos]; + } + } + str += ch; + } + return result; +} + +// Reads a block of base64 characters, decoding it, and returns via str +uint32_t TJSONProtocol::readJSONBase64(std::string &str) { + std::string tmp; + uint32_t result = readJSONString(tmp); + uint8_t *b = (uint8_t *)tmp.c_str(); + uint32_t len = tmp.length(); + str.clear(); + while (len >= 4) { + base64_decode(b, 4); + str.append((const char *)b, 3); + b += 4; + len -= 4; + } + // Don't decode if we hit the end or got a single leftover byte (invalid + // base64 but legal for skip of regular string type) + if (len > 1) { + base64_decode(b, len); + str.append((const char *)b, len - 1); + } + return result; +} + +// Reads a sequence of characters, stopping at the first one that is not +// a valid JSON numeric character. +uint32_t TJSONProtocol::readJSONNumericChars(std::string &str) { + uint32_t result = 0; + str.clear(); + while (true) { + uint8_t ch = reader_.peek(); + if (!isJSONNumeric(ch)) { + break; + } + reader_.read(); + str += ch; + ++result; + } + return result; +} + +// Reads a sequence of characters and assembles them into a number, +// returning them via num +template <typename NumberType> +uint32_t TJSONProtocol::readJSONInteger(NumberType &num) { + uint32_t result = context_->read(reader_); + if (context_->escapeNum()) { + result += readJSONSyntaxChar(kJSONStringDelimiter); + } + std::string str; + result += readJSONNumericChars(str); + try { + num = boost::lexical_cast<NumberType>(str); + } + catch (boost::bad_lexical_cast e) { + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Expected numeric value; got \"" + str + + "\""); + } + if (context_->escapeNum()) { + result += readJSONSyntaxChar(kJSONStringDelimiter); + } + return result; +} + +// Reads a JSON number or string and interprets it as a double. +uint32_t TJSONProtocol::readJSONDouble(double &num) { + uint32_t result = context_->read(reader_); + std::string str; + if (reader_.peek() == kJSONStringDelimiter) { + result += readJSONString(str, true); + // Check for NaN, Infinity and -Infinity + if (str == kThriftNan) { + num = HUGE_VAL/HUGE_VAL; // generates NaN + } + else if (str == kThriftInfinity) { + num = HUGE_VAL; + } + else if (str == kThriftNegativeInfinity) { + num = -HUGE_VAL; + } + else { + if (!context_->escapeNum()) { + // Throw exception -- we should not be in a string in this case + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Numeric data unexpectedly quoted"); + } + try { + num = boost::lexical_cast<double>(str); + } + catch (boost::bad_lexical_cast e) { + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Expected numeric value; got \"" + str + + "\""); + } + } + } + else { + if (context_->escapeNum()) { + // This will throw - we should have had a quote if escapeNum == true + readJSONSyntaxChar(kJSONStringDelimiter); + } + result += readJSONNumericChars(str); + try { + num = boost::lexical_cast<double>(str); + } + catch (boost::bad_lexical_cast e) { + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Expected numeric value; got \"" + str + + "\""); + } + } + return result; +} + +uint32_t TJSONProtocol::readJSONObjectStart() { + uint32_t result = context_->read(reader_); + result += readJSONSyntaxChar(kJSONObjectStart); + pushContext(boost::shared_ptr<TJSONContext>(new JSONPairContext())); + return result; +} + +uint32_t TJSONProtocol::readJSONObjectEnd() { + uint32_t result = readJSONSyntaxChar(kJSONObjectEnd); + popContext(); + return result; +} + +uint32_t TJSONProtocol::readJSONArrayStart() { + uint32_t result = context_->read(reader_); + result += readJSONSyntaxChar(kJSONArrayStart); + pushContext(boost::shared_ptr<TJSONContext>(new JSONListContext())); + return result; +} + +uint32_t TJSONProtocol::readJSONArrayEnd() { + uint32_t result = readJSONSyntaxChar(kJSONArrayEnd); + popContext(); + return result; +} + +uint32_t TJSONProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t result = readJSONArrayStart(); + uint64_t tmpVal = 0; + result += readJSONInteger(tmpVal); + if (tmpVal != kThriftVersion1) { + throw TProtocolException(TProtocolException::BAD_VERSION, + "Message contained bad version."); + } + result += readJSONString(name); + result += readJSONInteger(tmpVal); + messageType = (TMessageType)tmpVal; + result += readJSONInteger(tmpVal); + seqid = tmpVal; + return result; +} + +uint32_t TJSONProtocol::readMessageEnd() { + return readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readStructBegin(std::string& name) { + return readJSONObjectStart(); +} + +uint32_t TJSONProtocol::readStructEnd() { + return readJSONObjectEnd(); +} + +uint32_t TJSONProtocol::readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t result = 0; + // Check if we hit the end of the list + uint8_t ch = reader_.peek(); + if (ch == kJSONObjectEnd) { + fieldType = apache::thrift::protocol::T_STOP; + } + else { + uint64_t tmpVal = 0; + std::string tmpStr; + result += readJSONInteger(tmpVal); + fieldId = tmpVal; + result += readJSONObjectStart(); + result += readJSONString(tmpStr); + fieldType = getTypeIDForTypeName(tmpStr); + } + return result; +} + +uint32_t TJSONProtocol::readFieldEnd() { + return readJSONObjectEnd(); +} + +uint32_t TJSONProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + uint64_t tmpVal = 0; + std::string tmpStr; + uint32_t result = readJSONArrayStart(); + result += readJSONString(tmpStr); + keyType = getTypeIDForTypeName(tmpStr); + result += readJSONString(tmpStr); + valType = getTypeIDForTypeName(tmpStr); + result += readJSONInteger(tmpVal); + size = tmpVal; + result += readJSONObjectStart(); + return result; +} + +uint32_t TJSONProtocol::readMapEnd() { + return readJSONObjectEnd() + readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readListBegin(TType& elemType, + uint32_t& size) { + uint64_t tmpVal = 0; + std::string tmpStr; + uint32_t result = readJSONArrayStart(); + result += readJSONString(tmpStr); + elemType = getTypeIDForTypeName(tmpStr); + result += readJSONInteger(tmpVal); + size = tmpVal; + return result; +} + +uint32_t TJSONProtocol::readListEnd() { + return readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + uint64_t tmpVal = 0; + std::string tmpStr; + uint32_t result = readJSONArrayStart(); + result += readJSONString(tmpStr); + elemType = getTypeIDForTypeName(tmpStr); + result += readJSONInteger(tmpVal); + size = tmpVal; + return result; +} + +uint32_t TJSONProtocol::readSetEnd() { + return readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readBool(bool& value) { + return readJSONInteger(value); +} + +// readByte() must be handled properly becuase boost::lexical cast sees int8_t +// as a text type instead of an integer type +uint32_t TJSONProtocol::readByte(int8_t& byte) { + int16_t tmp = (int16_t) byte; + uint32_t result = readJSONInteger(tmp); + assert(tmp < 256); + byte = (int8_t)tmp; + return result; +} + +uint32_t TJSONProtocol::readI16(int16_t& i16) { + return readJSONInteger(i16); +} + +uint32_t TJSONProtocol::readI32(int32_t& i32) { + return readJSONInteger(i32); +} + +uint32_t TJSONProtocol::readI64(int64_t& i64) { + return readJSONInteger(i64); +} + +uint32_t TJSONProtocol::readDouble(double& dub) { + return readJSONDouble(dub); +} + +uint32_t TJSONProtocol::readString(std::string &str) { + return readJSONString(str); +} + +uint32_t TJSONProtocol::readBinary(std::string &str) { + return readJSONBase64(str); +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TJSONProtocol.h b/lib/cpp/src/protocol/TJSONProtocol.h new file mode 100644 index 000000000..2df499ac0 --- /dev/null +++ b/lib/cpp/src/protocol/TJSONProtocol.h @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ 1 + +#include "TProtocol.h" + +#include <stack> + +namespace apache { namespace thrift { namespace protocol { + +// Forward declaration +class TJSONContext; + +/** + * JSON protocol for Thrift. + * + * Implements a protocol which uses JSON as the wire-format. + * + * Thrift types are represented as described below: + * + * 1. Every Thrift integer type is represented as a JSON number. + * + * 2. Thrift doubles are represented as JSON numbers. Some special values are + * represented as strings: + * a. "NaN" for not-a-number values + * b. "Infinity" for postive infinity + * c. "-Infinity" for negative infinity + * + * 3. Thrift string values are emitted as JSON strings, with appropriate + * escaping. + * + * 4. Thrift binary values are encoded into Base64 and emitted as JSON strings. + * The readBinary() method is written such that it will properly skip if + * called on a Thrift string (although it will decode garbage data). + * + * 5. Thrift structs are represented as JSON objects, with the field ID as the + * key, and the field value represented as a JSON object with a single + * key-value pair. The key is a short string identifier for that type, + * followed by the value. The valid type identifiers are: "tf" for bool, + * "i8" for byte, "i16" for 16-bit integer, "i32" for 32-bit integer, "i64" + * for 64-bit integer, "dbl" for double-precision loating point, "str" for + * string (including binary), "rec" for struct ("records"), "map" for map, + * "lst" for list, "set" for set. + * + * 6. Thrift lists and sets are represented as JSON arrays, with the first + * element of the JSON array being the string identifier for the Thrift + * element type and the second element of the JSON array being the count of + * the Thrift elements. The Thrift elements then follow. + * + * 7. Thrift maps are represented as JSON arrays, with the first two elements + * of the JSON array being the string identifiers for the Thrift key type + * and value type, followed by the count of the Thrift pairs, followed by a + * JSON object containing the key-value pairs. Note that JSON keys can only + * be strings, which means that the key type of the Thrift map should be + * restricted to numeric or string types -- in the case of numerics, they + * are serialized as strings. + * + * 8. Thrift messages are represented as JSON arrays, with the protocol + * version #, the message name, the message type, and the sequence ID as + * the first 4 elements. + * + * More discussion of the double handling is probably warranted. The aim of + * the current implementation is to match as closely as possible the behavior + * of Java's Double.toString(), which has no precision loss. Implementors in + * other languages should strive to achieve that where possible. I have not + * yet verified whether boost:lexical_cast, which is doing that work for me in + * C++, loses any precision, but I am leaving this as a future improvement. I + * may try to provide a C component for this, so that other languages could + * bind to the same underlying implementation for maximum consistency. + * + * Note further that JavaScript itself is not capable of representing + * floating point infinities -- presumably when we have a JavaScript Thrift + * client, this would mean that infinities get converted to not-a-number in + * transmission. I don't know of any work-around for this issue. + * + */ +class TJSONProtocol : public TProtocol { + public: + + TJSONProtocol(boost::shared_ptr<TTransport> ptrans); + + ~TJSONProtocol(); + + private: + + void pushContext(boost::shared_ptr<TJSONContext> c); + + void popContext(); + + uint32_t writeJSONEscapeChar(uint8_t ch); + + uint32_t writeJSONChar(uint8_t ch); + + uint32_t writeJSONString(const std::string &str); + + uint32_t writeJSONBase64(const std::string &str); + + template <typename NumberType> + uint32_t writeJSONInteger(NumberType num); + + uint32_t writeJSONDouble(double num); + + uint32_t writeJSONObjectStart() ; + + uint32_t writeJSONObjectEnd(); + + uint32_t writeJSONArrayStart(); + + uint32_t writeJSONArrayEnd(); + + uint32_t readJSONSyntaxChar(uint8_t ch); + + uint32_t readJSONEscapeChar(uint8_t *out); + + uint32_t readJSONString(std::string &str, bool skipContext = false); + + uint32_t readJSONBase64(std::string &str); + + uint32_t readJSONNumericChars(std::string &str); + + template <typename NumberType> + uint32_t readJSONInteger(NumberType &num); + + uint32_t readJSONDouble(double &num); + + uint32_t readJSONObjectStart(); + + uint32_t readJSONObjectEnd(); + + uint32_t readJSONArrayStart(); + + uint32_t readJSONArrayEnd(); + + public: + + /** + * Writing functions. + */ + + uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + uint32_t writeMessageEnd(); + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldEnd(); + + uint32_t writeFieldStop(); + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeMapEnd(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeListEnd(); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetEnd(); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + /** + * Reading functions + */ + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readMessageEnd(); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readFieldEnd(); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readMapEnd(); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readListEnd(); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetEnd(); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + class LookaheadReader { + + public: + + LookaheadReader(TTransport &trans) : + trans_(&trans), + hasData_(false) { + } + + uint8_t read() { + if (hasData_) { + hasData_ = false; + } + else { + trans_->readAll(&data_, 1); + } + return data_; + } + + uint8_t peek() { + if (!hasData_) { + trans_->readAll(&data_, 1); + } + hasData_ = true; + return data_; + } + + private: + TTransport *trans_; + bool hasData_; + uint8_t data_; + }; + + private: + + std::stack<boost::shared_ptr<TJSONContext> > contexts_; + boost::shared_ptr<TJSONContext> context_; + LookaheadReader reader_; +}; + +/** + * Constructs input and output protocol objects given transports. + */ +class TJSONProtocolFactory : public TProtocolFactory { + public: + TJSONProtocolFactory() {} + + virtual ~TJSONProtocolFactory() {} + + boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TProtocol>(new TJSONProtocol(trans)); + } +}; + +}}} // apache::thrift::protocol + + +// TODO(dreiss): Move part of ThriftJSONString into a .cpp file and remove this. +#include <transport/TBufferTransports.h> + +namespace apache { namespace thrift { + +template<typename ThriftStruct> + std::string ThriftJSONString(const ThriftStruct& ts) { + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + TMemoryBuffer* buffer = new TMemoryBuffer; + boost::shared_ptr<TTransport> trans(buffer); + TJSONProtocol protocol(trans); + + ts.write(&protocol); + + uint8_t* buf; + uint32_t size; + buffer->getBuffer(&buf, &size); + return std::string((char*)buf, (unsigned int)size); +} + +}} // apache::thrift + +#endif // #define _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ 1 diff --git a/lib/cpp/src/protocol/TOneWayProtocol.h b/lib/cpp/src/protocol/TOneWayProtocol.h new file mode 100644 index 000000000..6f08fe1d7 --- /dev/null +++ b/lib/cpp/src/protocol/TOneWayProtocol.h @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TONEWAYPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TONEWAYPROTOCOL_H_ 1 + +#include "TProtocol.h" + +namespace apache { namespace thrift { namespace protocol { + +/** + * Abstract class for implementing a protocol that can only be written, + * not read. + * + */ +class TWriteOnlyProtocol : public TProtocol { + public: + /** + * @param subclass_name The name of the concrete subclass. + */ + TWriteOnlyProtocol(boost::shared_ptr<TTransport> trans, + const std::string& subclass_name) + : TProtocol(trans) + , subclass_(subclass_name) + {} + + // All writing functions remain abstract. + + /** + * Reading functions all throw an exception. + */ + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readMessageEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readStructBegin(std::string& name) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readStructEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readFieldEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readMapEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readListBegin(TType& elemType, + uint32_t& size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readListEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readSetBegin(TType& elemType, + uint32_t& size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readSetEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readBool(bool& value) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readByte(int8_t& byte) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readI16(int16_t& i16) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readI32(int32_t& i32) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readI64(int64_t& i64) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readDouble(double& dub) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readString(std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readBinary(std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + private: + std::string subclass_; +}; + + +/** + * Abstract class for implementing a protocol that can only be read, + * not written. + * + */ +class TReadOnlyProtocol : public TProtocol { + public: + /** + * @param subclass_name The name of the concrete subclass. + */ + TReadOnlyProtocol(boost::shared_ptr<TTransport> trans, + const std::string& subclass_name) + : TProtocol(trans) + , subclass_(subclass_name) + {} + + // All reading functions remain abstract. + + /** + * Writing functions all throw an exception. + */ + + uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeMessageEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + + uint32_t writeStructBegin(const char* name) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeStructEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeFieldEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeFieldStop() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeMapEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeListBegin(const TType elemType, + const uint32_t size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeListEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeSetEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeBool(const bool value) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeByte(const int8_t byte) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeI16(const int16_t i16) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeI32(const int32_t i32) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeI64(const int64_t i64) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeDouble(const double dub) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeString(const std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeBinary(const std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + private: + std::string subclass_; +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TProtocol.h b/lib/cpp/src/protocol/TProtocol.h new file mode 100644 index 000000000..40258277d --- /dev/null +++ b/lib/cpp/src/protocol/TProtocol.h @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1 + +#include <transport/TTransport.h> +#include <protocol/TProtocolException.h> + +#include <boost/shared_ptr.hpp> +#include <boost/static_assert.hpp> + +#include <netinet/in.h> +#include <sys/types.h> +#include <string> +#include <map> + + +// Use this to get around strict aliasing rules. +// For example, uint64_t i = bitwise_cast<uint64_t>(returns_double()); +// The most obvious implementation is to just cast a pointer, +// but that doesn't work. +// For a pretty in-depth explanation of the problem, see +// http://www.cellperformance.com/mike_acton/2006/06/ (...) +// understanding_strict_aliasing.html +template <typename To, typename From> +static inline To bitwise_cast(From from) { + BOOST_STATIC_ASSERT(sizeof(From) == sizeof(To)); + + // BAD!!! These are all broken with -O2. + //return *reinterpret_cast<To*>(&from); // BAD!!! + //return *static_cast<To*>(static_cast<void*>(&from)); // BAD!!! + //return *(To*)(void*)&from; // BAD!!! + + // Super clean and paritally blessed by section 3.9 of the standard. + //unsigned char c[sizeof(from)]; + //memcpy(c, &from, sizeof(from)); + //To to; + //memcpy(&to, c, sizeof(c)); + //return to; + + // Slightly more questionable. + // Same code emitted by GCC. + //To to; + //memcpy(&to, &from, sizeof(from)); + //return to; + + // Technically undefined, but almost universally supported, + // and the most efficient implementation. + union { + From f; + To t; + } u; + u.f = from; + return u.t; +} + + +namespace apache { namespace thrift { namespace protocol { + +using apache::thrift::transport::TTransport; + +#ifdef HAVE_ENDIAN_H +#include <endian.h> +#endif + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +# define __BYTE_ORDER BYTE_ORDER +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __BIG_ENDIAN BIG_ENDIAN +# else +# error "Cannot determine endianness" +# endif +#endif + +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +# if defined(__GNUC__) && defined(__GLIBC__) +# include <byteswap.h> +# define htolell(n) bswap_64(n) +# define letohll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define bswap_64(n) \ + ( (((n) & 0xff00000000000000ull) >> 56) \ + | (((n) & 0x00ff000000000000ull) >> 40) \ + | (((n) & 0x0000ff0000000000ull) >> 24) \ + | (((n) & 0x000000ff00000000ull) >> 8) \ + | (((n) & 0x00000000ff000000ull) << 8) \ + | (((n) & 0x0000000000ff0000ull) << 24) \ + | (((n) & 0x000000000000ff00ull) << 40) \ + | (((n) & 0x00000000000000ffull) << 56) ) +# define ntolell(n) bswap_64(n) +# define letonll(n) bswap_64(n) +# endif /* GNUC & GLIBC */ +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# define htolell(n) (n) +# define letohll(n) (n) +# if defined(__GNUC__) && defined(__GLIBC__) +# include <byteswap.h> +# define ntohll(n) bswap_64(n) +# define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +/** + * Enumerated definition of the types that the Thrift protocol supports. + * Take special note of the T_END type which is used specifically to mark + * the end of a sequence of fields. + */ +enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +}; + +/** + * Enumerated definition of the message types that the Thrift protocol + * supports. + */ +enum TMessageType { + T_CALL = 1, + T_REPLY = 2, + T_EXCEPTION = 3, + T_ONEWAY = 4 +}; + +/** + * Abstract class for a thrift protocol driver. These are all the methods that + * a protocol must implement. Essentially, there must be some way of reading + * and writing all the base types, plus a mechanism for writing out structs + * with indexed fields. + * + * TProtocol objects should not be shared across multiple encoding contexts, + * as they may need to maintain internal state in some protocols (i.e. XML). + * Note that is is acceptable for the TProtocol module to do its own internal + * buffered reads/writes to the underlying TTransport where appropriate (i.e. + * when parsing an input XML stream, reading should be batched rather than + * looking ahead character by character for a close tag). + * + */ +class TProtocol { + public: + virtual ~TProtocol() {} + + /** + * Writing functions. + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) = 0; + + virtual uint32_t writeMessageEnd() = 0; + + + virtual uint32_t writeStructBegin(const char* name) = 0; + + virtual uint32_t writeStructEnd() = 0; + + virtual uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) = 0; + + virtual uint32_t writeFieldEnd() = 0; + + virtual uint32_t writeFieldStop() = 0; + + virtual uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) = 0; + + virtual uint32_t writeMapEnd() = 0; + + virtual uint32_t writeListBegin(const TType elemType, + const uint32_t size) = 0; + + virtual uint32_t writeListEnd() = 0; + + virtual uint32_t writeSetBegin(const TType elemType, + const uint32_t size) = 0; + + virtual uint32_t writeSetEnd() = 0; + + virtual uint32_t writeBool(const bool value) = 0; + + virtual uint32_t writeByte(const int8_t byte) = 0; + + virtual uint32_t writeI16(const int16_t i16) = 0; + + virtual uint32_t writeI32(const int32_t i32) = 0; + + virtual uint32_t writeI64(const int64_t i64) = 0; + + virtual uint32_t writeDouble(const double dub) = 0; + + virtual uint32_t writeString(const std::string& str) = 0; + + virtual uint32_t writeBinary(const std::string& str) = 0; + + /** + * Reading functions + */ + + virtual uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) = 0; + + virtual uint32_t readMessageEnd() = 0; + + virtual uint32_t readStructBegin(std::string& name) = 0; + + virtual uint32_t readStructEnd() = 0; + + virtual uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) = 0; + + virtual uint32_t readFieldEnd() = 0; + + virtual uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) = 0; + + virtual uint32_t readMapEnd() = 0; + + virtual uint32_t readListBegin(TType& elemType, + uint32_t& size) = 0; + + virtual uint32_t readListEnd() = 0; + + virtual uint32_t readSetBegin(TType& elemType, + uint32_t& size) = 0; + + virtual uint32_t readSetEnd() = 0; + + virtual uint32_t readBool(bool& value) = 0; + + virtual uint32_t readByte(int8_t& byte) = 0; + + virtual uint32_t readI16(int16_t& i16) = 0; + + virtual uint32_t readI32(int32_t& i32) = 0; + + virtual uint32_t readI64(int64_t& i64) = 0; + + virtual uint32_t readDouble(double& dub) = 0; + + virtual uint32_t readString(std::string& str) = 0; + + virtual uint32_t readBinary(std::string& str) = 0; + + uint32_t readBool(std::vector<bool>::reference ref) { + bool value; + uint32_t rv = readBool(value); + ref = value; + return rv; + } + + /** + * Method to arbitrarily skip over data. + */ + uint32_t skip(TType type) { + switch (type) { + case T_BOOL: + { + bool boolv; + return readBool(boolv); + } + case T_BYTE: + { + int8_t bytev; + return readByte(bytev); + } + case T_I16: + { + int16_t i16; + return readI16(i16); + } + case T_I32: + { + int32_t i32; + return readI32(i32); + } + case T_I64: + { + int64_t i64; + return readI64(i64); + } + case T_DOUBLE: + { + double dub; + return readDouble(dub); + } + case T_STRING: + { + std::string str; + return readBinary(str); + } + case T_STRUCT: + { + uint32_t result = 0; + std::string name; + int16_t fid; + TType ftype; + result += readStructBegin(name); + while (true) { + result += readFieldBegin(name, ftype, fid); + if (ftype == T_STOP) { + break; + } + result += skip(ftype); + result += readFieldEnd(); + } + result += readStructEnd(); + return result; + } + case T_MAP: + { + uint32_t result = 0; + TType keyType; + TType valType; + uint32_t i, size; + result += readMapBegin(keyType, valType, size); + for (i = 0; i < size; i++) { + result += skip(keyType); + result += skip(valType); + } + result += readMapEnd(); + return result; + } + case T_SET: + { + uint32_t result = 0; + TType elemType; + uint32_t i, size; + result += readSetBegin(elemType, size); + for (i = 0; i < size; i++) { + result += skip(elemType); + } + result += readSetEnd(); + return result; + } + case T_LIST: + { + uint32_t result = 0; + TType elemType; + uint32_t i, size; + result += readListBegin(elemType, size); + for (i = 0; i < size; i++) { + result += skip(elemType); + } + result += readListEnd(); + return result; + } + default: + return 0; + } + } + + inline boost::shared_ptr<TTransport> getTransport() { + return ptrans_; + } + + // TODO: remove these two calls, they are for backwards + // compatibility + inline boost::shared_ptr<TTransport> getInputTransport() { + return ptrans_; + } + inline boost::shared_ptr<TTransport> getOutputTransport() { + return ptrans_; + } + + protected: + TProtocol(boost::shared_ptr<TTransport> ptrans): + ptrans_(ptrans) { + trans_ = ptrans.get(); + } + + boost::shared_ptr<TTransport> ptrans_; + TTransport* trans_; + + private: + TProtocol() {} +}; + +/** + * Constructs input and output protocol objects given transports. + */ +class TProtocolFactory { + public: + TProtocolFactory() {} + + virtual ~TProtocolFactory() {} + + virtual boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) = 0; +}; + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1 diff --git a/lib/cpp/src/protocol/TProtocolException.h b/lib/cpp/src/protocol/TProtocolException.h new file mode 100644 index 000000000..33011b379 --- /dev/null +++ b/lib/cpp/src/protocol/TProtocolException.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ +#define _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ 1 + +#include <string> + +namespace apache { namespace thrift { namespace protocol { + +/** + * Class to encapsulate all the possible types of protocol errors that may + * occur in various protocol systems. This provides a sort of generic + * wrapper around the shitty UNIX E_ error codes that lets a common code + * base of error handling to be used for various types of protocols, i.e. + * pipes etc. + * + */ +class TProtocolException : public apache::thrift::TException { + public: + + /** + * Error codes for the various types of exceptions. + */ + enum TProtocolExceptionType + { UNKNOWN = 0 + , INVALID_DATA = 1 + , NEGATIVE_SIZE = 2 + , SIZE_LIMIT = 3 + , BAD_VERSION = 4 + , NOT_IMPLEMENTED = 5 + }; + + TProtocolException() : + apache::thrift::TException(), + type_(UNKNOWN) {} + + TProtocolException(TProtocolExceptionType type) : + apache::thrift::TException(), + type_(type) {} + + TProtocolException(const std::string& message) : + apache::thrift::TException(message), + type_(UNKNOWN) {} + + TProtocolException(TProtocolExceptionType type, const std::string& message) : + apache::thrift::TException(message), + type_(type) {} + + virtual ~TProtocolException() throw() {} + + /** + * Returns an error code that provides information about the type of error + * that has occurred. + * + * @return Error code + */ + TProtocolExceptionType getType() { + return type_; + } + + virtual const char* what() const throw() { + if (message_.empty()) { + switch (type_) { + case UNKNOWN : return "TProtocolException: Unknown protocol exception"; + case INVALID_DATA : return "TProtocolException: Invalid data"; + case NEGATIVE_SIZE : return "TProtocolException: Negative size"; + case SIZE_LIMIT : return "TProtocolException: Exceeded size limit"; + case BAD_VERSION : return "TProtocolException: Invalid version"; + case NOT_IMPLEMENTED : return "TProtocolException: Not implemented"; + default : return "TProtocolException: (Invalid exception type)"; + } + } else { + return message_.c_str(); + } + } + + protected: + /** + * Error code + */ + TProtocolExceptionType type_; + +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ diff --git a/lib/cpp/src/protocol/TProtocolTap.h b/lib/cpp/src/protocol/TProtocolTap.h new file mode 100644 index 000000000..5580216a3 --- /dev/null +++ b/lib/cpp/src/protocol/TProtocolTap.h @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ +#define _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ 1 + +#include <protocol/TOneWayProtocol.h> + +namespace apache { namespace thrift { namespace protocol { + +using apache::thrift::transport::TTransport; + +/** + * Puts a wiretap on a protocol object. Any reads to this class are passed + * through to an enclosed protocol object, but also mirrored as write to a + * second protocol object. + * + */ +class TProtocolTap : public TReadOnlyProtocol { + public: + TProtocolTap(boost::shared_ptr<TProtocol> source, + boost::shared_ptr<TProtocol> sink) + : TReadOnlyProtocol(source->getTransport(), "TProtocolTap") + , source_(source) + , sink_(sink) + {} + + virtual uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t rv = source_->readMessageBegin(name, messageType, seqid); + sink_->writeMessageBegin(name, messageType, seqid); + return rv; + } + + virtual uint32_t readMessageEnd() { + uint32_t rv = source_->readMessageEnd(); + sink_->writeMessageEnd(); + return rv; + } + + virtual uint32_t readStructBegin(std::string& name) { + uint32_t rv = source_->readStructBegin(name); + sink_->writeStructBegin(name.c_str()); + return rv; + } + + virtual uint32_t readStructEnd() { + uint32_t rv = source_->readStructEnd(); + sink_->writeStructEnd(); + return rv; + } + + virtual uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t rv = source_->readFieldBegin(name, fieldType, fieldId); + if (fieldType == T_STOP) { + sink_->writeFieldStop(); + } else { + sink_->writeFieldBegin(name.c_str(), fieldType, fieldId); + } + return rv; + } + + + virtual uint32_t readFieldEnd() { + uint32_t rv = source_->readFieldEnd(); + sink_->writeFieldEnd(); + return rv; + } + + virtual uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + uint32_t rv = source_->readMapBegin(keyType, valType, size); + sink_->writeMapBegin(keyType, valType, size); + return rv; + } + + + virtual uint32_t readMapEnd() { + uint32_t rv = source_->readMapEnd(); + sink_->writeMapEnd(); + return rv; + } + + virtual uint32_t readListBegin(TType& elemType, + uint32_t& size) { + uint32_t rv = source_->readListBegin(elemType, size); + sink_->writeListBegin(elemType, size); + return rv; + } + + + virtual uint32_t readListEnd() { + uint32_t rv = source_->readListEnd(); + sink_->writeListEnd(); + return rv; + } + + virtual uint32_t readSetBegin(TType& elemType, + uint32_t& size) { + uint32_t rv = source_->readSetBegin(elemType, size); + sink_->writeSetBegin(elemType, size); + return rv; + } + + + virtual uint32_t readSetEnd() { + uint32_t rv = source_->readSetEnd(); + sink_->writeSetEnd(); + return rv; + } + + virtual uint32_t readBool(bool& value) { + uint32_t rv = source_->readBool(value); + sink_->writeBool(value); + return rv; + } + + virtual uint32_t readByte(int8_t& byte) { + uint32_t rv = source_->readByte(byte); + sink_->writeByte(byte); + return rv; + } + + virtual uint32_t readI16(int16_t& i16) { + uint32_t rv = source_->readI16(i16); + sink_->writeI16(i16); + return rv; + } + + virtual uint32_t readI32(int32_t& i32) { + uint32_t rv = source_->readI32(i32); + sink_->writeI32(i32); + return rv; + } + + virtual uint32_t readI64(int64_t& i64) { + uint32_t rv = source_->readI64(i64); + sink_->writeI64(i64); + return rv; + } + + virtual uint32_t readDouble(double& dub) { + uint32_t rv = source_->readDouble(dub); + sink_->writeDouble(dub); + return rv; + } + + virtual uint32_t readString(std::string& str) { + uint32_t rv = source_->readString(str); + sink_->writeString(str); + return rv; + } + + virtual uint32_t readBinary(std::string& str) { + uint32_t rv = source_->readBinary(str); + sink_->writeBinary(str); + return rv; + } + + private: + boost::shared_ptr<TProtocol> source_; + boost::shared_ptr<TProtocol> sink_; +}; + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ 1 diff --git a/lib/cpp/src/server/TNonblockingServer.cpp b/lib/cpp/src/server/TNonblockingServer.cpp new file mode 100644 index 000000000..45f635cbe --- /dev/null +++ b/lib/cpp/src/server/TNonblockingServer.cpp @@ -0,0 +1,750 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TNonblockingServer.h" +#include <concurrency/Exception.h> + +#include <iostream> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <netdb.h> +#include <fcntl.h> +#include <errno.h> +#include <assert.h> + +namespace apache { namespace thrift { namespace server { + +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::concurrency; +using namespace std; + +class TConnection::Task: public Runnable { + public: + Task(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocol> input, + boost::shared_ptr<TProtocol> output, + int taskHandle) : + processor_(processor), + input_(input), + output_(output), + taskHandle_(taskHandle) {} + + void run() { + try { + while (processor_->process(input_, output_)) { + if (!input_->getTransport()->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + cerr << "TNonblockingServer client died: " << ttx.what() << endl; + } catch (TException& x) { + cerr << "TNonblockingServer exception: " << x.what() << endl; + } catch (...) { + cerr << "TNonblockingServer uncaught exception." << endl; + } + + // Signal completion back to the libevent thread via a socketpair + int8_t b = 0; + if (-1 == send(taskHandle_, &b, sizeof(int8_t), 0)) { + GlobalOutput.perror("TNonblockingServer::Task: send ", errno); + } + if (-1 == ::close(taskHandle_)) { + GlobalOutput.perror("TNonblockingServer::Task: close, possible resource leak ", errno); + } + } + + private: + boost::shared_ptr<TProcessor> processor_; + boost::shared_ptr<TProtocol> input_; + boost::shared_ptr<TProtocol> output_; + int taskHandle_; +}; + +void TConnection::init(int socket, short eventFlags, TNonblockingServer* s) { + socket_ = socket; + server_ = s; + appState_ = APP_INIT; + eventFlags_ = 0; + + readBufferPos_ = 0; + readWant_ = 0; + + writeBuffer_ = NULL; + writeBufferSize_ = 0; + writeBufferPos_ = 0; + + socketState_ = SOCKET_RECV; + appState_ = APP_INIT; + + taskHandle_ = -1; + + // Set flags, which also registers the event + setFlags(eventFlags); + + // get input/transports + factoryInputTransport_ = s->getInputTransportFactory()->getTransport(inputTransport_); + factoryOutputTransport_ = s->getOutputTransportFactory()->getTransport(outputTransport_); + + // Create protocol + inputProtocol_ = s->getInputProtocolFactory()->getProtocol(factoryInputTransport_); + outputProtocol_ = s->getOutputProtocolFactory()->getProtocol(factoryOutputTransport_); +} + +void TConnection::workSocket() { + int flags=0, got=0, left=0, sent=0; + uint32_t fetch = 0; + + switch (socketState_) { + case SOCKET_RECV: + // It is an error to be in this state if we already have all the data + assert(readBufferPos_ < readWant_); + + // Double the buffer size until it is big enough + if (readWant_ > readBufferSize_) { + while (readWant_ > readBufferSize_) { + readBufferSize_ *= 2; + } + readBuffer_ = (uint8_t*)std::realloc(readBuffer_, readBufferSize_); + if (readBuffer_ == NULL) { + GlobalOutput("TConnection::workSocket() realloc"); + close(); + return; + } + } + + // Read from the socket + fetch = readWant_ - readBufferPos_; + got = recv(socket_, readBuffer_ + readBufferPos_, fetch, 0); + + if (got > 0) { + // Move along in the buffer + readBufferPos_ += got; + + // Check that we did not overdo it + assert(readBufferPos_ <= readWant_); + + // We are done reading, move onto the next state + if (readBufferPos_ == readWant_) { + transition(); + } + return; + } else if (got == -1) { + // Blocking errors are okay, just move on + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + + if (errno != ECONNRESET) { + GlobalOutput.perror("TConnection::workSocket() recv -1 ", errno); + } + } + + // Whenever we get down here it means a remote disconnect + close(); + + return; + + case SOCKET_SEND: + // Should never have position past size + assert(writeBufferPos_ <= writeBufferSize_); + + // If there is no data to send, then let us move on + if (writeBufferPos_ == writeBufferSize_) { + GlobalOutput("WARNING: Send state with no data to send\n"); + transition(); + return; + } + + flags = 0; + #ifdef MSG_NOSIGNAL + // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we + // check for the EPIPE return condition and close the socket in that case + flags |= MSG_NOSIGNAL; + #endif // ifdef MSG_NOSIGNAL + + left = writeBufferSize_ - writeBufferPos_; + sent = send(socket_, writeBuffer_ + writeBufferPos_, left, flags); + + if (sent <= 0) { + // Blocking errors are okay, just move on + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + if (errno != EPIPE) { + GlobalOutput.perror("TConnection::workSocket() send -1 ", errno); + } + close(); + return; + } + + writeBufferPos_ += sent; + + // Did we overdo it? + assert(writeBufferPos_ <= writeBufferSize_); + + // We are done! + if (writeBufferPos_ == writeBufferSize_) { + transition(); + } + + return; + + default: + GlobalOutput.printf("Shit Got Ill. Socket State %d", socketState_); + assert(0); + } +} + +/** + * This is called when the application transitions from one state into + * another. This means that it has finished writing the data that it needed + * to, or finished receiving the data that it needed to. + */ +void TConnection::transition() { + + int sz = 0; + + // Switch upon the state that we are currently in and move to a new state + switch (appState_) { + + case APP_READ_REQUEST: + // We are done reading the request, package the read buffer into transport + // and get back some data from the dispatch function + // If we've used these transport buffers enough times, reset them to avoid bloating + + inputTransport_->resetBuffer(readBuffer_, readBufferPos_); + ++numReadsSinceReset_; + if (numWritesSinceReset_ < 512) { + outputTransport_->resetBuffer(); + } else { + // reset the capacity of the output transport if we used it enough times that it might be bloated + try { + outputTransport_->resetBuffer(true); + numWritesSinceReset_ = 0; + } catch (TTransportException &ttx) { + GlobalOutput.printf("TTransportException: TMemoryBuffer::resetBuffer() %s", ttx.what()); + close(); + return; + } + } + + // Prepend four bytes of blank space to the buffer so we can + // write the frame size there later. + outputTransport_->getWritePtr(4); + outputTransport_->wroteBytes(4); + + if (server_->isThreadPoolProcessing()) { + // We are setting up a Task to do this work and we will wait on it + int sv[2]; + if (-1 == socketpair(AF_LOCAL, SOCK_STREAM, 0, sv)) { + GlobalOutput.perror("TConnection::socketpair() failed ", errno); + // Now we will fall through to the APP_WAIT_TASK block with no response + } else { + // Create task and dispatch to the thread manager + boost::shared_ptr<Runnable> task = + boost::shared_ptr<Runnable>(new Task(server_->getProcessor(), + inputProtocol_, + outputProtocol_, + sv[1])); + // The application is now waiting on the task to finish + appState_ = APP_WAIT_TASK; + + // Create an event to be notified when the task finishes + event_set(&taskEvent_, + taskHandle_ = sv[0], + EV_READ, + TConnection::taskHandler, + this); + + // Attach to the base + event_base_set(server_->getEventBase(), &taskEvent_); + + // Add the event and start up the server + if (-1 == event_add(&taskEvent_, 0)) { + GlobalOutput("TNonblockingServer::serve(): coult not event_add"); + return; + } + try { + server_->addTask(task); + } catch (IllegalStateException & ise) { + // The ThreadManager is not ready to handle any more tasks (it's probably shutting down). + GlobalOutput.printf("IllegalStateException: Server::process() %s", ise.what()); + close(); + } + + // Set this connection idle so that libevent doesn't process more + // data on it while we're still waiting for the threadmanager to + // finish this task + setIdle(); + return; + } + } else { + try { + // Invoke the processor + server_->getProcessor()->process(inputProtocol_, outputProtocol_); + } catch (TTransportException &ttx) { + GlobalOutput.printf("TTransportException: Server::process() %s", ttx.what()); + close(); + return; + } catch (TException &x) { + GlobalOutput.printf("TException: Server::process() %s", x.what()); + close(); + return; + } catch (...) { + GlobalOutput.printf("Server::process() unknown exception"); + close(); + return; + } + } + + // Intentionally fall through here, the call to process has written into + // the writeBuffer_ + + case APP_WAIT_TASK: + // We have now finished processing a task and the result has been written + // into the outputTransport_, so we grab its contents and place them into + // the writeBuffer_ for actual writing by the libevent thread + + // Get the result of the operation + outputTransport_->getBuffer(&writeBuffer_, &writeBufferSize_); + + // If the function call generated return data, then move into the send + // state and get going + // 4 bytes were reserved for frame size + if (writeBufferSize_ > 4) { + + // Move into write state + writeBufferPos_ = 0; + socketState_ = SOCKET_SEND; + + // Put the frame size into the write buffer + int32_t frameSize = (int32_t)htonl(writeBufferSize_ - 4); + memcpy(writeBuffer_, &frameSize, 4); + + // Socket into write mode + appState_ = APP_SEND_RESULT; + setWrite(); + + // Try to work the socket immediately + // workSocket(); + + return; + } + + // In this case, the request was oneway and we should fall through + // right back into the read frame header state + goto LABEL_APP_INIT; + + case APP_SEND_RESULT: + + ++numWritesSinceReset_; + + // N.B.: We also intentionally fall through here into the INIT state! + + LABEL_APP_INIT: + case APP_INIT: + + // reset the input buffer if we used it enough times that it might be bloated + if (numReadsSinceReset_ > 512) + { + void * new_buffer = std::realloc(readBuffer_, 1024); + if (new_buffer == NULL) { + GlobalOutput("TConnection::transition() realloc"); + close(); + return; + } + readBuffer_ = (uint8_t*) new_buffer; + readBufferSize_ = 1024; + numReadsSinceReset_ = 0; + } + + // Clear write buffer variables + writeBuffer_ = NULL; + writeBufferPos_ = 0; + writeBufferSize_ = 0; + + // Set up read buffer for getting 4 bytes + readBufferPos_ = 0; + readWant_ = 4; + + // Into read4 state we go + socketState_ = SOCKET_RECV; + appState_ = APP_READ_FRAME_SIZE; + + // Register read event + setRead(); + + // Try to work the socket right away + // workSocket(); + + return; + + case APP_READ_FRAME_SIZE: + // We just read the request length, deserialize it + sz = *(int32_t*)readBuffer_; + sz = (int32_t)ntohl(sz); + + if (sz <= 0) { + GlobalOutput.printf("TConnection:transition() Negative frame size %d, remote side not using TFramedTransport?", sz); + close(); + return; + } + + // Reset the read buffer + readWant_ = (uint32_t)sz; + readBufferPos_= 0; + + // Move into read request state + appState_ = APP_READ_REQUEST; + + // Work the socket right away + // workSocket(); + + return; + + default: + GlobalOutput.printf("Totally Fucked. Application State %d", appState_); + assert(0); + } +} + +void TConnection::setFlags(short eventFlags) { + // Catch the do nothing case + if (eventFlags_ == eventFlags) { + return; + } + + // Delete a previously existing event + if (eventFlags_ != 0) { + if (event_del(&event_) == -1) { + GlobalOutput("TConnection::setFlags event_del"); + return; + } + } + + // Update in memory structure + eventFlags_ = eventFlags; + + // Do not call event_set if there are no flags + if (!eventFlags_) { + return; + } + + /** + * event_set: + * + * Prepares the event structure &event to be used in future calls to + * event_add() and event_del(). The event will be prepared to call the + * eventHandler using the 'sock' file descriptor to monitor events. + * + * The events can be either EV_READ, EV_WRITE, or both, indicating + * that an application can read or write from the file respectively without + * blocking. + * + * The eventHandler will be called with the file descriptor that triggered + * the event and the type of event which will be one of: EV_TIMEOUT, + * EV_SIGNAL, EV_READ, EV_WRITE. + * + * The additional flag EV_PERSIST makes an event_add() persistent until + * event_del() has been called. + * + * Once initialized, the &event struct can be used repeatedly with + * event_add() and event_del() and does not need to be reinitialized unless + * the eventHandler and/or the argument to it are to be changed. However, + * when an ev structure has been added to libevent using event_add() the + * structure must persist until the event occurs (assuming EV_PERSIST + * is not set) or is removed using event_del(). You may not reuse the same + * ev structure for multiple monitored descriptors; each descriptor needs + * its own ev. + */ + event_set(&event_, socket_, eventFlags_, TConnection::eventHandler, this); + event_base_set(server_->getEventBase(), &event_); + + // Add the event + if (event_add(&event_, 0) == -1) { + GlobalOutput("TConnection::setFlags(): could not event_add"); + } +} + +/** + * Closes a connection + */ +void TConnection::close() { + // Delete the registered libevent + if (event_del(&event_) == -1) { + GlobalOutput("TConnection::close() event_del"); + } + + // Close the socket + if (socket_ > 0) { + ::close(socket_); + } + socket_ = 0; + + // close any factory produced transports + factoryInputTransport_->close(); + factoryOutputTransport_->close(); + + // Give this object back to the server that owns it + server_->returnConnection(this); +} + +void TConnection::checkIdleBufferMemLimit(uint32_t limit) { + if (readBufferSize_ > limit) { + readBufferSize_ = limit; + readBuffer_ = (uint8_t*)std::realloc(readBuffer_, readBufferSize_); + if (readBuffer_ == NULL) { + GlobalOutput("TConnection::checkIdleBufferMemLimit() realloc"); + close(); + } + } +} + +/** + * Creates a new connection either by reusing an object off the stack or + * by allocating a new one entirely + */ +TConnection* TNonblockingServer::createConnection(int socket, short flags) { + // Check the stack + if (connectionStack_.empty()) { + return new TConnection(socket, flags, this); + } else { + TConnection* result = connectionStack_.top(); + connectionStack_.pop(); + result->init(socket, flags, this); + return result; + } +} + +/** + * Returns a connection to the stack + */ +void TNonblockingServer::returnConnection(TConnection* connection) { + if (connectionStackLimit_ && + (connectionStack_.size() >= connectionStackLimit_)) { + delete connection; + } else { + connection->checkIdleBufferMemLimit(idleBufferMemLimit_); + connectionStack_.push(connection); + } +} + +/** + * Server socket had something happen. We accept all waiting client + * connections on fd and assign TConnection objects to handle those requests. + */ +void TNonblockingServer::handleEvent(int fd, short which) { + // Make sure that libevent didn't fuck up the socket handles + assert(fd == serverSocket_); + + // Server socket accepted a new connection + socklen_t addrLen; + struct sockaddr addr; + addrLen = sizeof(addr); + + // Going to accept a new client socket + int clientSocket; + + // Accept as many new clients as possible, even though libevent signaled only + // one, this helps us to avoid having to go back into the libevent engine so + // many times + while ((clientSocket = accept(fd, &addr, &addrLen)) != -1) { + + // Explicitly set this socket to NONBLOCK mode + int flags; + if ((flags = fcntl(clientSocket, F_GETFL, 0)) < 0 || + fcntl(clientSocket, F_SETFL, flags | O_NONBLOCK) < 0) { + GlobalOutput.perror("thriftServerEventHandler: set O_NONBLOCK (fcntl) ", errno); + close(clientSocket); + return; + } + + // Create a new TConnection for this client socket. + TConnection* clientConnection = + createConnection(clientSocket, EV_READ | EV_PERSIST); + + // Fail fast if we could not create a TConnection object + if (clientConnection == NULL) { + GlobalOutput.printf("thriftServerEventHandler: failed TConnection factory"); + close(clientSocket); + return; + } + + // Put this client connection into the proper state + clientConnection->transition(); + } + + // Done looping accept, now we have to make sure the error is due to + // blocking. Any other error is a problem + if (errno != EAGAIN && errno != EWOULDBLOCK) { + GlobalOutput.perror("thriftServerEventHandler: accept() ", errno); + } +} + +/** + * Creates a socket to listen on and binds it to the local port. + */ +void TNonblockingServer::listenSocket() { + int s; + struct addrinfo hints, *res, *res0; + int error; + + char port[sizeof("65536") + 1]; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + // Wildcard address + error = getaddrinfo(NULL, port, &hints, &res0); + if (error) { + string errStr = "TNonblockingServer::serve() getaddrinfo " + string(gai_strerror(error)); + GlobalOutput(errStr.c_str()); + return; + } + + // Pick the ipv6 address first since ipv4 addresses can be mapped + // into ipv6 space. + for (res = res0; res; res = res->ai_next) { + if (res->ai_family == AF_INET6 || res->ai_next == NULL) + break; + } + + // Create the server socket + s = socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (s == -1) { + freeaddrinfo(res0); + throw TException("TNonblockingServer::serve() socket() -1"); + } + + #ifdef IPV6_V6ONLY + int zero = 0; + if (-1 == setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &zero, sizeof(zero))) { + GlobalOutput("TServerSocket::listen() IPV6_V6ONLY"); + } + #endif // #ifdef IPV6_V6ONLY + + + int one = 1; + + // Set reuseaddr to avoid 2MSL delay on server restart + setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + + if (bind(s, res->ai_addr, res->ai_addrlen) == -1) { + close(s); + freeaddrinfo(res0); + throw TException("TNonblockingServer::serve() bind"); + } + + // Done with the addr info + freeaddrinfo(res0); + + // Set up this file descriptor for listening + listenSocket(s); +} + +/** + * Takes a socket created by listenSocket() and sets various options on it + * to prepare for use in the server. + */ +void TNonblockingServer::listenSocket(int s) { + // Set socket to nonblocking mode + int flags; + if ((flags = fcntl(s, F_GETFL, 0)) < 0 || + fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) { + close(s); + throw TException("TNonblockingServer::serve() O_NONBLOCK"); + } + + int one = 1; + struct linger ling = {0, 0}; + + // Keepalive to ensure full result flushing + setsockopt(s, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one)); + + // Turn linger off to avoid hung sockets + setsockopt(s, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)); + + // Set TCP nodelay if available, MAC OS X Hack + // See http://lists.danga.com/pipermail/memcached/2005-March/001240.html + #ifndef TCP_NOPUSH + setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); + #endif + + if (listen(s, LISTEN_BACKLOG) == -1) { + close(s); + throw TException("TNonblockingServer::serve() listen"); + } + + // Cool, this socket is good to go, set it as the serverSocket_ + serverSocket_ = s; +} + +/** + * Register the core libevent events onto the proper base. + */ +void TNonblockingServer::registerEvents(event_base* base) { + assert(serverSocket_ != -1); + assert(!eventBase_); + eventBase_ = base; + + // Print some libevent stats + GlobalOutput.printf("libevent %s method %s", + event_get_version(), + event_get_method()); + + // Register the server event + event_set(&serverEvent_, + serverSocket_, + EV_READ | EV_PERSIST, + TNonblockingServer::eventHandler, + this); + event_base_set(eventBase_, &serverEvent_); + + // Add the event and start up the server + if (-1 == event_add(&serverEvent_, 0)) { + throw TException("TNonblockingServer::serve(): coult not event_add"); + } +} + +/** + * Main workhorse function, starts up the server listening on a port and + * loops over the libevent handler. + */ +void TNonblockingServer::serve() { + // Init socket + listenSocket(); + + // Initialize libevent core + registerEvents(static_cast<event_base*>(event_init())); + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + // Run libevent engine, never returns, invokes calls to eventHandler + event_base_loop(eventBase_, 0); +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TNonblockingServer.h b/lib/cpp/src/server/TNonblockingServer.h new file mode 100644 index 000000000..1684b64a0 --- /dev/null +++ b/lib/cpp/src/server/TNonblockingServer.h @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TNONBLOCKINGSERVER_H_ +#define _THRIFT_SERVER_TNONBLOCKINGSERVER_H_ 1 + +#include <Thrift.h> +#include <server/TServer.h> +#include <transport/TBufferTransports.h> +#include <concurrency/ThreadManager.h> +#include <stack> +#include <string> +#include <errno.h> +#include <cstdlib> +#include <event.h> + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::transport::TMemoryBuffer; +using apache::thrift::protocol::TProtocol; +using apache::thrift::concurrency::Runnable; +using apache::thrift::concurrency::ThreadManager; + +// Forward declaration of class +class TConnection; + +/** + * This is a non-blocking server in C++ for high performance that operates a + * single IO thread. It assumes that all incoming requests are framed with a + * 4 byte length indicator and writes out responses using the same framing. + * + * It does not use the TServerTransport framework, but rather has socket + * operations hardcoded for use with select. + * + */ +class TNonblockingServer : public TServer { + private: + + // Listen backlog + static const int LISTEN_BACKLOG = 1024; + + // Default limit on size of idle connection pool + static const size_t CONNECTION_STACK_LIMIT = 1024; + + // Maximum size of buffer allocated to idle connection + static const uint32_t IDLE_BUFFER_MEM_LIMIT = 8192; + + // Server socket file descriptor + int serverSocket_; + + // Port server runs on + int port_; + + // For processing via thread pool, may be NULL + boost::shared_ptr<ThreadManager> threadManager_; + + // Is thread pool processing? + bool threadPoolProcessing_; + + // The event base for libevent + event_base* eventBase_; + + // Event struct, for use with eventBase_ + struct event serverEvent_; + + // Number of TConnection object we've created + size_t numTConnections_; + + // Limit for how many TConnection objects to cache + size_t connectionStackLimit_; + + /** + * Max read buffer size for an idle connection. When we place an idle + * TConnection into connectionStack_, we insure that its read buffer is + * reduced to this size to insure that idle connections don't hog memory. + */ + uint32_t idleBufferMemLimit_; + + /** + * This is a stack of all the objects that have been created but that + * are NOT currently in use. When we close a connection, we place it on this + * stack so that the object can be reused later, rather than freeing the + * memory and reallocating a new object later. + */ + std::stack<TConnection*> connectionStack_; + + void handleEvent(int fd, short which); + + public: + TNonblockingServer(boost::shared_ptr<TProcessor> processor, + int port) : + TServer(processor), + serverSocket_(-1), + port_(port), + threadPoolProcessing_(false), + eventBase_(NULL), + numTConnections_(0), + connectionStackLimit_(CONNECTION_STACK_LIMIT), + idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) {} + + TNonblockingServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocolFactory> protocolFactory, + int port, + boost::shared_ptr<ThreadManager> threadManager = boost::shared_ptr<ThreadManager>()) : + TServer(processor), + serverSocket_(-1), + port_(port), + threadManager_(threadManager), + eventBase_(NULL), + numTConnections_(0), + connectionStackLimit_(CONNECTION_STACK_LIMIT), + idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) { + setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setInputProtocolFactory(protocolFactory); + setOutputProtocolFactory(protocolFactory); + setThreadManager(threadManager); + } + + TNonblockingServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TTransportFactory> inputTransportFactory, + boost::shared_ptr<TTransportFactory> outputTransportFactory, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory, + int port, + boost::shared_ptr<ThreadManager> threadManager = boost::shared_ptr<ThreadManager>()) : + TServer(processor), + serverSocket_(0), + port_(port), + threadManager_(threadManager), + eventBase_(NULL), + numTConnections_(0), + connectionStackLimit_(CONNECTION_STACK_LIMIT), + idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) { + setInputTransportFactory(inputTransportFactory); + setOutputTransportFactory(outputTransportFactory); + setInputProtocolFactory(inputProtocolFactory); + setOutputProtocolFactory(outputProtocolFactory); + setThreadManager(threadManager); + } + + ~TNonblockingServer() {} + + void setThreadManager(boost::shared_ptr<ThreadManager> threadManager) { + threadManager_ = threadManager; + threadPoolProcessing_ = (threadManager != NULL); + } + + boost::shared_ptr<ThreadManager> getThreadManager() { + return threadManager_; + } + + /** + * Get the maximum number of unused TConnection we will hold in reserve. + * + * @return the current limit on TConnection pool size. + */ + size_t getConnectionStackLimit() const { + return connectionStackLimit_; + } + + /** + * Set the maximum number of unused TConnection we will hold in reserve. + * + * @param sz the new limit for TConnection pool size. + */ + void setConnectionStackLimit(size_t sz) { + connectionStackLimit_ = sz; + } + + bool isThreadPoolProcessing() const { + return threadPoolProcessing_; + } + + void addTask(boost::shared_ptr<Runnable> task) { + threadManager_->add(task); + } + + event_base* getEventBase() const { + return eventBase_; + } + + void incrementNumConnections() { + ++numTConnections_; + } + + void decrementNumConnections() { + --numTConnections_; + } + + size_t getNumConnections() { + return numTConnections_; + } + + size_t getNumIdleConnections() { + return connectionStack_.size(); + } + + /** + * Get the maximum limit of memory allocated to idle TConnection objects. + * + * @return # bytes beyond which we will shrink buffers when idle. + */ + size_t getIdleBufferMemLimit() const { + return idleBufferMemLimit_; + } + + /** + * Set the maximum limit of memory allocated to idle TConnection objects. + * If a TConnection object goes idle with more than this much memory + * allocated to its buffer, we shrink it to this value. + * + * @param limit of bytes beyond which we will shrink buffers when idle. + */ + void setIdleBufferMemLimit(size_t limit) { + idleBufferMemLimit_ = limit; + } + + TConnection* createConnection(int socket, short flags); + + void returnConnection(TConnection* connection); + + static void eventHandler(int fd, short which, void* v) { + ((TNonblockingServer*)v)->handleEvent(fd, which); + } + + void listenSocket(); + + void listenSocket(int fd); + + void registerEvents(event_base* base); + + void serve(); +}; + +/** + * Two states for sockets, recv and send mode + */ +enum TSocketState { + SOCKET_RECV, + SOCKET_SEND +}; + +/** + * Four states for the nonblocking servr: + * 1) initialize + * 2) read 4 byte frame size + * 3) read frame of data + * 4) send back data (if any) + */ +enum TAppState { + APP_INIT, + APP_READ_FRAME_SIZE, + APP_READ_REQUEST, + APP_WAIT_TASK, + APP_SEND_RESULT +}; + +/** + * Represents a connection that is handled via libevent. This connection + * essentially encapsulates a socket that has some associated libevent state. + */ +class TConnection { + private: + + class Task; + + // Server handle + TNonblockingServer* server_; + + // Socket handle + int socket_; + + // Libevent object + struct event event_; + + // Libevent flags + short eventFlags_; + + // Socket mode + TSocketState socketState_; + + // Application state + TAppState appState_; + + // How much data needed to read + uint32_t readWant_; + + // Where in the read buffer are we + uint32_t readBufferPos_; + + // Read buffer + uint8_t* readBuffer_; + + // Read buffer size + uint32_t readBufferSize_; + + // Write buffer + uint8_t* writeBuffer_; + + // Write buffer size + uint32_t writeBufferSize_; + + // How far through writing are we? + uint32_t writeBufferPos_; + + // How many times have we read since our last buffer reset? + uint32_t numReadsSinceReset_; + + // How many times have we written since our last buffer reset? + uint32_t numWritesSinceReset_; + + // Task handle + int taskHandle_; + + // Task event + struct event taskEvent_; + + // Transport to read from + boost::shared_ptr<TMemoryBuffer> inputTransport_; + + // Transport that processor writes to + boost::shared_ptr<TMemoryBuffer> outputTransport_; + + // extra transport generated by transport factory (e.g. BufferedRouterTransport) + boost::shared_ptr<TTransport> factoryInputTransport_; + boost::shared_ptr<TTransport> factoryOutputTransport_; + + // Protocol decoder + boost::shared_ptr<TProtocol> inputProtocol_; + + // Protocol encoder + boost::shared_ptr<TProtocol> outputProtocol_; + + // Go into read mode + void setRead() { + setFlags(EV_READ | EV_PERSIST); + } + + // Go into write mode + void setWrite() { + setFlags(EV_WRITE | EV_PERSIST); + } + + // Set socket idle + void setIdle() { + setFlags(0); + } + + // Set event flags + void setFlags(short eventFlags); + + // Libevent handlers + void workSocket(); + + // Close this client and reset + void close(); + + public: + + // Constructor + TConnection(int socket, short eventFlags, TNonblockingServer *s) { + readBuffer_ = (uint8_t*)std::malloc(1024); + if (readBuffer_ == NULL) { + throw new apache::thrift::TException("Out of memory."); + } + readBufferSize_ = 1024; + + numReadsSinceReset_ = 0; + numWritesSinceReset_ = 0; + + // Allocate input and output tranpsorts + // these only need to be allocated once per TConnection (they don't need to be + // reallocated on init() call) + inputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer(readBuffer_, readBufferSize_)); + outputTransport_ = boost::shared_ptr<TMemoryBuffer>(new TMemoryBuffer()); + + init(socket, eventFlags, s); + server_->incrementNumConnections(); + } + + ~TConnection() { + server_->decrementNumConnections(); + } + + /** + * Check read buffer against a given limit and shrink it if exceeded. + * + * @param limit we limit buffer size to. + */ + void checkIdleBufferMemLimit(uint32_t limit); + + // Initialize + void init(int socket, short eventFlags, TNonblockingServer *s); + + // Transition into a new state + void transition(); + + // Handler wrapper + static void eventHandler(int fd, short /* which */, void* v) { + assert(fd == ((TConnection*)v)->socket_); + ((TConnection*)v)->workSocket(); + } + + // Handler wrapper for task block + static void taskHandler(int fd, short /* which */, void* v) { + assert(fd == ((TConnection*)v)->taskHandle_); + if (-1 == ::close(((TConnection*)v)->taskHandle_)) { + GlobalOutput.perror("TConnection::taskHandler close handle failed, resource leak ", errno); + } + ((TConnection*)v)->transition(); + } + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ diff --git a/lib/cpp/src/server/TServer.cpp b/lib/cpp/src/server/TServer.cpp new file mode 100644 index 000000000..6b692ab02 --- /dev/null +++ b/lib/cpp/src/server/TServer.cpp @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <sys/time.h> +#include <sys/resource.h> +#include <unistd.h> + +namespace apache { namespace thrift { namespace server { + +int increase_max_fds(int max_fds=(1<<24)) { + struct rlimit fdmaxrl; + + for(fdmaxrl.rlim_cur = max_fds, fdmaxrl.rlim_max = max_fds; + max_fds && (setrlimit(RLIMIT_NOFILE, &fdmaxrl) < 0); + fdmaxrl.rlim_cur = max_fds, fdmaxrl.rlim_max = max_fds) { + max_fds /= 2; + } + + return fdmaxrl.rlim_cur; +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TServer.h b/lib/cpp/src/server/TServer.h new file mode 100644 index 000000000..5c4c588d4 --- /dev/null +++ b/lib/cpp/src/server/TServer.h @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TSERVER_H_ +#define _THRIFT_SERVER_TSERVER_H_ 1 + +#include <TProcessor.h> +#include <transport/TServerTransport.h> +#include <protocol/TBinaryProtocol.h> +#include <concurrency/Thread.h> + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::TProcessor; +using apache::thrift::protocol::TBinaryProtocolFactory; +using apache::thrift::protocol::TProtocol; +using apache::thrift::protocol::TProtocolFactory; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransport; +using apache::thrift::transport::TTransportFactory; + +/** + * Virtual interface class that can handle events from the server core. To + * use this you should subclass it and implement the methods that you care + * about. Your subclass can also store local data that you may care about, + * such as additional "arguments" to these methods (stored in the object + * instance's state). + */ +class TServerEventHandler { + public: + + virtual ~TServerEventHandler() {} + + /** + * Called before the server begins. + */ + virtual void preServe() {} + + /** + * Called when a new client has connected and is about to being processing. + */ + virtual void clientBegin(boost::shared_ptr<TProtocol> /* input */, + boost::shared_ptr<TProtocol> /* output */) {} + + /** + * Called when a client has finished making requests. + */ + virtual void clientEnd(boost::shared_ptr<TProtocol> /* input */, + boost::shared_ptr<TProtocol> /* output */) {} + + protected: + + /** + * Prevent direct instantiation. + */ + TServerEventHandler() {} + +}; + +/** + * Thrift server. + * + */ +class TServer : public concurrency::Runnable { + public: + + virtual ~TServer() {} + + virtual void serve() = 0; + + virtual void stop() {} + + // Allows running the server as a Runnable thread + virtual void run() { + serve(); + } + + boost::shared_ptr<TProcessor> getProcessor() { + return processor_; + } + + boost::shared_ptr<TServerTransport> getServerTransport() { + return serverTransport_; + } + + boost::shared_ptr<TTransportFactory> getInputTransportFactory() { + return inputTransportFactory_; + } + + boost::shared_ptr<TTransportFactory> getOutputTransportFactory() { + return outputTransportFactory_; + } + + boost::shared_ptr<TProtocolFactory> getInputProtocolFactory() { + return inputProtocolFactory_; + } + + boost::shared_ptr<TProtocolFactory> getOutputProtocolFactory() { + return outputProtocolFactory_; + } + + boost::shared_ptr<TServerEventHandler> getEventHandler() { + return eventHandler_; + } + +protected: + TServer(boost::shared_ptr<TProcessor> processor): + processor_(processor) { + setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setInputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory())); + setOutputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory())); + } + + TServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport): + processor_(processor), + serverTransport_(serverTransport) { + setInputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setOutputTransportFactory(boost::shared_ptr<TTransportFactory>(new TTransportFactory())); + setInputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory())); + setOutputProtocolFactory(boost::shared_ptr<TProtocolFactory>(new TBinaryProtocolFactory())); + } + + TServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory): + processor_(processor), + serverTransport_(serverTransport), + inputTransportFactory_(transportFactory), + outputTransportFactory_(transportFactory), + inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory) {} + + TServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> inputTransportFactory, + boost::shared_ptr<TTransportFactory> outputTransportFactory, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory): + processor_(processor), + serverTransport_(serverTransport), + inputTransportFactory_(inputTransportFactory), + outputTransportFactory_(outputTransportFactory), + inputProtocolFactory_(inputProtocolFactory), + outputProtocolFactory_(outputProtocolFactory) {} + + + // Class variables + boost::shared_ptr<TProcessor> processor_; + boost::shared_ptr<TServerTransport> serverTransport_; + + boost::shared_ptr<TTransportFactory> inputTransportFactory_; + boost::shared_ptr<TTransportFactory> outputTransportFactory_; + + boost::shared_ptr<TProtocolFactory> inputProtocolFactory_; + boost::shared_ptr<TProtocolFactory> outputProtocolFactory_; + + boost::shared_ptr<TServerEventHandler> eventHandler_; + +public: + void setInputTransportFactory(boost::shared_ptr<TTransportFactory> inputTransportFactory) { + inputTransportFactory_ = inputTransportFactory; + } + + void setOutputTransportFactory(boost::shared_ptr<TTransportFactory> outputTransportFactory) { + outputTransportFactory_ = outputTransportFactory; + } + + void setInputProtocolFactory(boost::shared_ptr<TProtocolFactory> inputProtocolFactory) { + inputProtocolFactory_ = inputProtocolFactory; + } + + void setOutputProtocolFactory(boost::shared_ptr<TProtocolFactory> outputProtocolFactory) { + outputProtocolFactory_ = outputProtocolFactory; + } + + void setServerEventHandler(boost::shared_ptr<TServerEventHandler> eventHandler) { + eventHandler_ = eventHandler; + } + +}; + +/** + * Helper function to increase the max file descriptors limit + * for the current process and all of its children. + * By default, tries to increase it to as much as 2^24. + */ + int increase_max_fds(int max_fds=(1<<24)); + + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TSERVER_H_ diff --git a/lib/cpp/src/server/TSimpleServer.cpp b/lib/cpp/src/server/TSimpleServer.cpp new file mode 100644 index 000000000..394ce21e2 --- /dev/null +++ b/lib/cpp/src/server/TSimpleServer.cpp @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "server/TSimpleServer.h" +#include "transport/TTransportException.h" +#include <string> +#include <iostream> + +namespace apache { namespace thrift { namespace server { + +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using boost::shared_ptr; + +/** + * A simple single-threaded application server. Perfect for unit tests! + * + */ +void TSimpleServer::serve() { + + shared_ptr<TTransport> client; + shared_ptr<TTransport> inputTransport; + shared_ptr<TTransport> outputTransport; + shared_ptr<TProtocol> inputProtocol; + shared_ptr<TProtocol> outputProtocol; + + try { + // Start the server listening + serverTransport_->listen(); + } catch (TTransportException& ttx) { + cerr << "TSimpleServer::run() listen(): " << ttx.what() << endl; + return; + } + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + // Fetch client from server + while (!stop_) { + try { + client = serverTransport_->accept(); + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_->getProtocol(outputTransport); + if (eventHandler_ != NULL) { + eventHandler_->clientBegin(inputProtocol, outputProtocol); + } + try { + while (processor_->process(inputProtocol, outputProtocol)) { + // Peek ahead, is the remote side closed? + if (!inputTransport->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + cerr << "TSimpleServer client died: " << ttx.what() << endl; + } catch (TException& tx) { + cerr << "TSimpleServer exception: " << tx.what() << endl; + } + if (eventHandler_ != NULL) { + eventHandler_->clientEnd(inputProtocol, outputProtocol); + } + inputTransport->close(); + outputTransport->close(); + client->close(); + } catch (TTransportException& ttx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + cerr << "TServerTransport died on accept: " << ttx.what() << endl; + continue; + } catch (TException& tx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + cerr << "Some kind of accept exception: " << tx.what() << endl; + continue; + } catch (string s) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + cerr << "TThreadPoolServer: Unknown exception: " << s << endl; + break; + } + } + + if (stop_) { + try { + serverTransport_->close(); + } catch (TTransportException &ttx) { + cerr << "TServerTransport failed on close: " << ttx.what() << endl; + } + stop_ = false; + } +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TSimpleServer.h b/lib/cpp/src/server/TSimpleServer.h new file mode 100644 index 000000000..c4fc91c78 --- /dev/null +++ b/lib/cpp/src/server/TSimpleServer.h @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ +#define _THRIFT_SERVER_TSIMPLESERVER_H_ 1 + +#include "server/TServer.h" +#include "transport/TServerTransport.h" + +namespace apache { namespace thrift { namespace server { + +/** + * This is the most basic simple server. It is single-threaded and runs a + * continuous loop of accepting a single connection, processing requests on + * that connection until it closes, and then repeating. It is a good example + * of how to extend the TServer interface. + * + */ +class TSimpleServer : public TServer { + public: + TSimpleServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory) : + TServer(processor, serverTransport, transportFactory, protocolFactory), + stop_(false) {} + + TSimpleServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> inputTransportFactory, + boost::shared_ptr<TTransportFactory> outputTransportFactory, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory): + TServer(processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory), + stop_(false) {} + + ~TSimpleServer() {} + + void serve(); + + void stop() { + stop_ = true; + } + + protected: + bool stop_; + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ diff --git a/lib/cpp/src/server/TThreadPoolServer.cpp b/lib/cpp/src/server/TThreadPoolServer.cpp new file mode 100644 index 000000000..0894cfa5f --- /dev/null +++ b/lib/cpp/src/server/TThreadPoolServer.cpp @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "server/TThreadPoolServer.h" +#include "transport/TTransportException.h" +#include "concurrency/Thread.h" +#include "concurrency/ThreadManager.h" +#include <string> +#include <iostream> + +namespace apache { namespace thrift { namespace server { + +using boost::shared_ptr; +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::concurrency; +using namespace apache::thrift::protocol;; +using namespace apache::thrift::transport; + +class TThreadPoolServer::Task : public Runnable { + +public: + + Task(TThreadPoolServer &server, + shared_ptr<TProcessor> processor, + shared_ptr<TProtocol> input, + shared_ptr<TProtocol> output) : + server_(server), + processor_(processor), + input_(input), + output_(output) { + } + + ~Task() {} + + void run() { + boost::shared_ptr<TServerEventHandler> eventHandler = + server_.getEventHandler(); + if (eventHandler != NULL) { + eventHandler->clientBegin(input_, output_); + } + try { + while (processor_->process(input_, output_)) { + if (!input_->getTransport()->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + // This is reasonably expected, client didn't send a full request so just + // ignore him + // string errStr = string("TThreadPoolServer client died: ") + ttx.what(); + // GlobalOutput(errStr.c_str()); + } catch (TException& x) { + string errStr = string("TThreadPoolServer exception: ") + x.what(); + GlobalOutput(errStr.c_str()); + } catch (std::exception &x) { + string errStr = string("TThreadPoolServer, std::exception: ") + x.what(); + GlobalOutput(errStr.c_str()); + } + + if (eventHandler != NULL) { + eventHandler->clientEnd(input_, output_); + } + + try { + input_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadPoolServer input close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + try { + output_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadPoolServer output close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + + } + + private: + TServer& server_; + shared_ptr<TProcessor> processor_; + shared_ptr<TProtocol> input_; + shared_ptr<TProtocol> output_; + +}; + +TThreadPoolServer::TThreadPoolServer(shared_ptr<TProcessor> processor, + shared_ptr<TServerTransport> serverTransport, + shared_ptr<TTransportFactory> transportFactory, + shared_ptr<TProtocolFactory> protocolFactory, + shared_ptr<ThreadManager> threadManager) : + TServer(processor, serverTransport, transportFactory, protocolFactory), + threadManager_(threadManager), + stop_(false), timeout_(0) {} + +TThreadPoolServer::TThreadPoolServer(shared_ptr<TProcessor> processor, + shared_ptr<TServerTransport> serverTransport, + shared_ptr<TTransportFactory> inputTransportFactory, + shared_ptr<TTransportFactory> outputTransportFactory, + shared_ptr<TProtocolFactory> inputProtocolFactory, + shared_ptr<TProtocolFactory> outputProtocolFactory, + shared_ptr<ThreadManager> threadManager) : + TServer(processor, serverTransport, inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory), + threadManager_(threadManager), + stop_(false), timeout_(0) {} + + +TThreadPoolServer::~TThreadPoolServer() {} + +void TThreadPoolServer::serve() { + shared_ptr<TTransport> client; + shared_ptr<TTransport> inputTransport; + shared_ptr<TTransport> outputTransport; + shared_ptr<TProtocol> inputProtocol; + shared_ptr<TProtocol> outputProtocol; + + try { + // Start the server listening + serverTransport_->listen(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadPoolServer::run() listen(): ") + ttx.what(); + GlobalOutput(errStr.c_str()); + return; + } + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + while (!stop_) { + try { + client.reset(); + inputTransport.reset(); + outputTransport.reset(); + inputProtocol.reset(); + outputProtocol.reset(); + + // Fetch client from server + client = serverTransport_->accept(); + + // Make IO transports + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_->getProtocol(outputTransport); + + // Add to threadmanager pool + threadManager_->add(shared_ptr<TThreadPoolServer::Task>(new TThreadPoolServer::Task(*this, processor_, inputProtocol, outputProtocol)), timeout_); + + } catch (TTransportException& ttx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) { + string errStr = string("TThreadPoolServer: TServerTransport died on accept: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + continue; + } catch (TException& tx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = string("TThreadPoolServer: Caught TException: ") + tx.what(); + GlobalOutput(errStr.c_str()); + continue; + } catch (string s) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = "TThreadPoolServer: Unknown exception: " + s; + GlobalOutput(errStr.c_str()); + break; + } + } + + // If stopped manually, join the existing threads + if (stop_) { + try { + serverTransport_->close(); + threadManager_->join(); + } catch (TException &tx) { + string errStr = string("TThreadPoolServer: Exception shutting down: ") + tx.what(); + GlobalOutput(errStr.c_str()); + } + stop_ = false; + } + +} + +int64_t TThreadPoolServer::getTimeout() const { + return timeout_; +} + +void TThreadPoolServer::setTimeout(int64_t value) { + timeout_ = value; +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TThreadPoolServer.h b/lib/cpp/src/server/TThreadPoolServer.h new file mode 100644 index 000000000..7b7e90647 --- /dev/null +++ b/lib/cpp/src/server/TThreadPoolServer.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_ +#define _THRIFT_SERVER_TTHREADPOOLSERVER_H_ 1 + +#include <concurrency/ThreadManager.h> +#include <server/TServer.h> +#include <transport/TServerTransport.h> + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::concurrency::ThreadManager; +using apache::thrift::protocol::TProtocolFactory; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransportFactory; + +class TThreadPoolServer : public TServer { + public: + class Task; + + TThreadPoolServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<ThreadManager> threadManager); + + TThreadPoolServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> inputTransportFactory, + boost::shared_ptr<TTransportFactory> outputTransportFactory, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory, + boost::shared_ptr<ThreadManager> threadManager); + + virtual ~TThreadPoolServer(); + + virtual void serve(); + + virtual int64_t getTimeout() const; + + virtual void setTimeout(int64_t value); + + virtual void stop() { + stop_ = true; + serverTransport_->interrupt(); + } + + protected: + + boost::shared_ptr<ThreadManager> threadManager_; + + volatile bool stop_; + + volatile int64_t timeout_; + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_ diff --git a/lib/cpp/src/server/TThreadedServer.cpp b/lib/cpp/src/server/TThreadedServer.cpp new file mode 100644 index 000000000..cc30f8ff7 --- /dev/null +++ b/lib/cpp/src/server/TThreadedServer.cpp @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "server/TThreadedServer.h" +#include "transport/TTransportException.h" +#include "concurrency/PosixThreadFactory.h" + +#include <string> +#include <iostream> +#include <pthread.h> +#include <unistd.h> + +namespace apache { namespace thrift { namespace server { + +using boost::shared_ptr; +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::concurrency; + +class TThreadedServer::Task: public Runnable { + +public: + + Task(TThreadedServer& server, + shared_ptr<TProcessor> processor, + shared_ptr<TProtocol> input, + shared_ptr<TProtocol> output) : + server_(server), + processor_(processor), + input_(input), + output_(output) { + } + + ~Task() {} + + void run() { + boost::shared_ptr<TServerEventHandler> eventHandler = + server_.getEventHandler(); + if (eventHandler != NULL) { + eventHandler->clientBegin(input_, output_); + } + try { + while (processor_->process(input_, output_)) { + if (!input_->getTransport()->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer client died: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } catch (TException& x) { + string errStr = string("TThreadedServer exception: ") + x.what(); + GlobalOutput(errStr.c_str()); + } catch (...) { + GlobalOutput("TThreadedServer uncaught exception."); + } + if (eventHandler != NULL) { + eventHandler->clientEnd(input_, output_); + } + + try { + input_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer input close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + try { + output_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer output close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + + // Remove this task from parent bookkeeping + { + Synchronized s(server_.tasksMonitor_); + server_.tasks_.erase(this); + if (server_.tasks_.empty()) { + server_.tasksMonitor_.notify(); + } + } + + } + + private: + TThreadedServer& server_; + friend class TThreadedServer; + + shared_ptr<TProcessor> processor_; + shared_ptr<TProtocol> input_; + shared_ptr<TProtocol> output_; +}; + + +TThreadedServer::TThreadedServer(shared_ptr<TProcessor> processor, + shared_ptr<TServerTransport> serverTransport, + shared_ptr<TTransportFactory> transportFactory, + shared_ptr<TProtocolFactory> protocolFactory): + TServer(processor, serverTransport, transportFactory, protocolFactory), + stop_(false) { + threadFactory_ = shared_ptr<PosixThreadFactory>(new PosixThreadFactory()); +} + +TThreadedServer::TThreadedServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<ThreadFactory> threadFactory): + TServer(processor, serverTransport, transportFactory, protocolFactory), + threadFactory_(threadFactory), + stop_(false) { +} + +TThreadedServer::~TThreadedServer() {} + +void TThreadedServer::serve() { + + shared_ptr<TTransport> client; + shared_ptr<TTransport> inputTransport; + shared_ptr<TTransport> outputTransport; + shared_ptr<TProtocol> inputProtocol; + shared_ptr<TProtocol> outputProtocol; + + try { + // Start the server listening + serverTransport_->listen(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer::run() listen(): ") +ttx.what(); + GlobalOutput(errStr.c_str()); + return; + } + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + while (!stop_) { + try { + client.reset(); + inputTransport.reset(); + outputTransport.reset(); + inputProtocol.reset(); + outputProtocol.reset(); + + // Fetch client from server + client = serverTransport_->accept(); + + // Make IO transports + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_->getProtocol(outputTransport); + + TThreadedServer::Task* task = new TThreadedServer::Task(*this, + processor_, + inputProtocol, + outputProtocol); + + // Create a task + shared_ptr<Runnable> runnable = + shared_ptr<Runnable>(task); + + // Create a thread for this task + shared_ptr<Thread> thread = + shared_ptr<Thread>(threadFactory_->newThread(runnable)); + + // Insert thread into the set of threads + { + Synchronized s(tasksMonitor_); + tasks_.insert(task); + } + + // Start the thread! + thread->start(); + + } catch (TTransportException& ttx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) { + string errStr = string("TThreadedServer: TServerTransport died on accept: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + continue; + } catch (TException& tx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = string("TThreadedServer: Caught TException: ") + tx.what(); + GlobalOutput(errStr.c_str()); + continue; + } catch (string s) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = "TThreadedServer: Unknown exception: " + s; + GlobalOutput(errStr.c_str()); + break; + } + } + + // If stopped manually, make sure to close server transport + if (stop_) { + try { + serverTransport_->close(); + } catch (TException &tx) { + string errStr = string("TThreadedServer: Exception shutting down: ") + tx.what(); + GlobalOutput(errStr.c_str()); + } + try { + Synchronized s(tasksMonitor_); + while (!tasks_.empty()) { + tasksMonitor_.wait(); + } + } catch (TException &tx) { + string errStr = string("TThreadedServer: Exception joining workers: ") + tx.what(); + GlobalOutput(errStr.c_str()); + } + stop_ = false; + } + +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TThreadedServer.h b/lib/cpp/src/server/TThreadedServer.h new file mode 100644 index 000000000..4d0811aaa --- /dev/null +++ b/lib/cpp/src/server/TThreadedServer.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TTHREADEDSERVER_H_ +#define _THRIFT_SERVER_TTHREADEDSERVER_H_ 1 + +#include <server/TServer.h> +#include <transport/TServerTransport.h> +#include <concurrency/Monitor.h> +#include <concurrency/Thread.h> + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::TProcessor; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransportFactory; +using apache::thrift::concurrency::Monitor; +using apache::thrift::concurrency::ThreadFactory; + +class TThreadedServer : public TServer { + + public: + class Task; + + TThreadedServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory); + + TThreadedServer(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TServerTransport> serverTransport, + boost::shared_ptr<TTransportFactory> transportFactory, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<ThreadFactory> threadFactory); + + virtual ~TThreadedServer(); + + virtual void serve(); + + void stop() { + stop_ = true; + serverTransport_->interrupt(); + } + + protected: + boost::shared_ptr<ThreadFactory> threadFactory_; + volatile bool stop_; + + Monitor tasksMonitor_; + std::set<Task*> tasks_; + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TTHREADEDSERVER_H_ diff --git a/lib/cpp/src/transport/TBufferTransports.cpp b/lib/cpp/src/transport/TBufferTransports.cpp new file mode 100644 index 000000000..7a7e5e928 --- /dev/null +++ b/lib/cpp/src/transport/TBufferTransports.cpp @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cassert> +#include <algorithm> + +#include <transport/TBufferTransports.h> + +using std::string; + +namespace apache { namespace thrift { namespace transport { + + +uint32_t TBufferedTransport::readSlow(uint8_t* buf, uint32_t len) { + uint32_t want = len; + uint32_t have = rBound_ - rBase_; + + // We should only take the slow path if we can't satisfy the read + // with the data already in the buffer. + assert(have < want); + + // Copy out whatever we have. + if (have > 0) { + memcpy(buf, rBase_, have); + want -= have; + buf += have; + } + // Get more from underlying transport up to buffer size. + // Note that this makes a lot of sense if len < rBufSize_ + // and almost no sense otherwise. TODO(dreiss): Fix that + // case (possibly including some readv hotness). + setReadBuffer(rBuf_.get(), transport_->read(rBuf_.get(), rBufSize_)); + + // Hand over whatever we have. + uint32_t give = std::min(want, static_cast<uint32_t>(rBound_ - rBase_)); + memcpy(buf, rBase_, give); + rBase_ += give; + want -= give; + + return (len - want); +} + +void TBufferedTransport::writeSlow(const uint8_t* buf, uint32_t len) { + uint32_t have_bytes = wBase_ - wBuf_.get(); + uint32_t space = wBound_ - wBase_; + // We should only take the slow path if we can't accomodate the write + // with the free space already in the buffer. + assert(wBound_ - wBase_ < static_cast<ptrdiff_t>(len)); + + // Now here's the tricky question: should we copy data from buf into our + // internal buffer and write it from there, or should we just write out + // the current internal buffer in one syscall and write out buf in another. + // If our currently buffered data plus buf is at least double our buffer + // size, we will have to do two syscalls no matter what (except in the + // degenerate case when our buffer is empty), so there is no use copying. + // Otherwise, there is sort of a sliding scale. If we have N-1 bytes + // buffered and need to write 2, it would be crazy to do two syscalls. + // On the other hand, if we have 2 bytes buffered and are writing 2N-3, + // we can save a syscall in the short term by loading up our buffer, writing + // it out, and copying the rest of the bytes into our buffer. Of course, + // if we get another 2-byte write, we haven't saved any syscalls at all, + // and have just copied nearly 2N bytes for nothing. Finding a perfect + // policy would require predicting the size of future writes, so we're just + // going to always eschew syscalls if we have less than 2N bytes to write. + + // The case where we have to do two syscalls. + // This case also covers the case where the buffer is empty, + // but it is clearer (I think) to think of it as two separate cases. + if ((have_bytes + len >= 2*wBufSize_) || (have_bytes == 0)) { + // TODO(dreiss): writev + if (have_bytes > 0) { + transport_->write(wBuf_.get(), have_bytes); + } + transport_->write(buf, len); + wBase_ = wBuf_.get(); + return; + } + + // Fill up our internal buffer for a write. + memcpy(wBase_, buf, space); + buf += space; + len -= space; + transport_->write(wBuf_.get(), wBufSize_); + + // Copy the rest into our buffer. + assert(len < wBufSize_); + memcpy(wBuf_.get(), buf, len); + wBase_ = wBuf_.get() + len; + return; +} + +const uint8_t* TBufferedTransport::borrowSlow(uint8_t* buf, uint32_t* len) { + // If the request is bigger than our buffer, we are hosed. + if (*len > rBufSize_) { + return NULL; + } + + // The number of bytes of data we have already. + uint32_t have = rBound_ - rBase_; + // The number of additional bytes we need from the underlying transport. + int32_t need = *len - have; + // The space from the start of the buffer to the end of our data. + uint32_t offset = rBound_ - rBuf_.get(); + assert(need > 0); + + // If we have less than half our buffer space available, shift the data + // we have down to the start. If the borrow is big compared to our buffer, + // this could be kind of a waste, but if the borrow is small, it frees up + // space at the end of our buffer to do a bigger single read from the + // underlying transport. Also, if our needs extend past the end of the + // buffer, we have to do a copy no matter what. + if ((offset > rBufSize_/2) || (offset + need > rBufSize_)) { + memmove(rBuf_.get(), rBase_, have); + setReadBuffer(rBuf_.get(), have); + } + + // First try to fill up the buffer. + uint32_t got = transport_->read(rBound_, rBufSize_ - have); + rBound_ += got; + need -= got; + + // If that fails, readAll until we get what we need. + if (need > 0) { + rBound_ += transport_->readAll(rBound_, need); + } + + *len = rBound_ - rBase_; + return rBase_; +} + +void TBufferedTransport::flush() { + // Write out any data waiting in the write buffer. + uint32_t have_bytes = wBase_ - wBuf_.get(); + if (have_bytes > 0) { + // Note that we reset wBase_ prior to the underlying write + // to ensure we're in a sane state (i.e. internal buffer cleaned) + // if the underlying write throws up an exception + wBase_ = wBuf_.get(); + transport_->write(wBuf_.get(), have_bytes); + } + + // Flush the underlying transport. + transport_->flush(); +} + + +uint32_t TFramedTransport::readSlow(uint8_t* buf, uint32_t len) { + uint32_t want = len; + uint32_t have = rBound_ - rBase_; + + // We should only take the slow path if we can't satisfy the read + // with the data already in the buffer. + assert(have < want); + + // Copy out whatever we have. + if (have > 0) { + memcpy(buf, rBase_, have); + want -= have; + buf += have; + } + + // Read another frame. + readFrame(); + + // TODO(dreiss): Should we warn when reads cross frames? + + // Hand over whatever we have. + uint32_t give = std::min(want, static_cast<uint32_t>(rBound_ - rBase_)); + memcpy(buf, rBase_, give); + rBase_ += give; + want -= give; + + return (len - want); +} + +void TFramedTransport::readFrame() { + // TODO(dreiss): Think about using readv here, even though it would + // result in (gasp) read-ahead. + + // Read the size of the next frame. + int32_t sz; + transport_->readAll((uint8_t*)&sz, sizeof(sz)); + sz = ntohl(sz); + + if (sz < 0) { + throw TTransportException("Frame size has negative value"); + } + + // Read the frame payload, and reset markers. + if (sz > static_cast<int32_t>(rBufSize_)) { + rBuf_.reset(new uint8_t[sz]); + rBufSize_ = sz; + } + transport_->readAll(rBuf_.get(), sz); + setReadBuffer(rBuf_.get(), sz); +} + +void TFramedTransport::writeSlow(const uint8_t* buf, uint32_t len) { + // Double buffer size until sufficient. + uint32_t have = wBase_ - wBuf_.get(); + while (wBufSize_ < len + have) { + wBufSize_ *= 2; + } + + // TODO(dreiss): Consider modifying this class to use malloc/free + // so we can use realloc here. + + // Allocate new buffer. + uint8_t* new_buf = new uint8_t[wBufSize_]; + + // Copy the old buffer to the new one. + memcpy(new_buf, wBuf_.get(), have); + + // Now point buf to the new one. + wBuf_.reset(new_buf); + wBase_ = wBuf_.get() + have; + wBound_ = wBuf_.get() + wBufSize_; + + // Copy the data into the new buffer. + memcpy(wBase_, buf, len); + wBase_ += len; +} + +void TFramedTransport::flush() { + int32_t sz_hbo, sz_nbo; + assert(wBufSize_ > sizeof(sz_nbo)); + + // Slip the frame size into the start of the buffer. + sz_hbo = wBase_ - (wBuf_.get() + sizeof(sz_nbo)); + sz_nbo = (int32_t)htonl((uint32_t)(sz_hbo)); + memcpy(wBuf_.get(), (uint8_t*)&sz_nbo, sizeof(sz_nbo)); + + if (sz_hbo > 0) { + // Note that we reset wBase_ (with a pad for the frame size) + // prior to the underlying write to ensure we're in a sane state + // (i.e. internal buffer cleaned) if the underlying write throws + // up an exception + wBase_ = wBuf_.get() + sizeof(sz_nbo); + + // Write size and frame body. + transport_->write(wBuf_.get(), sizeof(sz_nbo)+sz_hbo); + } + + // Flush the underlying transport. + transport_->flush(); +} + +const uint8_t* TFramedTransport::borrowSlow(uint8_t* buf, uint32_t* len) { + // Don't try to be clever with shifting buffers. + // If the fast path failed let the protocol use its slow path. + // Besides, who is going to try to borrow across messages? + return NULL; +} + + +void TMemoryBuffer::computeRead(uint32_t len, uint8_t** out_start, uint32_t* out_give) { + // Correct rBound_ so we can use the fast path in the future. + rBound_ = wBase_; + + // Decide how much to give. + uint32_t give = std::min(len, available_read()); + + *out_start = rBase_; + *out_give = give; + + // Preincrement rBase_ so the caller doesn't have to. + rBase_ += give; +} + +uint32_t TMemoryBuffer::readSlow(uint8_t* buf, uint32_t len) { + uint8_t* start; + uint32_t give; + computeRead(len, &start, &give); + + // Copy into the provided buffer. + memcpy(buf, start, give); + + return give; +} + +uint32_t TMemoryBuffer::readAppendToString(std::string& str, uint32_t len) { + // Don't get some stupid assertion failure. + if (buffer_ == NULL) { + return 0; + } + + uint8_t* start; + uint32_t give; + computeRead(len, &start, &give); + + // Append to the provided string. + str.append((char*)start, give); + + return give; +} + +void TMemoryBuffer::ensureCanWrite(uint32_t len) { + // Check available space + uint32_t avail = available_write(); + if (len <= avail) { + return; + } + + if (!owner_) { + throw TTransportException("Insufficient space in external MemoryBuffer"); + } + + // Grow the buffer as necessary. + while (len > avail) { + bufferSize_ *= 2; + wBound_ = buffer_ + bufferSize_; + avail = available_write(); + } + + // Allocate into a new pointer so we don't bork ours if it fails. + void* new_buffer = std::realloc(buffer_, bufferSize_); + if (new_buffer == NULL) { + throw TTransportException("Out of memory."); + } + + ptrdiff_t offset = (uint8_t*)new_buffer - buffer_; + buffer_ += offset; + rBase_ += offset; + rBound_ += offset; + wBase_ += offset; + wBound_ += offset; +} + +void TMemoryBuffer::writeSlow(const uint8_t* buf, uint32_t len) { + ensureCanWrite(len); + + // Copy into the buffer and increment wBase_. + memcpy(wBase_, buf, len); + wBase_ += len; +} + +void TMemoryBuffer::wroteBytes(uint32_t len) { + uint32_t avail = available_write(); + if (len > avail) { + throw TTransportException("Client wrote more bytes than size of buffer."); + } + wBase_ += len; +} + +const uint8_t* TMemoryBuffer::borrowSlow(uint8_t* buf, uint32_t* len) { + rBound_ = wBase_; + if (available_read() >= *len) { + *len = available_read(); + return rBase_; + } + return NULL; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TBufferTransports.h b/lib/cpp/src/transport/TBufferTransports.h new file mode 100644 index 000000000..1908205ff --- /dev/null +++ b/lib/cpp/src/transport/TBufferTransports.h @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ +#define _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ 1 + +#include <cstring> +#include "boost/scoped_array.hpp" + +#include <transport/TTransport.h> + +#ifdef __GNUC__ +#define TDB_LIKELY(val) (__builtin_expect((val), 1)) +#define TDB_UNLIKELY(val) (__builtin_expect((val), 0)) +#else +#define TDB_LIKELY(val) (val) +#define TDB_UNLIKELY(val) (val) +#endif + +namespace apache { namespace thrift { namespace transport { + + +/** + * Base class for all transports that use read/write buffers for performance. + * + * TBufferBase is designed to implement the fast-path "memcpy" style + * operations that work in the common case. It does so with small and + * (eventually) nonvirtual, inlinable methods. TBufferBase is an abstract + * class. Subclasses are expected to define the "slow path" operations + * that have to be done when the buffers are full or empty. + * + */ +class TBufferBase : public TTransport { + + public: + + /** + * Fast-path read. + * + * When we have enough data buffered to fulfill the read, we can satisfy it + * with a single memcpy, then adjust our internal pointers. If the buffer + * is empty, we call out to our slow path, implemented by a subclass. + * This method is meant to eventually be nonvirtual and inlinable. + */ + uint32_t read(uint8_t* buf, uint32_t len) { + uint8_t* new_rBase = rBase_ + len; + if (TDB_LIKELY(new_rBase <= rBound_)) { + std::memcpy(buf, rBase_, len); + rBase_ = new_rBase; + return len; + } + return readSlow(buf, len); + } + + /** + * Fast-path write. + * + * When we have enough empty space in our buffer to accomodate the write, we + * can satisfy it with a single memcpy, then adjust our internal pointers. + * If the buffer is full, we call out to our slow path, implemented by a + * subclass. This method is meant to eventually be nonvirtual and + * inlinable. + */ + void write(const uint8_t* buf, uint32_t len) { + uint8_t* new_wBase = wBase_ + len; + if (TDB_LIKELY(new_wBase <= wBound_)) { + std::memcpy(wBase_, buf, len); + wBase_ = new_wBase; + return; + } + writeSlow(buf, len); + } + + /** + * Fast-path borrow. A lot like the fast-path read. + */ + const uint8_t* borrow(uint8_t* buf, uint32_t* len) { + if (TDB_LIKELY(static_cast<ptrdiff_t>(*len) <= rBound_ - rBase_)) { + // With strict aliasing, writing to len shouldn't force us to + // refetch rBase_ from memory. TODO(dreiss): Verify this. + *len = rBound_ - rBase_; + return rBase_; + } + return borrowSlow(buf, len); + } + + /** + * Consume doesn't require a slow path. + */ + void consume(uint32_t len) { + if (TDB_LIKELY(static_cast<ptrdiff_t>(len) <= rBound_ - rBase_)) { + rBase_ += len; + } else { + throw TTransportException(TTransportException::BAD_ARGS, + "consume did not follow a borrow."); + } + } + + + protected: + + /// Slow path read. + virtual uint32_t readSlow(uint8_t* buf, uint32_t len) = 0; + + /// Slow path write. + virtual void writeSlow(const uint8_t* buf, uint32_t len) = 0; + + /** + * Slow path borrow. + * + * POSTCONDITION: return == NULL || rBound_ - rBase_ >= *len + */ + virtual const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len) = 0; + + /** + * Trivial constructor. + * + * Initialize pointers safely. Constructing is not a very + * performance-sensitive operation, so it is okay to just leave it to + * the concrete class to set up pointers correctly. + */ + TBufferBase() + : rBase_(NULL) + , rBound_(NULL) + , wBase_(NULL) + , wBound_(NULL) + {} + + /// Convenience mutator for setting the read buffer. + void setReadBuffer(uint8_t* buf, uint32_t len) { + rBase_ = buf; + rBound_ = buf+len; + } + + /// Convenience mutator for setting the write buffer. + void setWriteBuffer(uint8_t* buf, uint32_t len) { + wBase_ = buf; + wBound_ = buf+len; + } + + virtual ~TBufferBase() {} + + /// Reads begin here. + uint8_t* rBase_; + /// Reads may extend to just before here. + uint8_t* rBound_; + + /// Writes begin here. + uint8_t* wBase_; + /// Writes may extend to just before here. + uint8_t* wBound_; +}; + + +/** + * Base class for all transport which wraps transport to new one. + */ +class TUnderlyingTransport : public TBufferBase { + public: + static const int DEFAULT_BUFFER_SIZE = 512; + + virtual bool peek() { + return (rBase_ < rBound_) || transport_->peek(); + } + + void open() { + transport_->open(); + } + + bool isOpen() { + return transport_->isOpen(); + } + + void close() { + flush(); + transport_->close(); + } + + boost::shared_ptr<TTransport> getUnderlyingTransport() { + return transport_; + } + + protected: + boost::shared_ptr<TTransport> transport_; + + uint32_t rBufSize_; + uint32_t wBufSize_; + boost::scoped_array<uint8_t> rBuf_; + boost::scoped_array<uint8_t> wBuf_; + + TUnderlyingTransport(boost::shared_ptr<TTransport> transport, uint32_t sz) + : transport_(transport) + , rBufSize_(sz) + , wBufSize_(sz) + , rBuf_(new uint8_t[rBufSize_]) + , wBuf_(new uint8_t[wBufSize_]) {} + + TUnderlyingTransport(boost::shared_ptr<TTransport> transport) + : transport_(transport) + , rBufSize_(DEFAULT_BUFFER_SIZE) + , wBufSize_(DEFAULT_BUFFER_SIZE) + , rBuf_(new uint8_t[rBufSize_]) + , wBuf_(new uint8_t[wBufSize_]) {} + + TUnderlyingTransport(boost::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz) + : transport_(transport) + , rBufSize_(rsz) + , wBufSize_(wsz) + , rBuf_(new uint8_t[rBufSize_]) + , wBuf_(new uint8_t[wBufSize_]) {} +}; + +/** + * Buffered transport. For reads it will read more data than is requested + * and will serve future data out of a local buffer. For writes, data is + * stored to an in memory buffer before being written out. + * + */ +class TBufferedTransport : public TUnderlyingTransport { + public: + + /// Use default buffer sizes. + TBufferedTransport(boost::shared_ptr<TTransport> transport) + : TUnderlyingTransport(transport) + { + initPointers(); + } + + /// Use specified buffer sizes. + TBufferedTransport(boost::shared_ptr<TTransport> transport, uint32_t sz) + : TUnderlyingTransport(transport, sz) + { + initPointers(); + } + + /// Use specified read and write buffer sizes. + TBufferedTransport(boost::shared_ptr<TTransport> transport, uint32_t rsz, uint32_t wsz) + : TUnderlyingTransport(transport, rsz, wsz) + { + initPointers(); + } + + virtual bool peek() { + /* shigin: see THRIFT-96 discussion */ + if (rBase_ == rBound_) { + setReadBuffer(rBuf_.get(), transport_->read(rBuf_.get(), rBufSize_)); + } + return (rBound_ > rBase_); + } + virtual uint32_t readSlow(uint8_t* buf, uint32_t len); + + virtual void writeSlow(const uint8_t* buf, uint32_t len); + + void flush(); + + + /** + * The following behavior is currently implemented by TBufferedTransport, + * but that may change in a future version: + * 1/ If len is at most rBufSize_, borrow will never return NULL. + * Depending on the underlying transport, it could throw an exception + * or hang forever. + * 2/ Some borrow requests may copy bytes internally. However, + * if len is at most rBufSize_/2, none of the copied bytes + * will ever have to be copied again. For optimial performance, + * stay under this limit. + */ + virtual const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len); + + protected: + void initPointers() { + setReadBuffer(rBuf_.get(), 0); + setWriteBuffer(wBuf_.get(), wBufSize_); + // Write size never changes. + } +}; + + +/** + * Wraps a transport into a buffered one. + * + */ +class TBufferedTransportFactory : public TTransportFactory { + public: + TBufferedTransportFactory() {} + + virtual ~TBufferedTransportFactory() {} + + /** + * Wraps the transport into a buffered one. + */ + virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TTransport>(new TBufferedTransport(trans)); + } + +}; + + +/** + * Framed transport. All writes go into an in-memory buffer until flush is + * called, at which point the transport writes the length of the entire + * binary chunk followed by the data payload. This allows the receiver on the + * other end to always do fixed-length reads. + * + */ +class TFramedTransport : public TUnderlyingTransport { + public: + + /// Use default buffer sizes. + TFramedTransport(boost::shared_ptr<TTransport> transport) + : TUnderlyingTransport(transport) + { + initPointers(); + } + + TFramedTransport(boost::shared_ptr<TTransport> transport, uint32_t sz) + : TUnderlyingTransport(transport, sz) + { + initPointers(); + } + + virtual uint32_t readSlow(uint8_t* buf, uint32_t len); + + virtual void writeSlow(const uint8_t* buf, uint32_t len); + + virtual void flush(); + + const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len); + + protected: + /** + * Reads a frame of input from the underlying stream. + */ + void readFrame(); + + void initPointers() { + setReadBuffer(NULL, 0); + setWriteBuffer(wBuf_.get(), wBufSize_); + + // Pad the buffer so we can insert the size later. + int32_t pad = 0; + this->write((uint8_t*)&pad, sizeof(pad)); + } +}; + +/** + * Wraps a transport into a framed one. + * + */ +class TFramedTransportFactory : public TTransportFactory { + public: + TFramedTransportFactory() {} + + virtual ~TFramedTransportFactory() {} + + /** + * Wraps the transport into a framed one. + */ + virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> trans) { + return boost::shared_ptr<TTransport>(new TFramedTransport(trans)); + } + +}; + + +/** + * A memory buffer is a tranpsort that simply reads from and writes to an + * in memory buffer. Anytime you call write on it, the data is simply placed + * into a buffer, and anytime you call read, data is read from that buffer. + * + * The buffers are allocated using C constructs malloc,realloc, and the size + * doubles as necessary. We've considered using scoped + * + */ +class TMemoryBuffer : public TBufferBase { + private: + + // Common initialization done by all constructors. + void initCommon(uint8_t* buf, uint32_t size, bool owner, uint32_t wPos) { + if (buf == NULL && size != 0) { + assert(owner); + buf = (uint8_t*)std::malloc(size); + if (buf == NULL) { + throw TTransportException("Out of memory"); + } + } + + buffer_ = buf; + bufferSize_ = size; + + rBase_ = buffer_; + rBound_ = buffer_ + wPos; + // TODO(dreiss): Investigate NULL-ing this if !owner. + wBase_ = buffer_ + wPos; + wBound_ = buffer_ + bufferSize_; + + owner_ = owner; + + // rBound_ is really an artifact. In principle, it should always be + // equal to wBase_. We update it in a few places (computeRead, etc.). + } + + public: + static const uint32_t defaultSize = 1024; + + /** + * This enum specifies how a TMemoryBuffer should treat + * memory passed to it via constructors or resetBuffer. + * + * OBSERVE: + * TMemoryBuffer will simply store a pointer to the memory. + * It is the callers responsibility to ensure that the pointer + * remains valid for the lifetime of the TMemoryBuffer, + * and that it is properly cleaned up. + * Note that no data can be written to observed buffers. + * + * COPY: + * TMemoryBuffer will make an internal copy of the buffer. + * The caller has no responsibilities. + * + * TAKE_OWNERSHIP: + * TMemoryBuffer will become the "owner" of the buffer, + * and will be responsible for freeing it. + * The membory must have been allocated with malloc. + */ + enum MemoryPolicy + { OBSERVE = 1 + , COPY = 2 + , TAKE_OWNERSHIP = 3 + }; + + /** + * Construct a TMemoryBuffer with a default-sized buffer, + * owned by the TMemoryBuffer object. + */ + TMemoryBuffer() { + initCommon(NULL, defaultSize, true, 0); + } + + /** + * Construct a TMemoryBuffer with a buffer of a specified size, + * owned by the TMemoryBuffer object. + * + * @param sz The initial size of the buffer. + */ + TMemoryBuffer(uint32_t sz) { + initCommon(NULL, sz, true, 0); + } + + /** + * Construct a TMemoryBuffer with buf as its initial contents. + * + * @param buf The initial contents of the buffer. + * Note that, while buf is a non-const pointer, + * TMemoryBuffer will not write to it if policy == OBSERVE, + * so it is safe to const_cast<uint8_t*>(whatever). + * @param sz The size of @c buf. + * @param policy See @link MemoryPolicy @endlink . + */ + TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) { + if (buf == NULL && sz != 0) { + throw TTransportException(TTransportException::BAD_ARGS, + "TMemoryBuffer given null buffer with non-zero size."); + } + + switch (policy) { + case OBSERVE: + case TAKE_OWNERSHIP: + initCommon(buf, sz, policy == TAKE_OWNERSHIP, sz); + break; + case COPY: + initCommon(NULL, sz, true, 0); + this->write(buf, sz); + break; + default: + throw TTransportException(TTransportException::BAD_ARGS, + "Invalid MemoryPolicy for TMemoryBuffer"); + } + } + + ~TMemoryBuffer() { + if (owner_) { + std::free(buffer_); + } + } + + bool isOpen() { + return true; + } + + bool peek() { + return (rBase_ < wBase_); + } + + void open() {} + + void close() {} + + // TODO(dreiss): Make bufPtr const. + void getBuffer(uint8_t** bufPtr, uint32_t* sz) { + *bufPtr = rBase_; + *sz = wBase_ - rBase_; + } + + std::string getBufferAsString() { + if (buffer_ == NULL) { + return ""; + } + uint8_t* buf; + uint32_t sz; + getBuffer(&buf, &sz); + return std::string((char*)buf, (std::string::size_type)sz); + } + + void appendBufferToString(std::string& str) { + if (buffer_ == NULL) { + return; + } + uint8_t* buf; + uint32_t sz; + getBuffer(&buf, &sz); + str.append((char*)buf, sz); + } + + void resetBuffer(bool reset_capacity = false) { + if (reset_capacity) + { + assert(owner_); + + void* new_buffer = std::realloc(buffer_, defaultSize); + + if (new_buffer == NULL) { + throw TTransportException("Out of memory."); + } + + buffer_ = (uint8_t*) new_buffer; + bufferSize_ = defaultSize; + + wBound_ = buffer_ + bufferSize_; + } + + rBase_ = buffer_; + rBound_ = buffer_; + wBase_ = buffer_; + // It isn't safe to write into a buffer we don't own. + if (!owner_) { + wBound_ = wBase_; + bufferSize_ = 0; + } + } + + /// See constructor documentation. + void resetBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) { + // Use a variant of the copy-and-swap trick for assignment operators. + // This is sub-optimal in terms of performance for two reasons: + // 1/ The constructing and swapping of the (small) values + // in the temporary object takes some time, and is not necessary. + // 2/ If policy == COPY, we allocate the new buffer before + // freeing the old one, precluding the possibility of + // reusing that memory. + // I doubt that either of these problems could be optimized away, + // but the second is probably no a common case, and the first is minor. + // I don't expect resetBuffer to be a common operation, so I'm willing to + // bite the performance bullet to make the method this simple. + + // Construct the new buffer. + TMemoryBuffer new_buffer(buf, sz, policy); + // Move it into ourself. + this->swap(new_buffer); + // Our old self gets destroyed. + } + + std::string readAsString(uint32_t len) { + std::string str; + (void)readAppendToString(str, len); + return str; + } + + uint32_t readAppendToString(std::string& str, uint32_t len); + + void readEnd() { + if (rBase_ == wBase_) { + resetBuffer(); + } + } + + uint32_t available_read() const { + // Remember, wBase_ is the real rBound_. + return wBase_ - rBase_; + } + + uint32_t available_write() const { + return wBound_ - wBase_; + } + + // Returns a pointer to where the client can write data to append to + // the TMemoryBuffer, and ensures the buffer is big enough to accomodate a + // write of the provided length. The returned pointer is very convenient for + // passing to read(), recv(), or similar. You must call wroteBytes() as soon + // as data is written or the buffer will not be aware that data has changed. + uint8_t* getWritePtr(uint32_t len) { + ensureCanWrite(len); + return wBase_; + } + + // Informs the buffer that the client has written 'len' bytes into storage + // that had been provided by getWritePtr(). + void wroteBytes(uint32_t len); + + protected: + void swap(TMemoryBuffer& that) { + using std::swap; + swap(buffer_, that.buffer_); + swap(bufferSize_, that.bufferSize_); + + swap(rBase_, that.rBase_); + swap(rBound_, that.rBound_); + swap(wBase_, that.wBase_); + swap(wBound_, that.wBound_); + + swap(owner_, that.owner_); + } + + // Make sure there's at least 'len' bytes available for writing. + void ensureCanWrite(uint32_t len); + + // Compute the position and available data for reading. + void computeRead(uint32_t len, uint8_t** out_start, uint32_t* out_give); + + uint32_t readSlow(uint8_t* buf, uint32_t len); + + void writeSlow(const uint8_t* buf, uint32_t len); + + const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len); + + // Data buffer + uint8_t* buffer_; + + // Allocated buffer size + uint32_t bufferSize_; + + // Is this object the owner of the buffer? + bool owner_; + + // Don't forget to update constrctors, initCommon, and swap if + // you add new members. +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ diff --git a/lib/cpp/src/transport/TFDTransport.cpp b/lib/cpp/src/transport/TFDTransport.cpp new file mode 100644 index 000000000..a042f8b74 --- /dev/null +++ b/lib/cpp/src/transport/TFDTransport.cpp @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cerrno> +#include <exception> + +#include <transport/TFDTransport.h> + +#include <unistd.h> + +using namespace std; + +namespace apache { namespace thrift { namespace transport { + +void TFDTransport::close() { + if (!isOpen()) { + return; + } + + int rv = ::close(fd_); + int errno_copy = errno; + fd_ = -1; + // Have to check uncaught_exception because this is called in the destructor. + if (rv < 0 && !std::uncaught_exception()) { + throw TTransportException(TTransportException::UNKNOWN, + "TFDTransport::close()", + errno_copy); + } +} + +uint32_t TFDTransport::read(uint8_t* buf, uint32_t len) { + ssize_t rv = ::read(fd_, buf, len); + if (rv < 0) { + int errno_copy = errno; + throw TTransportException(TTransportException::UNKNOWN, + "TFDTransport::read()", + errno_copy); + } + return rv; +} + +void TFDTransport::write(const uint8_t* buf, uint32_t len) { + while (len > 0) { + ssize_t rv = ::write(fd_, buf, len); + + if (rv < 0) { + int errno_copy = errno; + throw TTransportException(TTransportException::UNKNOWN, + "TFDTransport::write()", + errno_copy); + } else if (rv == 0) { + throw TTransportException(TTransportException::END_OF_FILE, + "TFDTransport::write()"); + } + + buf += rv; + len -= rv; + } +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TFDTransport.h b/lib/cpp/src/transport/TFDTransport.h new file mode 100644 index 000000000..bda5d82a9 --- /dev/null +++ b/lib/cpp/src/transport/TFDTransport.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TFDTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TFDTRANSPORT_H_ 1 + +#include <string> +#include <sys/time.h> + +#include "TTransport.h" +#include "TServerSocket.h" + +namespace apache { namespace thrift { namespace transport { + +/** + * Dead-simple wrapper around a file descriptor. + * + */ +class TFDTransport : public TTransport { + public: + enum ClosePolicy + { NO_CLOSE_ON_DESTROY = 0 + , CLOSE_ON_DESTROY = 1 + }; + + TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY) + : fd_(fd) + , close_policy_(close_policy) + {} + + ~TFDTransport() { + if (close_policy_ == CLOSE_ON_DESTROY) { + close(); + } + } + + bool isOpen() { return fd_ >= 0; } + + void open() {} + + void close(); + + uint32_t read(uint8_t* buf, uint32_t len); + + void write(const uint8_t* buf, uint32_t len); + + void setFD(int fd) { fd_ = fd; } + int getFD() { return fd_; } + + protected: + int fd_; + ClosePolicy close_policy_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TFDTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TFileTransport.cpp b/lib/cpp/src/transport/TFileTransport.cpp new file mode 100644 index 000000000..f67b9e355 --- /dev/null +++ b/lib/cpp/src/transport/TFileTransport.cpp @@ -0,0 +1,953 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "TFileTransport.h" +#include "TTransportUtils.h" + +#include <pthread.h> +#ifdef HAVE_SYS_TIME_H +#include <sys/time.h> +#else +#include <time.h> +#endif +#include <fcntl.h> +#include <errno.h> +#include <unistd.h> +#ifdef HAVE_STRINGS_H +#include <strings.h> +#endif +#include <cstdlib> +#include <cstring> +#include <iostream> +#include <sys/stat.h> + +namespace apache { namespace thrift { namespace transport { + +using boost::shared_ptr; +using namespace std; +using namespace apache::thrift::protocol; + +#ifndef HAVE_CLOCK_GETTIME + +/** + * Fake clock_gettime for systems like darwin + * + */ +#define CLOCK_REALTIME 0 +static int clock_gettime(int clk_id /*ignored*/, struct timespec *tp) { + struct timeval now; + + int rv = gettimeofday(&now, NULL); + if (rv != 0) { + return rv; + } + + tp->tv_sec = now.tv_sec; + tp->tv_nsec = now.tv_usec * 1000; + return 0; +} +#endif + +TFileTransport::TFileTransport(string path, bool readOnly) + : readState_() + , readBuff_(NULL) + , currentEvent_(NULL) + , readBuffSize_(DEFAULT_READ_BUFF_SIZE) + , readTimeout_(NO_TAIL_READ_TIMEOUT) + , chunkSize_(DEFAULT_CHUNK_SIZE) + , eventBufferSize_(DEFAULT_EVENT_BUFFER_SIZE) + , flushMaxUs_(DEFAULT_FLUSH_MAX_US) + , flushMaxBytes_(DEFAULT_FLUSH_MAX_BYTES) + , maxEventSize_(DEFAULT_MAX_EVENT_SIZE) + , maxCorruptedEvents_(DEFAULT_MAX_CORRUPTED_EVENTS) + , eofSleepTime_(DEFAULT_EOF_SLEEP_TIME_US) + , corruptedEventSleepTime_(DEFAULT_CORRUPTED_SLEEP_TIME_US) + , writerThreadId_(0) + , dequeueBuffer_(NULL) + , enqueueBuffer_(NULL) + , closing_(false) + , forceFlush_(false) + , filename_(path) + , fd_(0) + , bufferAndThreadInitialized_(false) + , offset_(0) + , lastBadChunk_(0) + , numCorruptedEventsInChunk_(0) + , readOnly_(readOnly) +{ + // initialize all the condition vars/mutexes + pthread_mutex_init(&mutex_, NULL); + pthread_cond_init(¬Full_, NULL); + pthread_cond_init(¬Empty_, NULL); + pthread_cond_init(&flushed_, NULL); + + openLogFile(); +} + +void TFileTransport::resetOutputFile(int fd, string filename, int64_t offset) { + filename_ = filename; + offset_ = offset; + + // check if current file is still open + if (fd_ > 0) { + // flush any events in the queue + flush(); + GlobalOutput.printf("error, current file (%s) not closed", filename_.c_str()); + if (-1 == ::close(fd_)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: resetOutputFile() ::close() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error in file close", errno_copy); + } + } + + if (fd) { + fd_ = fd; + } else { + // open file if the input fd is 0 + openLogFile(); + } +} + + +TFileTransport::~TFileTransport() { + // flush the buffer if a writer thread is active + if (writerThreadId_ > 0) { + // reduce the flush timeout so that closing is quicker + setFlushMaxUs(300*1000); + + // flush output buffer + flush(); + + // set state to closing + closing_ = true; + + // TODO: make sure event queue is empty + // currently only the write buffer is flushed + // we dont actually wait until the queue is empty. This shouldn't be a big + // deal in the common case because writing is quick + + pthread_join(writerThreadId_, NULL); + writerThreadId_ = 0; + } + + if (dequeueBuffer_) { + delete dequeueBuffer_; + dequeueBuffer_ = NULL; + } + + if (enqueueBuffer_) { + delete enqueueBuffer_; + enqueueBuffer_ = NULL; + } + + if (readBuff_) { + delete[] readBuff_; + readBuff_ = NULL; + } + + if (currentEvent_) { + delete currentEvent_; + currentEvent_ = NULL; + } + + // close logfile + if (fd_ > 0) { + if(-1 == ::close(fd_)) { + GlobalOutput.perror("TFileTransport: ~TFileTransport() ::close() ", errno); + } + } +} + +bool TFileTransport::initBufferAndWriteThread() { + if (bufferAndThreadInitialized_) { + T_ERROR("Trying to double-init TFileTransport"); + return false; + } + + if (writerThreadId_ == 0) { + if (pthread_create(&writerThreadId_, NULL, startWriterThread, (void *)this) != 0) { + T_ERROR("Could not create writer thread"); + return false; + } + } + + dequeueBuffer_ = new TFileTransportBuffer(eventBufferSize_); + enqueueBuffer_ = new TFileTransportBuffer(eventBufferSize_); + bufferAndThreadInitialized_ = true; + + return true; +} + +void TFileTransport::write(const uint8_t* buf, uint32_t len) { + if (readOnly_) { + throw TTransportException("TFileTransport: attempting to write to file opened readonly"); + } + + enqueueEvent(buf, len, false); +} + +void TFileTransport::enqueueEvent(const uint8_t* buf, uint32_t eventLen, bool blockUntilFlush) { + // can't enqueue more events if file is going to close + if (closing_) { + return; + } + + // make sure that event size is valid + if ( (maxEventSize_ > 0) && (eventLen > maxEventSize_) ) { + T_ERROR("msg size is greater than max event size: %u > %u\n", eventLen, maxEventSize_); + return; + } + + if (eventLen == 0) { + T_ERROR("cannot enqueue an empty event"); + return; + } + + eventInfo* toEnqueue = new eventInfo(); + toEnqueue->eventBuff_ = (uint8_t *)std::malloc((sizeof(uint8_t) * eventLen) + 4); + // first 4 bytes is the event length + memcpy(toEnqueue->eventBuff_, (void*)(&eventLen), 4); + // actual event contents + memcpy(toEnqueue->eventBuff_ + 4, buf, eventLen); + toEnqueue->eventSize_ = eventLen + 4; + + // lock mutex + pthread_mutex_lock(&mutex_); + + // make sure that enqueue buffer is initialized and writer thread is running + if (!bufferAndThreadInitialized_) { + if (!initBufferAndWriteThread()) { + delete toEnqueue; + pthread_mutex_unlock(&mutex_); + return; + } + } + + // Can't enqueue while buffer is full + while (enqueueBuffer_->isFull()) { + pthread_cond_wait(¬Full_, &mutex_); + } + + // add to the buffer + if (!enqueueBuffer_->addEvent(toEnqueue)) { + delete toEnqueue; + pthread_mutex_unlock(&mutex_); + return; + } + + // signal anybody who's waiting for the buffer to be non-empty + pthread_cond_signal(¬Empty_); + + if (blockUntilFlush) { + pthread_cond_wait(&flushed_, &mutex_); + } + + // this really should be a loop where it makes sure it got flushed + // because condition variables can get triggered by the os for no reason + // it is probably a non-factor for the time being + pthread_mutex_unlock(&mutex_); +} + +bool TFileTransport::swapEventBuffers(struct timespec* deadline) { + pthread_mutex_lock(&mutex_); + if (deadline != NULL) { + // if we were handed a deadline time struct, do a timed wait + pthread_cond_timedwait(¬Empty_, &mutex_, deadline); + } else { + // just wait until the buffer gets an item + pthread_cond_wait(¬Empty_, &mutex_); + } + + bool swapped = false; + + // could be empty if we timed out + if (!enqueueBuffer_->isEmpty()) { + TFileTransportBuffer *temp = enqueueBuffer_; + enqueueBuffer_ = dequeueBuffer_; + dequeueBuffer_ = temp; + + swapped = true; + } + + // unlock the mutex and signal if required + pthread_mutex_unlock(&mutex_); + + if (swapped) { + pthread_cond_signal(¬Full_); + } + + return swapped; +} + + +void TFileTransport::writerThread() { + // open file if it is not open + if(!fd_) { + openLogFile(); + } + + // set the offset to the correct value (EOF) + try { + seekToEnd(); + } catch (TException &te) { + } + + // throw away any partial events + offset_ += readState_.lastDispatchPtr_; + ftruncate(fd_, offset_); + readState_.resetAllValues(); + + // Figure out the next time by which a flush must take place + + struct timespec ts_next_flush; + getNextFlushTime(&ts_next_flush); + uint32_t unflushed = 0; + + while(1) { + // this will only be true when the destructor is being invoked + if(closing_) { + // empty out both the buffers + if (enqueueBuffer_->isEmpty() && dequeueBuffer_->isEmpty()) { + if (-1 == ::close(fd_)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: writerThread() ::close() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error in file close", errno_copy); + } + // just be safe and sync to disk + fsync(fd_); + fd_ = 0; + pthread_exit(NULL); + return; + } + } + + if (swapEventBuffers(&ts_next_flush)) { + eventInfo* outEvent; + while (NULL != (outEvent = dequeueBuffer_->getNext())) { + if (!outEvent) { + T_DEBUG_L(1, "Got an empty event"); + return; + } + + // sanity check on event + if ((maxEventSize_ > 0) && (outEvent->eventSize_ > maxEventSize_)) { + T_ERROR("msg size is greater than max event size: %u > %u\n", outEvent->eventSize_, maxEventSize_); + continue; + } + + // If chunking is required, then make sure that msg does not cross chunk boundary + if ((outEvent->eventSize_ > 0) && (chunkSize_ != 0)) { + + // event size must be less than chunk size + if(outEvent->eventSize_ > chunkSize_) { + T_ERROR("TFileTransport: event size(%u) is greater than chunk size(%u): skipping event", + outEvent->eventSize_, chunkSize_); + continue; + } + + int64_t chunk1 = offset_/chunkSize_; + int64_t chunk2 = (offset_ + outEvent->eventSize_ - 1)/chunkSize_; + + // if adding this event will cross a chunk boundary, pad the chunk with zeros + if (chunk1 != chunk2) { + // refetch the offset to keep in sync + offset_ = lseek(fd_, 0, SEEK_CUR); + int32_t padding = (int32_t)((offset_/chunkSize_ + 1)*chunkSize_ - offset_); + + uint8_t zeros[padding]; + bzero(zeros, padding); + if (-1 == ::write(fd_, zeros, padding)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: writerThread() error while padding zeros ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error while padding zeros", errno_copy); + } + unflushed += padding; + offset_ += padding; + } + } + + // write the dequeued event to the file + if (outEvent->eventSize_ > 0) { + if (-1 == ::write(fd_, outEvent->eventBuff_, outEvent->eventSize_)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: error while writing event ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error while writing event", errno_copy); + } + + unflushed += outEvent->eventSize_; + offset_ += outEvent->eventSize_; + } + } + dequeueBuffer_->reset(); + } + + bool flushTimeElapsed = false; + struct timespec current_time; + clock_gettime(CLOCK_REALTIME, ¤t_time); + + if (current_time.tv_sec > ts_next_flush.tv_sec || + (current_time.tv_sec == ts_next_flush.tv_sec && current_time.tv_nsec > ts_next_flush.tv_nsec)) { + flushTimeElapsed = true; + getNextFlushTime(&ts_next_flush); + } + + // couple of cases from which a flush could be triggered + if ((flushTimeElapsed && unflushed > 0) || + unflushed > flushMaxBytes_ || + forceFlush_) { + + // sync (force flush) file to disk + fsync(fd_); + unflushed = 0; + + // notify anybody waiting for flush completion + forceFlush_ = false; + pthread_cond_broadcast(&flushed_); + } + } +} + +void TFileTransport::flush() { + // file must be open for writing for any flushing to take place + if (writerThreadId_ <= 0) { + return; + } + // wait for flush to take place + pthread_mutex_lock(&mutex_); + + forceFlush_ = true; + + while (forceFlush_) { + pthread_cond_wait(&flushed_, &mutex_); + } + + pthread_mutex_unlock(&mutex_); +} + + +uint32_t TFileTransport::readAll(uint8_t* buf, uint32_t len) { + uint32_t have = 0; + uint32_t get = 0; + + while (have < len) { + get = read(buf+have, len-have); + if (get <= 0) { + throw TEOFException(); + } + have += get; + } + + return have; +} + +uint32_t TFileTransport::read(uint8_t* buf, uint32_t len) { + // check if there an event is ready to be read + if (!currentEvent_) { + currentEvent_ = readEvent(); + } + + // did not manage to read an event from the file. This could have happened + // if the timeout expired or there was some other error + if (!currentEvent_) { + return 0; + } + + // read as much of the current event as possible + int32_t remaining = currentEvent_->eventSize_ - currentEvent_->eventBuffPos_; + if (remaining <= (int32_t)len) { + // copy over anything thats remaining + if (remaining > 0) { + memcpy(buf, + currentEvent_->eventBuff_ + currentEvent_->eventBuffPos_, + remaining); + } + delete(currentEvent_); + currentEvent_ = NULL; + return remaining; + } + + // read as much as possible + memcpy(buf, currentEvent_->eventBuff_ + currentEvent_->eventBuffPos_, len); + currentEvent_->eventBuffPos_ += len; + return len; +} + +eventInfo* TFileTransport::readEvent() { + int readTries = 0; + + if (!readBuff_) { + readBuff_ = new uint8_t[readBuffSize_]; + } + + while (1) { + // read from the file if read buffer is exhausted + if (readState_.bufferPtr_ == readState_.bufferLen_) { + // advance the offset pointer + offset_ += readState_.bufferLen_; + readState_.bufferLen_ = ::read(fd_, readBuff_, readBuffSize_); + // if (readState_.bufferLen_) { + // T_DEBUG_L(1, "Amount read: %u (offset: %lu)", readState_.bufferLen_, offset_); + // } + readState_.bufferPtr_ = 0; + readState_.lastDispatchPtr_ = 0; + + // read error + if (readState_.bufferLen_ == -1) { + readState_.resetAllValues(); + GlobalOutput("TFileTransport: error while reading from file"); + throw TTransportException("TFileTransport: error while reading from file"); + } else if (readState_.bufferLen_ == 0) { // EOF + // wait indefinitely if there is no timeout + if (readTimeout_ == TAIL_READ_TIMEOUT) { + usleep(eofSleepTime_); + continue; + } else if (readTimeout_ == NO_TAIL_READ_TIMEOUT) { + // reset state + readState_.resetState(0); + return NULL; + } else if (readTimeout_ > 0) { + // timeout already expired once + if (readTries > 0) { + readState_.resetState(0); + return NULL; + } else { + usleep(readTimeout_ * 1000); + readTries++; + continue; + } + } + } + } + + readTries = 0; + + // attempt to read an event from the buffer + while(readState_.bufferPtr_ < readState_.bufferLen_) { + if (readState_.readingSize_) { + if(readState_.eventSizeBuffPos_ == 0) { + if ( (offset_ + readState_.bufferPtr_)/chunkSize_ != + ((offset_ + readState_.bufferPtr_ + 3)/chunkSize_)) { + // skip one byte towards chunk boundary + // T_DEBUG_L(1, "Skipping a byte"); + readState_.bufferPtr_++; + continue; + } + } + + readState_.eventSizeBuff_[readState_.eventSizeBuffPos_++] = + readBuff_[readState_.bufferPtr_++]; + if (readState_.eventSizeBuffPos_ == 4) { + // 0 length event indicates padding + if (*((uint32_t *)(readState_.eventSizeBuff_)) == 0) { + // T_DEBUG_L(1, "Got padding"); + readState_.resetState(readState_.lastDispatchPtr_); + continue; + } + // got a valid event + readState_.readingSize_ = false; + if (readState_.event_) { + delete(readState_.event_); + } + readState_.event_ = new eventInfo(); + readState_.event_->eventSize_ = *((uint32_t *)(readState_.eventSizeBuff_)); + + // check if the event is corrupted and perform recovery if required + if (isEventCorrupted()) { + performRecovery(); + // start from the top + break; + } + } + } else { + if (!readState_.event_->eventBuff_) { + readState_.event_->eventBuff_ = new uint8_t[readState_.event_->eventSize_]; + readState_.event_->eventBuffPos_ = 0; + } + // take either the entire event or the remaining bytes in the buffer + int reclaimBuffer = min((uint32_t)(readState_.bufferLen_ - readState_.bufferPtr_), + readState_.event_->eventSize_ - readState_.event_->eventBuffPos_); + + // copy data from read buffer into event buffer + memcpy(readState_.event_->eventBuff_ + readState_.event_->eventBuffPos_, + readBuff_ + readState_.bufferPtr_, + reclaimBuffer); + + // increment position ptrs + readState_.event_->eventBuffPos_ += reclaimBuffer; + readState_.bufferPtr_ += reclaimBuffer; + + // check if the event has been read in full + if (readState_.event_->eventBuffPos_ == readState_.event_->eventSize_) { + // set the completed event to the current event + eventInfo* completeEvent = readState_.event_; + completeEvent->eventBuffPos_ = 0; + + readState_.event_ = NULL; + readState_.resetState(readState_.bufferPtr_); + + // exit criteria + return completeEvent; + } + } + } + + } +} + +bool TFileTransport::isEventCorrupted() { + // an error is triggered if: + if ( (maxEventSize_ > 0) && (readState_.event_->eventSize_ > maxEventSize_)) { + // 1. Event size is larger than user-speficied max-event size + T_ERROR("Read corrupt event. Event size(%u) greater than max event size (%u)", + readState_.event_->eventSize_, maxEventSize_); + return true; + } else if (readState_.event_->eventSize_ > chunkSize_) { + // 2. Event size is larger than chunk size + T_ERROR("Read corrupt event. Event size(%u) greater than chunk size (%u)", + readState_.event_->eventSize_, chunkSize_); + return true; + } else if( ((offset_ + readState_.bufferPtr_ - 4)/chunkSize_) != + ((offset_ + readState_.bufferPtr_ + readState_.event_->eventSize_ - 1)/chunkSize_) ) { + // 3. size indicates that event crosses chunk boundary + T_ERROR("Read corrupt event. Event crosses chunk boundary. Event size:%u Offset:%ld", + readState_.event_->eventSize_, offset_ + readState_.bufferPtr_ + 4); + return true; + } + + return false; +} + +void TFileTransport::performRecovery() { + // perform some kickass recovery + uint32_t curChunk = getCurChunk(); + if (lastBadChunk_ == curChunk) { + numCorruptedEventsInChunk_++; + } else { + lastBadChunk_ = curChunk; + numCorruptedEventsInChunk_ = 1; + } + + if (numCorruptedEventsInChunk_ < maxCorruptedEvents_) { + // maybe there was an error in reading the file from disk + // seek to the beginning of chunk and try again + seekToChunk(curChunk); + } else { + + // just skip ahead to the next chunk if we not already at the last chunk + if (curChunk != (getNumChunks() - 1)) { + seekToChunk(curChunk + 1); + } else if (readTimeout_ == TAIL_READ_TIMEOUT) { + // if tailing the file, wait until there is enough data to start + // the next chunk + while(curChunk == (getNumChunks() - 1)) { + usleep(DEFAULT_CORRUPTED_SLEEP_TIME_US); + } + seekToChunk(curChunk + 1); + } else { + // pretty hosed at this stage, rewind the file back to the last successful + // point and punt on the error + readState_.resetState(readState_.lastDispatchPtr_); + currentEvent_ = NULL; + char errorMsg[1024]; + sprintf(errorMsg, "TFileTransport: log file corrupted at offset: %lu", + offset_ + readState_.lastDispatchPtr_); + GlobalOutput(errorMsg); + throw TTransportException(errorMsg); + } + } + +} + +void TFileTransport::seekToChunk(int32_t chunk) { + if (fd_ <= 0) { + throw TTransportException("File not open"); + } + + int32_t numChunks = getNumChunks(); + + // file is empty, seeking to chunk is pointless + if (numChunks == 0) { + return; + } + + // negative indicates reverse seek (from the end) + if (chunk < 0) { + chunk += numChunks; + } + + // too large a value for reverse seek, just seek to beginning + if (chunk < 0) { + T_DEBUG("Incorrect value for reverse seek. Seeking to beginning...", chunk) + chunk = 0; + } + + // cannot seek past EOF + bool seekToEnd = false; + uint32_t minEndOffset = 0; + if (chunk >= numChunks) { + T_DEBUG("Trying to seek past EOF. Seeking to EOF instead..."); + seekToEnd = true; + chunk = numChunks - 1; + // this is the min offset to process events till + minEndOffset = lseek(fd_, 0, SEEK_END); + } + + off_t newOffset = off_t(chunk) * chunkSize_; + offset_ = lseek(fd_, newOffset, SEEK_SET); + readState_.resetAllValues(); + currentEvent_ = NULL; + if (offset_ == -1) { + GlobalOutput("TFileTransport: lseek error in seekToChunk"); + throw TTransportException("TFileTransport: lseek error in seekToChunk"); + } + + // seek to EOF if user wanted to go to last chunk + if (seekToEnd) { + uint32_t oldReadTimeout = getReadTimeout(); + setReadTimeout(NO_TAIL_READ_TIMEOUT); + // keep on reading unti the last event at point of seekChunk call + while (readEvent() && ((offset_ + readState_.bufferPtr_) < minEndOffset)) {}; + setReadTimeout(oldReadTimeout); + } + +} + +void TFileTransport::seekToEnd() { + seekToChunk(getNumChunks()); +} + +uint32_t TFileTransport::getNumChunks() { + if (fd_ <= 0) { + return 0; + } + + struct stat f_info; + int rv = fstat(fd_, &f_info); + + if (rv < 0) { + int errno_copy = errno; + throw TTransportException(TTransportException::UNKNOWN, + "TFileTransport::getNumChunks() (fstat)", + errno_copy); + } + + if (f_info.st_size > 0) { + return ((f_info.st_size)/chunkSize_) + 1; + } + + // empty file has no chunks + return 0; +} + +uint32_t TFileTransport::getCurChunk() { + return offset_/chunkSize_; +} + +// Utility Functions +void TFileTransport::openLogFile() { + mode_t mode = readOnly_ ? S_IRUSR | S_IRGRP | S_IROTH : S_IRUSR | S_IWUSR| S_IRGRP | S_IROTH; + int flags = readOnly_ ? O_RDONLY : O_RDWR | O_CREAT | O_APPEND; + fd_ = ::open(filename_.c_str(), flags, mode); + offset_ = 0; + + // make sure open call was successful + if(fd_ == -1) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: openLogFile() ::open() file: " + filename_, errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, filename_, errno_copy); + } + +} + +void TFileTransport::getNextFlushTime(struct timespec* ts_next_flush) { + clock_gettime(CLOCK_REALTIME, ts_next_flush); + ts_next_flush->tv_nsec += (flushMaxUs_ % 1000000) * 1000; + if (ts_next_flush->tv_nsec > 1000000000) { + ts_next_flush->tv_nsec -= 1000000000; + ts_next_flush->tv_sec += 1; + } + ts_next_flush->tv_sec += flushMaxUs_ / 1000000; +} + +TFileTransportBuffer::TFileTransportBuffer(uint32_t size) + : bufferMode_(WRITE) + , writePoint_(0) + , readPoint_(0) + , size_(size) +{ + buffer_ = new eventInfo*[size]; +} + +TFileTransportBuffer::~TFileTransportBuffer() { + if (buffer_) { + for (uint32_t i = 0; i < writePoint_; i++) { + delete buffer_[i]; + } + delete[] buffer_; + buffer_ = NULL; + } +} + +bool TFileTransportBuffer::addEvent(eventInfo *event) { + if (bufferMode_ == READ) { + GlobalOutput("Trying to write to a buffer in read mode"); + } + if (writePoint_ < size_) { + buffer_[writePoint_++] = event; + return true; + } else { + // buffer is full + return false; + } +} + +eventInfo* TFileTransportBuffer::getNext() { + if (bufferMode_ == WRITE) { + bufferMode_ = READ; + } + if (readPoint_ < writePoint_) { + return buffer_[readPoint_++]; + } else { + // no more entries + return NULL; + } +} + +void TFileTransportBuffer::reset() { + if (bufferMode_ == WRITE || writePoint_ > readPoint_) { + T_DEBUG("Resetting a buffer with unread entries"); + } + // Clean up the old entries + for (uint32_t i = 0; i < writePoint_; i++) { + delete buffer_[i]; + } + bufferMode_ = WRITE; + writePoint_ = 0; + readPoint_ = 0; +} + +bool TFileTransportBuffer::isFull() { + return writePoint_ == size_; +} + +bool TFileTransportBuffer::isEmpty() { + return writePoint_ == 0; +} + +TFileProcessor::TFileProcessor(shared_ptr<TProcessor> processor, + shared_ptr<TProtocolFactory> protocolFactory, + shared_ptr<TFileReaderTransport> inputTransport): + processor_(processor), + inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory), + inputTransport_(inputTransport) { + + // default the output transport to a null transport (common case) + outputTransport_ = shared_ptr<TNullTransport>(new TNullTransport()); +} + +TFileProcessor::TFileProcessor(shared_ptr<TProcessor> processor, + shared_ptr<TProtocolFactory> inputProtocolFactory, + shared_ptr<TProtocolFactory> outputProtocolFactory, + shared_ptr<TFileReaderTransport> inputTransport): + processor_(processor), + inputProtocolFactory_(inputProtocolFactory), + outputProtocolFactory_(outputProtocolFactory), + inputTransport_(inputTransport) { + + // default the output transport to a null transport (common case) + outputTransport_ = shared_ptr<TNullTransport>(new TNullTransport()); +} + +TFileProcessor::TFileProcessor(shared_ptr<TProcessor> processor, + shared_ptr<TProtocolFactory> protocolFactory, + shared_ptr<TFileReaderTransport> inputTransport, + shared_ptr<TTransport> outputTransport): + processor_(processor), + inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory), + inputTransport_(inputTransport), + outputTransport_(outputTransport) {}; + +void TFileProcessor::process(uint32_t numEvents, bool tail) { + shared_ptr<TProtocol> inputProtocol = inputProtocolFactory_->getProtocol(inputTransport_); + shared_ptr<TProtocol> outputProtocol = outputProtocolFactory_->getProtocol(outputTransport_); + + // set the read timeout to 0 if tailing is required + int32_t oldReadTimeout = inputTransport_->getReadTimeout(); + if (tail) { + // save old read timeout so it can be restored + inputTransport_->setReadTimeout(TFileTransport::TAIL_READ_TIMEOUT); + } + + uint32_t numProcessed = 0; + while(1) { + // bad form to use exceptions for flow control but there is really + // no other way around it + try { + processor_->process(inputProtocol, outputProtocol); + numProcessed++; + if ( (numEvents > 0) && (numProcessed == numEvents)) { + return; + } + } catch (TEOFException& teof) { + if (!tail) { + break; + } + } catch (TException &te) { + cerr << te.what() << endl; + break; + } + } + + // restore old read timeout + if (tail) { + inputTransport_->setReadTimeout(oldReadTimeout); + } + +} + +void TFileProcessor::processChunk() { + shared_ptr<TProtocol> inputProtocol = inputProtocolFactory_->getProtocol(inputTransport_); + shared_ptr<TProtocol> outputProtocol = outputProtocolFactory_->getProtocol(outputTransport_); + + uint32_t curChunk = inputTransport_->getCurChunk(); + + while(1) { + // bad form to use exceptions for flow control but there is really + // no other way around it + try { + processor_->process(inputProtocol, outputProtocol); + if (curChunk != inputTransport_->getCurChunk()) { + break; + } + } catch (TEOFException& teof) { + break; + } catch (TException &te) { + cerr << te.what() << endl; + break; + } + } +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TFileTransport.h b/lib/cpp/src/transport/TFileTransport.h new file mode 100644 index 000000000..fbaf2cd0d --- /dev/null +++ b/lib/cpp/src/transport/TFileTransport.h @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TFILETRANSPORT_H_ +#define _THRIFT_TRANSPORT_TFILETRANSPORT_H_ 1 + +#include "TTransport.h" +#include "Thrift.h" +#include "TProcessor.h" + +#include <string> +#include <stdio.h> + +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace transport { + +using apache::thrift::TProcessor; +using apache::thrift::protocol::TProtocolFactory; + +// Data pertaining to a single event +typedef struct eventInfo { + uint8_t* eventBuff_; + uint32_t eventSize_; + uint32_t eventBuffPos_; + + eventInfo():eventBuff_(NULL), eventSize_(0), eventBuffPos_(0){}; + ~eventInfo() { + if (eventBuff_) { + delete[] eventBuff_; + } + } +} eventInfo; + +// information about current read state +typedef struct readState { + eventInfo* event_; + + // keep track of event size + uint8_t eventSizeBuff_[4]; + uint8_t eventSizeBuffPos_; + bool readingSize_; + + // read buffer variables + int32_t bufferPtr_; + int32_t bufferLen_; + + // last successful dispatch point + int32_t lastDispatchPtr_; + + void resetState(uint32_t lastDispatchPtr) { + readingSize_ = true; + eventSizeBuffPos_ = 0; + lastDispatchPtr_ = lastDispatchPtr; + } + + void resetAllValues() { + resetState(0); + bufferPtr_ = 0; + bufferLen_ = 0; + if (event_) { + delete(event_); + } + event_ = 0; + } + + readState() { + event_ = 0; + resetAllValues(); + } + + ~readState() { + if (event_) { + delete(event_); + } + } + +} readState; + +/** + * TFileTransportBuffer - buffer class used by TFileTransport for queueing up events + * to be written to disk. Should be used in the following way: + * 1) Buffer created + * 2) Buffer written to (addEvent) + * 3) Buffer read from (getNext) + * 4) Buffer reset (reset) + * 5) Go back to 2, or destroy buffer + * + * The buffer should never be written to after it is read from, unless it is reset first. + * Note: The above rules are enforced mainly for debugging its sole client TFileTransport + * which uses the buffer in this way. + * + */ +class TFileTransportBuffer { + public: + TFileTransportBuffer(uint32_t size); + ~TFileTransportBuffer(); + + bool addEvent(eventInfo *event); + eventInfo* getNext(); + void reset(); + bool isFull(); + bool isEmpty(); + + private: + TFileTransportBuffer(); // should not be used + + enum mode { + WRITE, + READ + }; + mode bufferMode_; + + uint32_t writePoint_; + uint32_t readPoint_; + uint32_t size_; + eventInfo** buffer_; +}; + +/** + * Abstract interface for transports used to read files + */ +class TFileReaderTransport : virtual public TTransport { + public: + virtual int32_t getReadTimeout() = 0; + virtual void setReadTimeout(int32_t readTimeout) = 0; + + virtual uint32_t getNumChunks() = 0; + virtual uint32_t getCurChunk() = 0; + virtual void seekToChunk(int32_t chunk) = 0; + virtual void seekToEnd() = 0; +}; + +/** + * Abstract interface for transports used to write files + */ +class TFileWriterTransport : virtual public TTransport { + public: + virtual uint32_t getChunkSize() = 0; + virtual void setChunkSize(uint32_t chunkSize) = 0; +}; + +/** + * File implementation of a transport. Reads and writes are done to a + * file on disk. + * + */ +class TFileTransport : public TFileReaderTransport, + public TFileWriterTransport { + public: + TFileTransport(std::string path, bool readOnly=false); + ~TFileTransport(); + + // TODO: what is the correct behaviour for this? + // the log file is generally always open + bool isOpen() { + return true; + } + + void write(const uint8_t* buf, uint32_t len); + void flush(); + + uint32_t readAll(uint8_t* buf, uint32_t len); + uint32_t read(uint8_t* buf, uint32_t len); + + // log-file specific functions + void seekToChunk(int32_t chunk); + void seekToEnd(); + uint32_t getNumChunks(); + uint32_t getCurChunk(); + + // for changing the output file + void resetOutputFile(int fd, std::string filename, int64_t offset); + + // Setter/Getter functions for user-controllable options + void setReadBuffSize(uint32_t readBuffSize) { + if (readBuffSize) { + readBuffSize_ = readBuffSize; + } + } + uint32_t getReadBuffSize() { + return readBuffSize_; + } + + static const int32_t TAIL_READ_TIMEOUT = -1; + static const int32_t NO_TAIL_READ_TIMEOUT = 0; + void setReadTimeout(int32_t readTimeout) { + readTimeout_ = readTimeout; + } + int32_t getReadTimeout() { + return readTimeout_; + } + + void setChunkSize(uint32_t chunkSize) { + if (chunkSize) { + chunkSize_ = chunkSize; + } + } + uint32_t getChunkSize() { + return chunkSize_; + } + + void setEventBufferSize(uint32_t bufferSize) { + if (bufferAndThreadInitialized_) { + GlobalOutput("Cannot change the buffer size after writer thread started"); + return; + } + eventBufferSize_ = bufferSize; + } + + uint32_t getEventBufferSize() { + return eventBufferSize_; + } + + void setFlushMaxUs(uint32_t flushMaxUs) { + if (flushMaxUs) { + flushMaxUs_ = flushMaxUs; + } + } + uint32_t getFlushMaxUs() { + return flushMaxUs_; + } + + void setFlushMaxBytes(uint32_t flushMaxBytes) { + if (flushMaxBytes) { + flushMaxBytes_ = flushMaxBytes; + } + } + uint32_t getFlushMaxBytes() { + return flushMaxBytes_; + } + + void setMaxEventSize(uint32_t maxEventSize) { + maxEventSize_ = maxEventSize; + } + uint32_t getMaxEventSize() { + return maxEventSize_; + } + + void setMaxCorruptedEvents(uint32_t maxCorruptedEvents) { + maxCorruptedEvents_ = maxCorruptedEvents; + } + uint32_t getMaxCorruptedEvents() { + return maxCorruptedEvents_; + } + + void setEofSleepTimeUs(uint32_t eofSleepTime) { + if (eofSleepTime) { + eofSleepTime_ = eofSleepTime; + } + } + uint32_t getEofSleepTimeUs() { + return eofSleepTime_; + } + + private: + // helper functions for writing to a file + void enqueueEvent(const uint8_t* buf, uint32_t eventLen, bool blockUntilFlush); + bool swapEventBuffers(struct timespec* deadline); + bool initBufferAndWriteThread(); + + // control for writer thread + static void* startWriterThread(void* ptr) { + (((TFileTransport*)ptr)->writerThread()); + return 0; + } + void writerThread(); + + // helper functions for reading from a file + eventInfo* readEvent(); + + // event corruption-related functions + bool isEventCorrupted(); + void performRecovery(); + + // Utility functions + void openLogFile(); + void getNextFlushTime(struct timespec* ts_next_flush); + + // Class variables + readState readState_; + uint8_t* readBuff_; + eventInfo* currentEvent_; + + uint32_t readBuffSize_; + static const uint32_t DEFAULT_READ_BUFF_SIZE = 1 * 1024 * 1024; + + int32_t readTimeout_; + static const int32_t DEFAULT_READ_TIMEOUT_MS = 200; + + // size of chunks that file will be split up into + uint32_t chunkSize_; + static const uint32_t DEFAULT_CHUNK_SIZE = 16 * 1024 * 1024; + + // size of event buffers + uint32_t eventBufferSize_; + static const uint32_t DEFAULT_EVENT_BUFFER_SIZE = 10000; + + // max number of microseconds that can pass without flushing + uint32_t flushMaxUs_; + static const uint32_t DEFAULT_FLUSH_MAX_US = 3000000; + + // max number of bytes that can be written without flushing + uint32_t flushMaxBytes_; + static const uint32_t DEFAULT_FLUSH_MAX_BYTES = 1000 * 1024; + + // max event size + uint32_t maxEventSize_; + static const uint32_t DEFAULT_MAX_EVENT_SIZE = 0; + + // max number of corrupted events per chunk + uint32_t maxCorruptedEvents_; + static const uint32_t DEFAULT_MAX_CORRUPTED_EVENTS = 0; + + // sleep duration when EOF is hit + uint32_t eofSleepTime_; + static const uint32_t DEFAULT_EOF_SLEEP_TIME_US = 500 * 1000; + + // sleep duration when a corrupted event is encountered + uint32_t corruptedEventSleepTime_; + static const uint32_t DEFAULT_CORRUPTED_SLEEP_TIME_US = 1 * 1000 * 1000; + + // writer thread id + pthread_t writerThreadId_; + + // buffers to hold data before it is flushed. Each element of the buffer stores a msg that + // needs to be written to the file. The buffers are swapped by the writer thread. + TFileTransportBuffer *dequeueBuffer_; + TFileTransportBuffer *enqueueBuffer_; + + // conditions used to block when the buffer is full or empty + pthread_cond_t notFull_, notEmpty_; + volatile bool closing_; + + // To keep track of whether the buffer has been flushed + pthread_cond_t flushed_; + volatile bool forceFlush_; + + // Mutex that is grabbed when enqueueing and swapping the read/write buffers + pthread_mutex_t mutex_; + + // File information + std::string filename_; + int fd_; + + // Whether the writer thread and buffers have been initialized + bool bufferAndThreadInitialized_; + + // Offset within the file + off_t offset_; + + // event corruption information + uint32_t lastBadChunk_; + uint32_t numCorruptedEventsInChunk_; + + bool readOnly_; +}; + +// Exception thrown when EOF is hit +class TEOFException : public TTransportException { + public: + TEOFException(): + TTransportException(TTransportException::END_OF_FILE) {}; +}; + + +// wrapper class to process events from a file containing thrift events +class TFileProcessor { + public: + /** + * Constructor that defaults output transport to null transport + * + * @param processor processes log-file events + * @param protocolFactory protocol factory + * @param inputTransport file transport + */ + TFileProcessor(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<TFileReaderTransport> inputTransport); + + TFileProcessor(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocolFactory> inputProtocolFactory, + boost::shared_ptr<TProtocolFactory> outputProtocolFactory, + boost::shared_ptr<TFileReaderTransport> inputTransport); + + /** + * Constructor + * + * @param processor processes log-file events + * @param protocolFactory protocol factory + * @param inputTransport input file transport + * @param output output transport + */ + TFileProcessor(boost::shared_ptr<TProcessor> processor, + boost::shared_ptr<TProtocolFactory> protocolFactory, + boost::shared_ptr<TFileReaderTransport> inputTransport, + boost::shared_ptr<TTransport> outputTransport); + + /** + * processes events from the file + * + * @param numEvents number of events to process (0 for unlimited) + * @param tail tails the file if true + */ + void process(uint32_t numEvents, bool tail); + + /** + * process events until the end of the chunk + * + */ + void processChunk(); + + private: + boost::shared_ptr<TProcessor> processor_; + boost::shared_ptr<TProtocolFactory> inputProtocolFactory_; + boost::shared_ptr<TProtocolFactory> outputProtocolFactory_; + boost::shared_ptr<TFileReaderTransport> inputTransport_; + boost::shared_ptr<TTransport> outputTransport_; +}; + + +}}} // apache::thrift::transport + +#endif // _THRIFT_TRANSPORT_TFILETRANSPORT_H_ diff --git a/lib/cpp/src/transport/THttpClient.cpp b/lib/cpp/src/transport/THttpClient.cpp new file mode 100644 index 000000000..59f233968 --- /dev/null +++ b/lib/cpp/src/transport/THttpClient.cpp @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cstdlib> +#include <sstream> + +#include "THttpClient.h" +#include "TSocket.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace std; + +/** + * Http client implementation. + * + */ + +// Yeah, yeah, hacky to put these here, I know. +static const char* CRLF = "\r\n"; +static const int CRLF_LEN = 2; + +THttpClient::THttpClient(boost::shared_ptr<TTransport> transport, string host, string path) : + transport_(transport), + host_(host), + path_(path), + readHeaders_(true), + chunked_(false), + chunkedDone_(false), + chunkSize_(0), + contentLength_(0), + httpBuf_(NULL), + httpPos_(0), + httpBufLen_(0), + httpBufSize_(1024) { + init(); +} + +THttpClient::THttpClient(string host, int port, string path) : + host_(host), + path_(path), + readHeaders_(true), + chunked_(false), + chunkedDone_(false), + chunkSize_(0), + contentLength_(0), + httpBuf_(NULL), + httpPos_(0), + httpBufLen_(0), + httpBufSize_(1024) { + transport_ = boost::shared_ptr<TTransport>(new TSocket(host, port)); + init(); +} + +void THttpClient::init() { + httpBuf_ = (char*)std::malloc(httpBufSize_+1); + if (httpBuf_ == NULL) { + throw TTransportException("Out of memory."); + } + httpBuf_[httpBufLen_] = '\0'; +} + +THttpClient::~THttpClient() { + if (httpBuf_ != NULL) { + std::free(httpBuf_); + } +} + +uint32_t THttpClient::read(uint8_t* buf, uint32_t len) { + if (readBuffer_.available_read() == 0) { + readBuffer_.resetBuffer(); + uint32_t got = readMoreData(); + if (got == 0) { + return 0; + } + } + return readBuffer_.read(buf, len); +} + +void THttpClient::readEnd() { + // Read any pending chunked data (footers etc.) + if (chunked_) { + while (!chunkedDone_) { + readChunked(); + } + } +} + +uint32_t THttpClient::readMoreData() { + // Get more data! + refill(); + + if (readHeaders_) { + readHeaders(); + } + + if (chunked_) { + return readChunked(); + } else { + return readContent(contentLength_); + } +} + +uint32_t THttpClient::readChunked() { + uint32_t length = 0; + + char* line = readLine(); + uint32_t chunkSize = parseChunkSize(line); + if (chunkSize == 0) { + readChunkedFooters(); + } else { + // Read data content + length += readContent(chunkSize); + // Read trailing CRLF after content + readLine(); + } + return length; +} + +void THttpClient::readChunkedFooters() { + // End of data, read footer lines until a blank one appears + while (true) { + char* line = readLine(); + if (strlen(line) == 0) { + chunkedDone_ = true; + break; + } + } +} + +uint32_t THttpClient::parseChunkSize(char* line) { + char* semi = strchr(line, ';'); + if (semi != NULL) { + *semi = '\0'; + } + int size = 0; + sscanf(line, "%x", &size); + return (uint32_t)size; +} + +uint32_t THttpClient::readContent(uint32_t size) { + uint32_t need = size; + while (need > 0) { + uint32_t avail = httpBufLen_ - httpPos_; + if (avail == 0) { + // We have given all the data, reset position to head of the buffer + httpPos_ = 0; + httpBufLen_ = 0; + refill(); + + // Now have available however much we read + avail = httpBufLen_; + } + uint32_t give = avail; + if (need < give) { + give = need; + } + readBuffer_.write((uint8_t*)(httpBuf_+httpPos_), give); + httpPos_ += give; + need -= give; + } + return size; +} + +char* THttpClient::readLine() { + while (true) { + char* eol = NULL; + + eol = strstr(httpBuf_+httpPos_, CRLF); + + // No CRLF yet? + if (eol == NULL) { + // Shift whatever we have now to front and refill + shift(); + refill(); + } else { + // Return pointer to next line + *eol = '\0'; + char* line = httpBuf_+httpPos_; + httpPos_ = (eol-httpBuf_) + CRLF_LEN; + return line; + } + } + +} + +void THttpClient::shift() { + if (httpBufLen_ > httpPos_) { + // Shift down remaining data and read more + uint32_t length = httpBufLen_ - httpPos_; + memmove(httpBuf_, httpBuf_+httpPos_, length); + httpBufLen_ = length; + } else { + httpBufLen_ = 0; + } + httpPos_ = 0; + httpBuf_[httpBufLen_] = '\0'; +} + +void THttpClient::refill() { + uint32_t avail = httpBufSize_ - httpBufLen_; + if (avail <= (httpBufSize_ / 4)) { + httpBufSize_ *= 2; + httpBuf_ = (char*)std::realloc(httpBuf_, httpBufSize_+1); + if (httpBuf_ == NULL) { + throw TTransportException("Out of memory."); + } + } + + // Read more data + uint32_t got = transport_->read((uint8_t*)(httpBuf_+httpBufLen_), httpBufSize_-httpBufLen_); + httpBufLen_ += got; + httpBuf_[httpBufLen_] = '\0'; + + if (got == 0) { + throw TTransportException("Could not refill buffer"); + } +} + +void THttpClient::readHeaders() { + // Initialize headers state variables + contentLength_ = 0; + chunked_ = false; + chunkedDone_ = false; + chunkSize_ = 0; + + // Control state flow + bool statusLine = true; + bool finished = false; + + // Loop until headers are finished + while (true) { + char* line = readLine(); + + if (strlen(line) == 0) { + if (finished) { + readHeaders_ = false; + return; + } else { + // Must have been an HTTP 100, keep going for another status line + statusLine = true; + } + } else { + if (statusLine) { + statusLine = false; + finished = parseStatusLine(line); + } else { + parseHeader(line); + } + } + } +} + +bool THttpClient::parseStatusLine(char* status) { + char* http = status; + + char* code = strchr(http, ' '); + if (code == NULL) { + throw TTransportException(string("Bad Status: ") + status); + } + + *code = '\0'; + while (*(code++) == ' '); + + char* msg = strchr(code, ' '); + if (msg == NULL) { + throw TTransportException(string("Bad Status: ") + status); + } + *msg = '\0'; + + if (strcmp(code, "200") == 0) { + // HTTP 200 = OK, we got the response + return true; + } else if (strcmp(code, "100") == 0) { + // HTTP 100 = continue, just keep reading + return false; + } else { + throw TTransportException(string("Bad Status: ") + status); + } +} + +void THttpClient::parseHeader(char* header) { + char* colon = strchr(header, ':'); + if (colon == NULL) { + return; + } + uint32_t sz = colon - header; + char* value = colon+1; + + if (strncmp(header, "Transfer-Encoding", sz) == 0) { + if (strstr(value, "chunked") != NULL) { + chunked_ = true; + } + } else if (strncmp(header, "Content-Length", sz) == 0) { + chunked_ = false; + contentLength_ = atoi(value); + } +} + +void THttpClient::write(const uint8_t* buf, uint32_t len) { + writeBuffer_.write(buf, len); +} + +void THttpClient::flush() { + // Fetch the contents of the write buffer + uint8_t* buf; + uint32_t len; + writeBuffer_.getBuffer(&buf, &len); + + // Construct the HTTP header + std::ostringstream h; + h << + "POST " << path_ << " HTTP/1.1" << CRLF << + "Host: " << host_ << CRLF << + "Content-Type: application/x-thrift" << CRLF << + "Content-Length: " << len << CRLF << + "Accept: application/x-thrift" << CRLF << + "User-Agent: C++/THttpClient" << CRLF << + CRLF; + string header = h.str(); + + // Write the header, then the data, then flush + transport_->write((const uint8_t*)header.c_str(), header.size()); + transport_->write(buf, len); + transport_->flush(); + + // Reset the buffer and header variables + writeBuffer_.resetBuffer(); + readHeaders_ = true; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/THttpClient.h b/lib/cpp/src/transport/THttpClient.h new file mode 100644 index 000000000..f4be4c1a6 --- /dev/null +++ b/lib/cpp/src/transport/THttpClient.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_THTTPCLIENT_H_ +#define _THRIFT_TRANSPORT_THTTPCLIENT_H_ 1 + +#include <transport/TBufferTransports.h> + +namespace apache { namespace thrift { namespace transport { + +/** + * HTTP client implementation of the thrift transport. This was irritating + * to write, but the alternatives in C++ land are daunting. Linking CURL + * requires 23 dynamic libraries last time I checked (WTF?!?). All we have + * here is a VERY basic HTTP/1.1 client which supports HTTP 100 Continue, + * chunked transfer encoding, keepalive, etc. Tested against Apache. + * + */ +class THttpClient : public TTransport { + public: + THttpClient(boost::shared_ptr<TTransport> transport, std::string host, std::string path=""); + + THttpClient(std::string host, int port, std::string path=""); + + virtual ~THttpClient(); + + void open() { + transport_->open(); + } + + bool isOpen() { + return transport_->isOpen(); + } + + bool peek() { + return transport_->peek(); + } + + void close() { + transport_->close(); + } + + uint32_t read(uint8_t* buf, uint32_t len); + + void readEnd(); + + void write(const uint8_t* buf, uint32_t len); + + void flush(); + + private: + void init(); + + protected: + + boost::shared_ptr<TTransport> transport_; + + TMemoryBuffer writeBuffer_; + TMemoryBuffer readBuffer_; + + std::string host_; + std::string path_; + + bool readHeaders_; + bool chunked_; + bool chunkedDone_; + uint32_t chunkSize_; + uint32_t contentLength_; + + char* httpBuf_; + uint32_t httpPos_; + uint32_t httpBufLen_; + uint32_t httpBufSize_; + + uint32_t readMoreData(); + char* readLine(); + + void readHeaders(); + void parseHeader(char* header); + bool parseStatusLine(char* status); + + uint32_t readChunked(); + void readChunkedFooters(); + uint32_t parseChunkSize(char* line); + + uint32_t readContent(uint32_t size); + + void refill(); + void shift(); + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_THTTPCLIENT_H_ diff --git a/lib/cpp/src/transport/TServerSocket.cpp b/lib/cpp/src/transport/TServerSocket.cpp new file mode 100644 index 000000000..9b47aa539 --- /dev/null +++ b/lib/cpp/src/transport/TServerSocket.cpp @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cstring> +#include <sys/socket.h> +#include <sys/poll.h> +#include <sys/types.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <netdb.h> +#include <fcntl.h> +#include <errno.h> + +#include "TSocket.h" +#include "TServerSocket.h" +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace transport { + +using namespace std; +using boost::shared_ptr; + +TServerSocket::TServerSocket(int port) : + port_(port), + serverSocket_(-1), + acceptBacklog_(1024), + sendTimeout_(0), + recvTimeout_(0), + retryLimit_(0), + retryDelay_(0), + tcpSendBuffer_(0), + tcpRecvBuffer_(0), + intSock1_(-1), + intSock2_(-1) {} + +TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) : + port_(port), + serverSocket_(-1), + acceptBacklog_(1024), + sendTimeout_(sendTimeout), + recvTimeout_(recvTimeout), + retryLimit_(0), + retryDelay_(0), + tcpSendBuffer_(0), + tcpRecvBuffer_(0), + intSock1_(-1), + intSock2_(-1) {} + +TServerSocket::~TServerSocket() { + close(); +} + +void TServerSocket::setSendTimeout(int sendTimeout) { + sendTimeout_ = sendTimeout; +} + +void TServerSocket::setRecvTimeout(int recvTimeout) { + recvTimeout_ = recvTimeout; +} + +void TServerSocket::setRetryLimit(int retryLimit) { + retryLimit_ = retryLimit; +} + +void TServerSocket::setRetryDelay(int retryDelay) { + retryDelay_ = retryDelay; +} + +void TServerSocket::setTcpSendBuffer(int tcpSendBuffer) { + tcpSendBuffer_ = tcpSendBuffer; +} + +void TServerSocket::setTcpRecvBuffer(int tcpRecvBuffer) { + tcpRecvBuffer_ = tcpRecvBuffer; +} + +void TServerSocket::listen() { + int sv[2]; + if (-1 == socketpair(AF_LOCAL, SOCK_STREAM, 0, sv)) { + GlobalOutput.perror("TServerSocket::listen() socketpair() ", errno); + intSock1_ = -1; + intSock2_ = -1; + } else { + intSock1_ = sv[1]; + intSock2_ = sv[0]; + } + + struct addrinfo hints, *res, *res0; + int error; + char port[sizeof("65536") + 1]; + std::memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + // Wildcard address + error = getaddrinfo(NULL, port, &hints, &res0); + if (error) { + GlobalOutput.printf("getaddrinfo %d: %s", error, gai_strerror(error)); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not resolve host for server socket."); + } + + // Pick the ipv6 address first since ipv4 addresses can be mapped + // into ipv6 space. + for (res = res0; res; res = res->ai_next) { + if (res->ai_family == AF_INET6 || res->ai_next == NULL) + break; + } + + serverSocket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (serverSocket_ == -1) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() socket() ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not create server socket.", errno_copy); + } + + // Set reusaddress to prevent 2MSL delay on accept + int one = 1; + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_REUSEADDR, + &one, sizeof(one))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_REUSEADDR ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_REUSEADDR", errno_copy); + } + + // Set TCP buffer sizes + if (tcpSendBuffer_ > 0) { + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_SNDBUF, + &tcpSendBuffer_, sizeof(tcpSendBuffer_))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_SNDBUF ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_SNDBUF", errno_copy); + } + } + + if (tcpRecvBuffer_ > 0) { + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_RCVBUF, + &tcpRecvBuffer_, sizeof(tcpRecvBuffer_))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_RCVBUF ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_RCVBUF", errno_copy); + } + } + + // Defer accept + #ifdef TCP_DEFER_ACCEPT + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, TCP_DEFER_ACCEPT, + &one, sizeof(one))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_DEFER_ACCEPT ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set TCP_DEFER_ACCEPT", errno_copy); + } + #endif // #ifdef TCP_DEFER_ACCEPT + + #ifdef IPV6_V6ONLY + int zero = 0; + if (-1 == setsockopt(serverSocket_, IPPROTO_IPV6, IPV6_V6ONLY, + &zero, sizeof(zero))) { + GlobalOutput.perror("TServerSocket::listen() IPV6_V6ONLY ", errno); + } + #endif // #ifdef IPV6_V6ONLY + + // Turn linger off, don't want to block on calls to close + struct linger ling = {0, 0}; + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_LINGER, + &ling, sizeof(ling))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_LINGER ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_LINGER", errno_copy); + } + + // TCP Nodelay, speed over bandwidth + if (-1 == setsockopt(serverSocket_, IPPROTO_TCP, TCP_NODELAY, + &one, sizeof(one))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_NODELAY ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set TCP_NODELAY", errno_copy); + } + + // Set NONBLOCK on the accept socket + int flags = fcntl(serverSocket_, F_GETFL, 0); + if (flags == -1) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() fcntl() F_GETFL ", errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + + if (-1 == fcntl(serverSocket_, F_SETFL, flags | O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() fcntl() O_NONBLOCK ", errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + + // prepare the port information + // we may want to try to bind more than once, since SO_REUSEADDR doesn't + // always seem to work. The client can configure the retry variables. + int retries = 0; + do { + if (0 == bind(serverSocket_, res->ai_addr, res->ai_addrlen)) { + break; + } + + // use short circuit evaluation here to only sleep if we need to + } while ((retries++ < retryLimit_) && (sleep(retryDelay_) == 0)); + + // free addrinfo + freeaddrinfo(res0); + + // throw an error if we failed to bind properly + if (retries > retryLimit_) { + char errbuf[1024]; + sprintf(errbuf, "TServerSocket::listen() BIND %d", port_); + GlobalOutput(errbuf); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not bind"); + } + + // Call listen + if (-1 == ::listen(serverSocket_, acceptBacklog_)) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not listen", errno_copy); + } + + // The socket is now listening! +} + +shared_ptr<TTransport> TServerSocket::acceptImpl() { + if (serverSocket_ < 0) { + throw TTransportException(TTransportException::NOT_OPEN, "TServerSocket not listening"); + } + + struct pollfd fds[2]; + + int maxEintrs = 5; + int numEintrs = 0; + + while (true) { + std::memset(fds, 0 , sizeof(fds)); + fds[0].fd = serverSocket_; + fds[0].events = POLLIN; + if (intSock2_ >= 0) { + fds[1].fd = intSock2_; + fds[1].events = POLLIN; + } + int ret = poll(fds, 2, -1); + + if (ret < 0) { + // error cases + if (errno == EINTR && (numEintrs++ < maxEintrs)) { + // EINTR needs to be handled manually and we can tolerate + // a certain number + continue; + } + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() poll() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy); + } else if (ret > 0) { + // Check for an interrupt signal + if (intSock2_ >= 0 && (fds[1].revents & POLLIN)) { + int8_t buf; + if (-1 == recv(intSock2_, &buf, sizeof(int8_t), 0)) { + GlobalOutput.perror("TServerSocket::acceptImpl() recv() interrupt ", errno); + } + throw TTransportException(TTransportException::INTERRUPTED); + } + + // Check for the actual server socket being ready + if (fds[0].revents & POLLIN) { + break; + } + } else { + GlobalOutput("TServerSocket::acceptImpl() poll 0"); + throw TTransportException(TTransportException::UNKNOWN); + } + } + + struct sockaddr_storage clientAddress; + int size = sizeof(clientAddress); + int clientSocket = ::accept(serverSocket_, + (struct sockaddr *) &clientAddress, + (socklen_t *) &size); + + if (clientSocket < 0) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() ::accept() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "accept()", errno_copy); + } + + // Make sure client socket is blocking + int flags = fcntl(clientSocket, F_GETFL, 0); + if (flags == -1) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() fcntl() F_GETFL ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_GETFL)", errno_copy); + } + + if (-1 == fcntl(clientSocket, F_SETFL, flags & ~O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() fcntl() F_SETFL ~O_NONBLOCK ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_SETFL)", errno_copy); + } + + shared_ptr<TSocket> client(new TSocket(clientSocket)); + if (sendTimeout_ > 0) { + client->setSendTimeout(sendTimeout_); + } + if (recvTimeout_ > 0) { + client->setRecvTimeout(recvTimeout_); + } + + return client; +} + +void TServerSocket::interrupt() { + if (intSock1_ >= 0) { + int8_t byte = 0; + if (-1 == send(intSock1_, &byte, sizeof(int8_t), 0)) { + GlobalOutput.perror("TServerSocket::interrupt() send() ", errno); + } + } +} + +void TServerSocket::close() { + if (serverSocket_ >= 0) { + shutdown(serverSocket_, SHUT_RDWR); + ::close(serverSocket_); + } + if (intSock1_ >= 0) { + ::close(intSock1_); + } + if (intSock2_ >= 0) { + ::close(intSock2_); + } + serverSocket_ = -1; + intSock1_ = -1; + intSock2_ = -1; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TServerSocket.h b/lib/cpp/src/transport/TServerSocket.h new file mode 100644 index 000000000..a6be01737 --- /dev/null +++ b/lib/cpp/src/transport/TServerSocket.h @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_ +#define _THRIFT_TRANSPORT_TSERVERSOCKET_H_ 1 + +#include "TServerTransport.h" +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace transport { + +class TSocket; + +/** + * Server socket implementation of TServerTransport. Wrapper around a unix + * socket listen and accept calls. + * + */ +class TServerSocket : public TServerTransport { + public: + TServerSocket(int port); + TServerSocket(int port, int sendTimeout, int recvTimeout); + + ~TServerSocket(); + + void setSendTimeout(int sendTimeout); + void setRecvTimeout(int recvTimeout); + + void setRetryLimit(int retryLimit); + void setRetryDelay(int retryDelay); + + void setTcpSendBuffer(int tcpSendBuffer); + void setTcpRecvBuffer(int tcpRecvBuffer); + + void listen(); + void close(); + + void interrupt(); + + protected: + boost::shared_ptr<TTransport> acceptImpl(); + + private: + int port_; + int serverSocket_; + int acceptBacklog_; + int sendTimeout_; + int recvTimeout_; + int retryLimit_; + int retryDelay_; + int tcpSendBuffer_; + int tcpRecvBuffer_; + + int intSock1_; + int intSock2_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_ diff --git a/lib/cpp/src/transport/TServerTransport.h b/lib/cpp/src/transport/TServerTransport.h new file mode 100644 index 000000000..40bbc6c78 --- /dev/null +++ b/lib/cpp/src/transport/TServerTransport.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ 1 + +#include "TTransport.h" +#include "TTransportException.h" +#include <boost/shared_ptr.hpp> + +namespace apache { namespace thrift { namespace transport { + +/** + * Server transport framework. A server needs to have some facility for + * creating base transports to read/write from. + * + */ +class TServerTransport { + public: + virtual ~TServerTransport() {} + + /** + * Starts the server transport listening for new connections. Prior to this + * call most transports will not return anything when accept is called. + * + * @throws TTransportException if we were unable to listen + */ + virtual void listen() {} + + /** + * Gets a new dynamically allocated transport object and passes it to the + * caller. Note that it is the explicit duty of the caller to free the + * allocated object. The returned TTransport object must always be in the + * opened state. NULL should never be returned, instead an Exception should + * always be thrown. + * + * @return A new TTransport object + * @throws TTransportException if there is an error + */ + boost::shared_ptr<TTransport> accept() { + boost::shared_ptr<TTransport> result = acceptImpl(); + if (result == NULL) { + throw TTransportException("accept() may not return NULL"); + } + return result; + } + + /** + * For "smart" TServerTransport implementations that work in a multi + * threaded context this can be used to break out of an accept() call. + * It is expected that the transport will throw a TTransportException + * with the interrupted error code. + */ + virtual void interrupt() {} + + /** + * Closes this transport such that future calls to accept will do nothing. + */ + virtual void close() = 0; + + protected: + TServerTransport() {} + + /** + * Subclasses should implement this function for accept. + * + * @return A newly allocated TTransport object + * @throw TTransportException If an error occurs + */ + virtual boost::shared_ptr<TTransport> acceptImpl() = 0; + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TShortReadTransport.h b/lib/cpp/src/transport/TShortReadTransport.h new file mode 100644 index 000000000..3df8a57ca --- /dev/null +++ b/lib/cpp/src/transport/TShortReadTransport.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ 1 + +#include <cstdlib> + +#include <transport/TTransport.h> + +namespace apache { namespace thrift { namespace transport { namespace test { + +/** + * This class is only meant for testing. It wraps another transport. + * Calls to read are passed through with some probability. Otherwise, + * the read amount is randomly reduced before being passed through. + * + */ +class TShortReadTransport : public TTransport { + public: + TShortReadTransport(boost::shared_ptr<TTransport> transport, double full_prob) + : transport_(transport) + , fullProb_(full_prob) + {} + + bool isOpen() { + return transport_->isOpen(); + } + + bool peek() { + return transport_->peek(); + } + + void open() { + transport_->open(); + } + + void close() { + transport_->close(); + } + + uint32_t read(uint8_t* buf, uint32_t len) { + if (len == 0) { + return 0; + } + + if (rand()/(double)RAND_MAX >= fullProb_) { + len = 1 + rand()%len; + } + return transport_->read(buf, len); + } + + void write(const uint8_t* buf, uint32_t len) { + transport_->write(buf, len); + } + + void flush() { + transport_->flush(); + } + + const uint8_t* borrow(uint8_t* buf, uint32_t* len) { + return transport_->borrow(buf, len); + } + + void consume(uint32_t len) { + return transport_->consume(len); + } + + boost::shared_ptr<TTransport> getUnderlyingTransport() { + return transport_; + } + + protected: + boost::shared_ptr<TTransport> transport_; + double fullProb_; +}; + +}}}} // apache::thrift::transport::test + +#endif // #ifndef _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TSimpleFileTransport.cpp b/lib/cpp/src/transport/TSimpleFileTransport.cpp new file mode 100644 index 000000000..e58a57430 --- /dev/null +++ b/lib/cpp/src/transport/TSimpleFileTransport.cpp @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TSimpleFileTransport.h" + +#include <sys/types.h> +#include <sys/stat.h> +#include <fcntl.h> + +namespace apache { namespace thrift { namespace transport { + +TSimpleFileTransport:: +TSimpleFileTransport(const std::string& path, bool read, bool write) + : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY) { + int flags = 0; + if (read && write) { + flags = O_RDWR; + } else if (read) { + flags = O_RDONLY; + } else if (write) { + flags = O_WRONLY; + } else { + throw TTransportException("Neither READ nor WRITE specified"); + } + if (write) { + flags |= O_CREAT | O_APPEND; + } + int fd = ::open(path.c_str(), + flags, + S_IRUSR | S_IWUSR| S_IRGRP | S_IROTH); + if (fd < 0) { + throw TTransportException("failed to open file for writing: " + path); + } + setFD(fd); + open(); +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TSimpleFileTransport.h b/lib/cpp/src/transport/TSimpleFileTransport.h new file mode 100644 index 000000000..6cc52ea1a --- /dev/null +++ b/lib/cpp/src/transport/TSimpleFileTransport.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ +#define _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ 1 + +#include "TFDTransport.h" + +namespace apache { namespace thrift { namespace transport { + +/** + * Dead-simple wrapper around a file. + * + * Writeable files are opened with O_CREAT and O_APPEND + */ +class TSimpleFileTransport : public TFDTransport { + public: + TSimpleFileTransport(const std::string& path, + bool read = true, + bool write = false); +}; + +}}} // apache::thrift::transport + +#endif // _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ diff --git a/lib/cpp/src/transport/TSocket.cpp b/lib/cpp/src/transport/TSocket.cpp new file mode 100644 index 000000000..3395dabdc --- /dev/null +++ b/lib/cpp/src/transport/TSocket.cpp @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <config.h> +#include <cstring> +#include <sstream> +#include <sys/socket.h> +#include <sys/poll.h> +#include <sys/types.h> +#include <arpa/inet.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <netdb.h> +#include <unistd.h> +#include <errno.h> +#include <fcntl.h> + +#include "concurrency/Monitor.h" +#include "TSocket.h" +#include "TTransportException.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace std; + +// Global var to track total socket sys calls +uint32_t g_socket_syscalls = 0; + +/** + * TSocket implementation. + * + */ + +TSocket::TSocket(string host, int port) : + host_(host), + port_(port), + socket_(-1), + connTimeout_(0), + sendTimeout_(0), + recvTimeout_(0), + lingerOn_(1), + lingerVal_(0), + noDelay_(1), + maxRecvRetries_(5) { + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); +} + +TSocket::TSocket() : + host_(""), + port_(0), + socket_(-1), + connTimeout_(0), + sendTimeout_(0), + recvTimeout_(0), + lingerOn_(1), + lingerVal_(0), + noDelay_(1), + maxRecvRetries_(5) { + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); +} + +TSocket::TSocket(int socket) : + host_(""), + port_(0), + socket_(socket), + connTimeout_(0), + sendTimeout_(0), + recvTimeout_(0), + lingerOn_(1), + lingerVal_(0), + noDelay_(1), + maxRecvRetries_(5) { + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); +} + +TSocket::~TSocket() { + close(); +} + +bool TSocket::isOpen() { + return (socket_ >= 0); +} + +bool TSocket::peek() { + if (!isOpen()) { + return false; + } + uint8_t buf; + int r = recv(socket_, &buf, 1, MSG_PEEK); + if (r == -1) { + int errno_copy = errno; + #ifdef __FreeBSD__ + /* shigin: + * freebsd returns -1 and ECONNRESET if socket was closed by + * the other side + */ + if (errno_copy == ECONNRESET) + { + close(); + return false; + } + #endif + GlobalOutput.perror("TSocket::peek() recv() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "recv()", errno_copy); + } + return (r > 0); +} + +void TSocket::openConnection(struct addrinfo *res) { + if (isOpen()) { + throw TTransportException(TTransportException::ALREADY_OPEN); + } + + socket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (socket_ == -1) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() socket() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "socket()", errno_copy); + } + + // Send timeout + if (sendTimeout_ > 0) { + setSendTimeout(sendTimeout_); + } + + // Recv timeout + if (recvTimeout_ > 0) { + setRecvTimeout(recvTimeout_); + } + + // Linger + setLinger(lingerOn_, lingerVal_); + + // No delay + setNoDelay(noDelay_); + + // Set the socket to be non blocking for connect if a timeout exists + int flags = fcntl(socket_, F_GETFL, 0); + if (connTimeout_ > 0) { + if (-1 == fcntl(socket_, F_SETFL, flags | O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() fcntl() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + } else { + if (-1 == fcntl(socket_, F_SETFL, flags & ~O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() fcntl " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + } + + // Connect the socket + int ret = connect(socket_, res->ai_addr, res->ai_addrlen); + + // success case + if (ret == 0) { + goto done; + } + + if (errno != EINPROGRESS) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() connect() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "connect() failed", errno_copy); + } + + + struct pollfd fds[1]; + std::memset(fds, 0 , sizeof(fds)); + fds[0].fd = socket_; + fds[0].events = POLLOUT; + ret = poll(fds, 1, connTimeout_); + + if (ret > 0) { + // Ensure the socket is connected and that there are no errors set + int val; + socklen_t lon; + lon = sizeof(int); + int ret2 = getsockopt(socket_, SOL_SOCKET, SO_ERROR, (void *)&val, &lon); + if (ret2 == -1) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() getsockopt() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "getsockopt()", errno_copy); + } + // no errors on socket, go to town + if (val == 0) { + goto done; + } + GlobalOutput.perror("TSocket::open() error on socket (after poll) " + getSocketInfo(), val); + throw TTransportException(TTransportException::NOT_OPEN, "socket open() error", val); + } else if (ret == 0) { + // socket timed out + string errStr = "TSocket::open() timed out " + getSocketInfo(); + GlobalOutput(errStr.c_str()); + throw TTransportException(TTransportException::NOT_OPEN, "open() timed out"); + } else { + // error on poll() + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() poll() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "poll() failed", errno_copy); + } + + done: + // Set socket back to normal mode (blocking) + fcntl(socket_, F_SETFL, flags); +} + +void TSocket::open() { + if (isOpen()) { + throw TTransportException(TTransportException::ALREADY_OPEN); + } + + // Validate port number + if (port_ < 0 || port_ > 65536) { + throw TTransportException(TTransportException::NOT_OPEN, "Specified port is invalid"); + } + + struct addrinfo hints, *res, *res0; + res = NULL; + res0 = NULL; + int error; + char port[sizeof("65536")]; + std::memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + error = getaddrinfo(host_.c_str(), port, &hints, &res0); + + if (error) { + string errStr = "TSocket::open() getaddrinfo() " + getSocketInfo() + string(gai_strerror(error)); + GlobalOutput(errStr.c_str()); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not resolve host for client socket."); + } + + // Cycle through all the returned addresses until one + // connects or push the exception up. + for (res = res0; res; res = res->ai_next) { + try { + openConnection(res); + break; + } catch (TTransportException& ttx) { + if (res->ai_next) { + close(); + } else { + close(); + freeaddrinfo(res0); // cleanup on failure + throw; + } + } + } + + // Free address structure memory + freeaddrinfo(res0); +} + +void TSocket::close() { + if (socket_ >= 0) { + shutdown(socket_, SHUT_RDWR); + ::close(socket_); + } + socket_ = -1; +} + +uint32_t TSocket::read(uint8_t* buf, uint32_t len) { + if (socket_ < 0) { + throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket"); + } + + int32_t retries = 0; + + // EAGAIN can be signalled both when a timeout has occurred and when + // the system is out of resources (an awesome undocumented feature). + // The following is an approximation of the time interval under which + // EAGAIN is taken to indicate an out of resources error. + uint32_t eagainThresholdMicros = 0; + if (recvTimeout_) { + // if a readTimeout is specified along with a max number of recv retries, then + // the threshold will ensure that the read timeout is not exceeded even in the + // case of resource errors + eagainThresholdMicros = (recvTimeout_*1000)/ ((maxRecvRetries_>0) ? maxRecvRetries_ : 2); + } + + try_again: + // Read from the socket + struct timeval begin; + gettimeofday(&begin, NULL); + int got = recv(socket_, buf, len, 0); + int errno_copy = errno; //gettimeofday can change errno + struct timeval end; + gettimeofday(&end, NULL); + uint32_t readElapsedMicros = (((end.tv_sec - begin.tv_sec) * 1000 * 1000) + + (((uint64_t)(end.tv_usec - begin.tv_usec)))); + ++g_socket_syscalls; + + // Check for error on read + if (got < 0) { + if (errno_copy == EAGAIN) { + // check if this is the lack of resources or timeout case + if (!eagainThresholdMicros || (readElapsedMicros < eagainThresholdMicros)) { + if (retries++ < maxRecvRetries_) { + usleep(50); + goto try_again; + } else { + throw TTransportException(TTransportException::TIMED_OUT, + "EAGAIN (unavailable resources)"); + } + } else { + // infer that timeout has been hit + throw TTransportException(TTransportException::TIMED_OUT, + "EAGAIN (timed out)"); + } + } + + // If interrupted, try again + if (errno_copy == EINTR && retries++ < maxRecvRetries_) { + goto try_again; + } + + // Now it's not a try again case, but a real probblez + GlobalOutput.perror("TSocket::read() recv() " + getSocketInfo(), errno_copy); + + // If we disconnect with no linger time + if (errno_copy == ECONNRESET) { + #ifdef __FreeBSD__ + /* shigin: freebsd doesn't follow POSIX semantic of recv and fails with + * ECONNRESET if peer performed shutdown + */ + close(); + return 0; + #else + throw TTransportException(TTransportException::NOT_OPEN, "ECONNRESET"); + #endif + } + + // This ish isn't open + if (errno_copy == ENOTCONN) { + throw TTransportException(TTransportException::NOT_OPEN, "ENOTCONN"); + } + + // Timed out! + if (errno_copy == ETIMEDOUT) { + throw TTransportException(TTransportException::TIMED_OUT, "ETIMEDOUT"); + } + + // Some other error, whatevz + throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy); + } + + // The remote host has closed the socket + if (got == 0) { + close(); + return 0; + } + + // Pack data into string + return got; +} + +void TSocket::write(const uint8_t* buf, uint32_t len) { + if (socket_ < 0) { + throw TTransportException(TTransportException::NOT_OPEN, "Called write on non-open socket"); + } + + uint32_t sent = 0; + + while (sent < len) { + + int flags = 0; + #ifdef MSG_NOSIGNAL + // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we + // check for the EPIPE return condition and close the socket in that case + flags |= MSG_NOSIGNAL; + #endif // ifdef MSG_NOSIGNAL + + int b = send(socket_, buf + sent, len - sent, flags); + ++g_socket_syscalls; + + // Fail on a send error + if (b < 0) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::write() send() " + getSocketInfo(), errno_copy); + + if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) { + close(); + throw TTransportException(TTransportException::NOT_OPEN, "write() send()", errno_copy); + } + + throw TTransportException(TTransportException::UNKNOWN, "write() send()", errno_copy); + } + + // Fail on blocked send + if (b == 0) { + throw TTransportException(TTransportException::NOT_OPEN, "Socket send returned 0."); + } + sent += b; + } +} + +std::string TSocket::getHost() { + return host_; +} + +int TSocket::getPort() { + return port_; +} + +void TSocket::setHost(string host) { + host_ = host; +} + +void TSocket::setPort(int port) { + port_ = port; +} + +void TSocket::setLinger(bool on, int linger) { + lingerOn_ = on; + lingerVal_ = linger; + if (socket_ < 0) { + return; + } + + struct linger l = {(lingerOn_ ? 1 : 0), lingerVal_}; + int ret = setsockopt(socket_, SOL_SOCKET, SO_LINGER, &l, sizeof(l)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setLinger() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setNoDelay(bool noDelay) { + noDelay_ = noDelay; + if (socket_ < 0) { + return; + } + + // Set socket to NODELAY + int v = noDelay_ ? 1 : 0; + int ret = setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, &v, sizeof(v)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setNoDelay() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setConnTimeout(int ms) { + connTimeout_ = ms; +} + +void TSocket::setRecvTimeout(int ms) { + if (ms < 0) { + char errBuf[512]; + sprintf(errBuf, "TSocket::setRecvTimeout with negative input: %d\n", ms); + GlobalOutput(errBuf); + return; + } + recvTimeout_ = ms; + + if (socket_ < 0) { + return; + } + + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); + + // Copy because poll may modify + struct timeval r = recvTimeval_; + int ret = setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &r, sizeof(r)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setRecvTimeout() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setSendTimeout(int ms) { + if (ms < 0) { + char errBuf[512]; + sprintf(errBuf, "TSocket::setSendTimeout with negative input: %d\n", ms); + GlobalOutput(errBuf); + return; + } + sendTimeout_ = ms; + + if (socket_ < 0) { + return; + } + + struct timeval s = {(int)(sendTimeout_/1000), + (int)((sendTimeout_%1000)*1000)}; + int ret = setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &s, sizeof(s)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setSendTimeout() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setMaxRecvRetries(int maxRecvRetries) { + maxRecvRetries_ = maxRecvRetries; +} + +string TSocket::getSocketInfo() { + std::ostringstream oss; + oss << "<Host: " << host_ << " Port: " << port_ << ">"; + return oss.str(); +} + +std::string TSocket::getPeerHost() { + if (peerHost_.empty()) { + struct sockaddr_storage addr; + socklen_t addrLen = sizeof(addr); + + if (socket_ < 0) { + return host_; + } + + int rv = getpeername(socket_, (sockaddr*) &addr, &addrLen); + + if (rv != 0) { + return peerHost_; + } + + char clienthost[NI_MAXHOST]; + char clientservice[NI_MAXSERV]; + + getnameinfo((sockaddr*) &addr, addrLen, + clienthost, sizeof(clienthost), + clientservice, sizeof(clientservice), 0); + + peerHost_ = clienthost; + } + return peerHost_; +} + +std::string TSocket::getPeerAddress() { + if (peerAddress_.empty()) { + struct sockaddr_storage addr; + socklen_t addrLen = sizeof(addr); + + if (socket_ < 0) { + return peerAddress_; + } + + int rv = getpeername(socket_, (sockaddr*) &addr, &addrLen); + + if (rv != 0) { + return peerAddress_; + } + + char clienthost[NI_MAXHOST]; + char clientservice[NI_MAXSERV]; + + getnameinfo((sockaddr*) &addr, addrLen, + clienthost, sizeof(clienthost), + clientservice, sizeof(clientservice), + NI_NUMERICHOST|NI_NUMERICSERV); + + peerAddress_ = clienthost; + peerPort_ = std::atoi(clientservice); + } + return peerAddress_; +} + +int TSocket::getPeerPort() { + getPeerAddress(); + return peerPort_; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TSocket.h b/lib/cpp/src/transport/TSocket.h new file mode 100644 index 000000000..b0f445aa3 --- /dev/null +++ b/lib/cpp/src/transport/TSocket.h @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSOCKET_H_ +#define _THRIFT_TRANSPORT_TSOCKET_H_ 1 + +#include <string> +#include <sys/time.h> + +#include "TTransport.h" +#include "TServerSocket.h" + +namespace apache { namespace thrift { namespace transport { + +/** + * TCP Socket implementation of the TTransport interface. + * + */ +class TSocket : public TTransport { + /** + * We allow the TServerSocket acceptImpl() method to access the private + * members of a socket so that it can access the TSocket(int socket) + * constructor which creates a socket object from the raw UNIX socket + * handle. + */ + friend class TServerSocket; + + public: + /** + * Constructs a new socket. Note that this does NOT actually connect the + * socket. + * + */ + TSocket(); + + /** + * Constructs a new socket. Note that this does NOT actually connect the + * socket. + * + * @param host An IP address or hostname to connect to + * @param port The port to connect on + */ + TSocket(std::string host, int port); + + /** + * Destroyes the socket object, closing it if necessary. + */ + virtual ~TSocket(); + + /** + * Whether the socket is alive. + * + * @return Is the socket alive? + */ + bool isOpen(); + + /** + * Calls select on the socket to see if there is more data available. + */ + bool peek(); + + /** + * Creates and opens the UNIX socket. + * + * @throws TTransportException If the socket could not connect + */ + virtual void open(); + + /** + * Shuts down communications on the socket. + */ + void close(); + + /** + * Reads from the underlying socket. + */ + uint32_t read(uint8_t* buf, uint32_t len); + + /** + * Writes to the underlying socket. + */ + void write(const uint8_t* buf, uint32_t len); + + /** + * Get the host that the socket is connected to + * + * @return string host identifier + */ + std::string getHost(); + + /** + * Get the port that the socket is connected to + * + * @return int port number + */ + int getPort(); + + /** + * Set the host that socket will connect to + * + * @param host host identifier + */ + void setHost(std::string host); + + /** + * Set the port that socket will connect to + * + * @param port port number + */ + void setPort(int port); + + /** + * Controls whether the linger option is set on the socket. + * + * @param on Whether SO_LINGER is on + * @param linger If linger is active, the number of seconds to linger for + */ + void setLinger(bool on, int linger); + + /** + * Whether to enable/disable Nagle's algorithm. + * + * @param noDelay Whether or not to disable the algorithm. + * @return + */ + void setNoDelay(bool noDelay); + + /** + * Set the connect timeout + */ + void setConnTimeout(int ms); + + /** + * Set the receive timeout + */ + void setRecvTimeout(int ms); + + /** + * Set the send timeout + */ + void setSendTimeout(int ms); + + /** + * Set the max number of recv retries in case of an EAGAIN + * error + */ + void setMaxRecvRetries(int maxRecvRetries); + + /** + * Get socket information formated as a string <Host: x Port: x> + */ + std::string getSocketInfo(); + + /** + * Returns the DNS name of the host to which the socket is connected + */ + std::string getPeerHost(); + + /** + * Returns the address of the host to which the socket is connected + */ + std::string getPeerAddress(); + + /** + * Returns the port of the host to which the socket is connected + **/ + int getPeerPort(); + + + protected: + /** + * Constructor to create socket from raw UNIX handle. Never called directly + * but used by the TServerSocket class. + */ + TSocket(int socket); + + /** connect, called by open */ + void openConnection(struct addrinfo *res); + + /** Host to connect to */ + std::string host_; + + /** Peer hostname */ + std::string peerHost_; + + /** Peer address */ + std::string peerAddress_; + + /** Peer port */ + int peerPort_; + + /** Port number to connect on */ + int port_; + + /** Underlying UNIX socket handle */ + int socket_; + + /** Connect timeout in ms */ + int connTimeout_; + + /** Send timeout in ms */ + int sendTimeout_; + + /** Recv timeout in ms */ + int recvTimeout_; + + /** Linger on */ + bool lingerOn_; + + /** Linger val */ + int lingerVal_; + + /** Nodelay */ + bool noDelay_; + + /** Recv EGAIN retries */ + int maxRecvRetries_; + + /** Recv timeout timeval */ + struct timeval recvTimeval_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSOCKET_H_ + diff --git a/lib/cpp/src/transport/TSocketPool.cpp b/lib/cpp/src/transport/TSocketPool.cpp new file mode 100644 index 000000000..1150282bb --- /dev/null +++ b/lib/cpp/src/transport/TSocketPool.cpp @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <algorithm> +#include <iostream> + +#include "TSocketPool.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace std; + +using boost::shared_ptr; + +/** + * TSocketPoolServer implementation + * + */ +TSocketPoolServer::TSocketPoolServer() + : host_(""), + port_(0), + socket_(-1), + lastFailTime_(0), + consecutiveFailures_(0) {} + +/** + * Constructor for TSocketPool server + */ +TSocketPoolServer::TSocketPoolServer(const string &host, int port) + : host_(host), + port_(port), + socket_(-1), + lastFailTime_(0), + consecutiveFailures_(0) {} + +/** + * TSocketPool implementation. + * + */ + +TSocketPool::TSocketPool() : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) { +} + +TSocketPool::TSocketPool(const vector<string> &hosts, + const vector<int> &ports) : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ + if (hosts.size() != ports.size()) { + GlobalOutput("TSocketPool::TSocketPool: hosts.size != ports.size"); + throw TTransportException(TTransportException::BAD_ARGS); + } + + for (unsigned int i = 0; i < hosts.size(); ++i) { + addServer(hosts[i], ports[i]); + } +} + +TSocketPool::TSocketPool(const vector<pair<string, int> >& servers) : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ + for (unsigned i = 0; i < servers.size(); ++i) { + addServer(servers[i].first, servers[i].second); + } +} + +TSocketPool::TSocketPool(const vector< shared_ptr<TSocketPoolServer> >& servers) : TSocket(), + servers_(servers), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ +} + +TSocketPool::TSocketPool(const string& host, int port) : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ + addServer(host, port); +} + +TSocketPool::~TSocketPool() { + vector< shared_ptr<TSocketPoolServer> >::const_iterator iter = servers_.begin(); + vector< shared_ptr<TSocketPoolServer> >::const_iterator iterEnd = servers_.end(); + for (; iter != iterEnd; ++iter) { + setCurrentServer(*iter); + TSocketPool::close(); + } +} + +void TSocketPool::addServer(const string& host, int port) { + servers_.push_back(shared_ptr<TSocketPoolServer>(new TSocketPoolServer(host, port))); +} + +void TSocketPool::setServers(const vector< shared_ptr<TSocketPoolServer> >& servers) { + servers_ = servers; +} + +void TSocketPool::getServers(vector< shared_ptr<TSocketPoolServer> >& servers) { + servers = servers_; +} + +void TSocketPool::setNumRetries(int numRetries) { + numRetries_ = numRetries; +} + +void TSocketPool::setRetryInterval(int retryInterval) { + retryInterval_ = retryInterval; +} + + +void TSocketPool::setMaxConsecutiveFailures(int maxConsecutiveFailures) { + maxConsecutiveFailures_ = maxConsecutiveFailures; +} + +void TSocketPool::setRandomize(bool randomize) { + randomize_ = randomize; +} + +void TSocketPool::setAlwaysTryLast(bool alwaysTryLast) { + alwaysTryLast_ = alwaysTryLast; +} + +void TSocketPool::setCurrentServer(const shared_ptr<TSocketPoolServer> &server) { + currentServer_ = server; + host_ = server->host_; + port_ = server->port_; + socket_ = server->socket_; +} + +/* TODO: without apc we ignore a lot of functionality from the php version */ +void TSocketPool::open() { + if (randomize_) { + random_shuffle(servers_.begin(), servers_.end()); + } + + unsigned int numServers = servers_.size(); + for (unsigned int i = 0; i < numServers; ++i) { + + shared_ptr<TSocketPoolServer> &server = servers_[i]; + bool retryIntervalPassed = (server->lastFailTime_ == 0); + bool isLastServer = alwaysTryLast_ ? (i == (numServers - 1)) : false; + + // Impersonate the server socket + setCurrentServer(server); + + if (isOpen()) { + // already open means we're done + return; + } + + if (server->lastFailTime_ > 0) { + // The server was marked as down, so check if enough time has elapsed to retry + int elapsedTime = time(NULL) - server->lastFailTime_; + if (elapsedTime > retryInterval_) { + retryIntervalPassed = true; + } + } + + if (retryIntervalPassed || isLastServer) { + for (int j = 0; j < numRetries_; ++j) { + try { + TSocket::open(); + + // Copy over the opened socket so that we can keep it persistent + server->socket_ = socket_; + + // reset lastFailTime_ is required + if (server->lastFailTime_) { + server->lastFailTime_ = 0; + } + + // success + return; + } catch (TException e) { + string errStr = "TSocketPool::open failed "+getSocketInfo()+": "+e.what(); + GlobalOutput(errStr.c_str()); + // connection failed + } + } + + ++server->consecutiveFailures_; + if (server->consecutiveFailures_ > maxConsecutiveFailures_) { + // Mark server as down + server->consecutiveFailures_ = 0; + server->lastFailTime_ = time(NULL); + } + } + } + + GlobalOutput("TSocketPool::open: all connections failed"); + throw TTransportException(TTransportException::NOT_OPEN); +} + +void TSocketPool::close() { + if (isOpen()) { + TSocket::close(); + currentServer_->socket_ = -1; + } +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TSocketPool.h b/lib/cpp/src/transport/TSocketPool.h new file mode 100644 index 000000000..8c506695a --- /dev/null +++ b/lib/cpp/src/transport/TSocketPool.h @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSOCKETPOOL_H_ +#define _THRIFT_TRANSPORT_TSOCKETPOOL_H_ 1 + +#include <vector> +#include "TSocket.h" + +namespace apache { namespace thrift { namespace transport { + + /** + * Class to hold server information for TSocketPool + * + */ +class TSocketPoolServer { + + public: + /** + * Default constructor for server info + */ + TSocketPoolServer(); + + /** + * Constructor for TSocketPool server + */ + TSocketPoolServer(const std::string &host, int port); + + // Host name + std::string host_; + + // Port to connect on + int port_; + + // Socket for the server + int socket_; + + // Last time connecting to this server failed + int lastFailTime_; + + // Number of consecutive times connecting to this server failed + int consecutiveFailures_; +}; + +/** + * TCP Socket implementation of the TTransport interface. + * + */ +class TSocketPool : public TSocket { + + public: + + /** + * Socket pool constructor + */ + TSocketPool(); + + /** + * Socket pool constructor + * + * @param hosts list of host names + * @param ports list of port names + */ + TSocketPool(const std::vector<std::string> &hosts, + const std::vector<int> &ports); + + /** + * Socket pool constructor + * + * @param servers list of pairs of host name and port + */ + TSocketPool(const std::vector<std::pair<std::string, int> >& servers); + + /** + * Socket pool constructor + * + * @param servers list of TSocketPoolServers + */ + TSocketPool(const std::vector< boost::shared_ptr<TSocketPoolServer> >& servers); + + /** + * Socket pool constructor + * + * @param host single host + * @param port single port + */ + TSocketPool(const std::string& host, int port); + + /** + * Destroyes the socket object, closing it if necessary. + */ + virtual ~TSocketPool(); + + /** + * Add a server to the pool + */ + void addServer(const std::string& host, int port); + + /** + * Set list of servers in this pool + */ + void setServers(const std::vector< boost::shared_ptr<TSocketPoolServer> >& servers); + + /** + * Get list of servers in this pool + */ + void getServers(std::vector< boost::shared_ptr<TSocketPoolServer> >& servers); + + /** + * Sets how many times to keep retrying a host in the connect function. + */ + void setNumRetries(int numRetries); + + /** + * Sets how long to wait until retrying a host if it was marked down + */ + void setRetryInterval(int retryInterval); + + /** + * Sets how many times to keep retrying a host before marking it as down. + */ + void setMaxConsecutiveFailures(int maxConsecutiveFailures); + + /** + * Turns randomization in connect order on or off. + */ + void setRandomize(bool randomize); + + /** + * Whether to always try the last server. + */ + void setAlwaysTryLast(bool alwaysTryLast); + + /** + * Creates and opens the UNIX socket. + */ + void open(); + + /* + * Closes the UNIX socket + */ + void close(); + + protected: + + void setCurrentServer(const boost::shared_ptr<TSocketPoolServer> &server); + + /** List of servers to connect to */ + std::vector< boost::shared_ptr<TSocketPoolServer> > servers_; + + /** Current server */ + boost::shared_ptr<TSocketPoolServer> currentServer_; + + /** How many times to retry each host in connect */ + int numRetries_; + + /** Retry interval in seconds, how long to not try a host if it has been + * marked as down. + */ + int retryInterval_; + + /** Max consecutive failures before marking a host down. */ + int maxConsecutiveFailures_; + + /** Try hosts in order? or Randomized? */ + bool randomize_; + + /** Always try last host, even if marked down? */ + bool alwaysTryLast_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSOCKETPOOL_H_ + diff --git a/lib/cpp/src/transport/TTransport.h b/lib/cpp/src/transport/TTransport.h new file mode 100644 index 000000000..eb0d5df8a --- /dev/null +++ b/lib/cpp/src/transport/TTransport.h @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TTRANSPORT_H_ 1 + +#include <Thrift.h> +#include <boost/shared_ptr.hpp> +#include <transport/TTransportException.h> +#include <string> + +namespace apache { namespace thrift { namespace transport { + +/** + * Generic interface for a method of transporting data. A TTransport may be + * capable of either reading or writing, but not necessarily both. + * + */ +class TTransport { + public: + /** + * Virtual deconstructor. + */ + virtual ~TTransport() {} + + /** + * Whether this transport is open. + */ + virtual bool isOpen() { + return false; + } + + /** + * Tests whether there is more data to read or if the remote side is + * still open. By default this is true whenever the transport is open, + * but implementations should add logic to test for this condition where + * possible (i.e. on a socket). + * This is used by a server to check if it should listen for another + * request. + */ + virtual bool peek() { + return isOpen(); + } + + /** + * Opens the transport for communications. + * + * @return bool Whether the transport was successfully opened + * @throws TTransportException if opening failed + */ + virtual void open() { + throw TTransportException(TTransportException::NOT_OPEN, "Cannot open base TTransport."); + } + + /** + * Closes the transport. + */ + virtual void close() { + throw TTransportException(TTransportException::NOT_OPEN, "Cannot close base TTransport."); + } + + /** + * Attempt to read up to the specified number of bytes into the string. + * + * @param buf Reference to the location to write the data + * @param len How many bytes to read + * @return How many bytes were actually read + * @throws TTransportException If an error occurs + */ + virtual uint32_t read(uint8_t* /* buf */, uint32_t /* len */) { + throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot read."); + } + + /** + * Reads the given amount of data in its entirety no matter what. + * + * @param s Reference to location for read data + * @param len How many bytes to read + * @return How many bytes read, which must be equal to size + * @throws TTransportException If insufficient data was read + */ + virtual uint32_t readAll(uint8_t* buf, uint32_t len) { + uint32_t have = 0; + uint32_t get = 0; + + while (have < len) { + get = read(buf+have, len-have); + if (get <= 0) { + throw TTransportException("No more data to read."); + } + have += get; + } + + return have; + } + + /** + * Called when read is completed. + * This can be over-ridden to perform a transport-specific action + * e.g. logging the request to a file + * + */ + virtual void readEnd() { + // default behaviour is to do nothing + return; + } + + /** + * Writes the string in its entirety to the buffer. + * + * @param buf The data to write out + * @throws TTransportException if an error occurs + */ + virtual void write(const uint8_t* /* buf */, uint32_t /* len */) { + throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot write."); + } + + /** + * Called when write is completed. + * This can be over-ridden to perform a transport-specific action + * at the end of a request. + * + */ + virtual void writeEnd() { + // default behaviour is to do nothing + return; + } + + /** + * Flushes any pending data to be written. Typically used with buffered + * transport mechanisms. + * + * @throws TTransportException if an error occurs + */ + virtual void flush() {} + + /** + * Attempts to return a pointer to \c len bytes, possibly copied into \c buf. + * Does not consume the bytes read (i.e.: a later read will return the same + * data). This method is meant to support protocols that need to read + * variable-length fields. They can attempt to borrow the maximum amount of + * data that they will need, then consume (see next method) what they + * actually use. Some transports will not support this method and others + * will fail occasionally, so protocols must be prepared to use read if + * borrow fails. + * + * @oaram buf A buffer where the data can be stored if needed. + * If borrow doesn't return buf, then the contents of + * buf after the call are undefined. + * @param len *len should initially contain the number of bytes to borrow. + * If borrow succeeds, *len will contain the number of bytes + * available in the returned pointer. This will be at least + * what was requested, but may be more if borrow returns + * a pointer to an internal buffer, rather than buf. + * If borrow fails, the contents of *len are undefined. + * @return If the borrow succeeds, return a pointer to the borrowed data. + * This might be equal to \c buf, or it might be a pointer into + * the transport's internal buffers. + * @throws TTransportException if an error occurs + */ + virtual const uint8_t* borrow(uint8_t* /* buf */, uint32_t* /* len */) { + return NULL; + } + + /** + * Remove len bytes from the transport. This should always follow a borrow + * of at least len bytes, and should always succeed. + * TODO(dreiss): Is there any transport that could borrow but fail to + * consume, or that would require a buffer to dump the consumed data? + * + * @param len How many bytes to consume + * @throws TTransportException If an error occurs + */ + virtual void consume(uint32_t /* len */) { + throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot consume."); + } + + protected: + /** + * Simple constructor. + */ + TTransport() {} +}; + +/** + * Generic factory class to make an input and output transport out of a + * source transport. Commonly used inside servers to make input and output + * streams out of raw clients. + * + */ +class TTransportFactory { + public: + TTransportFactory() {} + + virtual ~TTransportFactory() {} + + /** + * Default implementation does nothing, just returns the transport given. + */ + virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> trans) { + return trans; + } + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TTransportException.cpp b/lib/cpp/src/transport/TTransportException.cpp new file mode 100644 index 000000000..f0aaedc2f --- /dev/null +++ b/lib/cpp/src/transport/TTransportException.cpp @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <transport/TTransportException.h> +#include <boost/lexical_cast.hpp> +#include <cstring> +#include <config.h> + +using std::string; +using boost::lexical_cast; + +namespace apache { namespace thrift { namespace transport { + +}}} // apache::thrift::transport + diff --git a/lib/cpp/src/transport/TTransportException.h b/lib/cpp/src/transport/TTransportException.h new file mode 100644 index 000000000..330785cea --- /dev/null +++ b/lib/cpp/src/transport/TTransportException.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ +#define _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ 1 + +#include <string> +#include <Thrift.h> + +namespace apache { namespace thrift { namespace transport { + +/** + * Class to encapsulate all the possible types of transport errors that may + * occur in various transport systems. This provides a sort of generic + * wrapper around the shitty UNIX E_ error codes that lets a common code + * base of error handling to be used for various types of transports, i.e. + * pipes etc. + * + */ +class TTransportException : public apache::thrift::TException { + public: + /** + * Error codes for the various types of exceptions. + */ + enum TTransportExceptionType + { UNKNOWN = 0 + , NOT_OPEN = 1 + , ALREADY_OPEN = 2 + , TIMED_OUT = 3 + , END_OF_FILE = 4 + , INTERRUPTED = 5 + , BAD_ARGS = 6 + , CORRUPTED_DATA = 7 + , INTERNAL_ERROR = 8 + }; + + TTransportException() : + apache::thrift::TException(), + type_(UNKNOWN) {} + + TTransportException(TTransportExceptionType type) : + apache::thrift::TException(), + type_(type) {} + + TTransportException(const std::string& message) : + apache::thrift::TException(message), + type_(UNKNOWN) {} + + TTransportException(TTransportExceptionType type, const std::string& message) : + apache::thrift::TException(message), + type_(type) {} + + TTransportException(TTransportExceptionType type, + const std::string& message, + int errno_copy) : + apache::thrift::TException(message + ": " + TOutput::strerror_s(errno_copy)), + type_(type) {} + + virtual ~TTransportException() throw() {} + + /** + * Returns an error code that provides information about the type of error + * that has occurred. + * + * @return Error code + */ + TTransportExceptionType getType() const throw() { + return type_; + } + + virtual const char* what() const throw() { + if (message_.empty()) { + switch (type_) { + case UNKNOWN : return "TTransportException: Unknown transport exception"; + case NOT_OPEN : return "TTransportException: Transport not open"; + case ALREADY_OPEN : return "TTransportException: Transport already open"; + case TIMED_OUT : return "TTransportException: Timed out"; + case END_OF_FILE : return "TTransportException: End of file"; + case INTERRUPTED : return "TTransportException: Interrupted"; + case BAD_ARGS : return "TTransportException: Invalid arguments"; + case CORRUPTED_DATA : return "TTransportException: Corrupted Data"; + case INTERNAL_ERROR : return "TTransportException: Internal error"; + default : return "TTransportException: (Invalid exception type)"; + } + } else { + return message_.c_str(); + } + } + + protected: + /** Just like strerror_r but returns a C++ string object. */ + std::string strerror_s(int errno_copy); + + /** Error code */ + TTransportExceptionType type_; + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ diff --git a/lib/cpp/src/transport/TTransportUtils.cpp b/lib/cpp/src/transport/TTransportUtils.cpp new file mode 100644 index 000000000..a840fa6c1 --- /dev/null +++ b/lib/cpp/src/transport/TTransportUtils.cpp @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <transport/TTransportUtils.h> + +using std::string; + +namespace apache { namespace thrift { namespace transport { + +uint32_t TPipedTransport::read(uint8_t* buf, uint32_t len) { + uint32_t need = len; + + // We don't have enough data yet + if (rLen_-rPos_ < need) { + // Copy out whatever we have + if (rLen_-rPos_ > 0) { + memcpy(buf, rBuf_+rPos_, rLen_-rPos_); + need -= rLen_-rPos_; + buf += rLen_-rPos_; + rPos_ = rLen_; + } + + // Double the size of the underlying buffer if it is full + if (rLen_ == rBufSize_) { + rBufSize_ *=2; + rBuf_ = (uint8_t *)std::realloc(rBuf_, sizeof(uint8_t) * rBufSize_); + } + + // try to fill up the buffer + rLen_ += srcTrans_->read(rBuf_+rPos_, rBufSize_ - rPos_); + } + + + // Hand over whatever we have + uint32_t give = need; + if (rLen_-rPos_ < give) { + give = rLen_-rPos_; + } + if (give > 0) { + memcpy(buf, rBuf_+rPos_, give); + rPos_ += give; + need -= give; + } + + return (len - need); +} + +void TPipedTransport::write(const uint8_t* buf, uint32_t len) { + if (len == 0) { + return; + } + + // Make the buffer as big as it needs to be + if ((len + wLen_) >= wBufSize_) { + uint32_t newBufSize = wBufSize_*2; + while ((len + wLen_) >= newBufSize) { + newBufSize *= 2; + } + wBuf_ = (uint8_t *)std::realloc(wBuf_, sizeof(uint8_t) * newBufSize); + wBufSize_ = newBufSize; + } + + // Copy into the buffer + memcpy(wBuf_ + wLen_, buf, len); + wLen_ += len; +} + +void TPipedTransport::flush() { + // Write out any data waiting in the write buffer + if (wLen_ > 0) { + srcTrans_->write(wBuf_, wLen_); + wLen_ = 0; + } + + // Flush the underlying transport + srcTrans_->flush(); +} + +TPipedFileReaderTransport::TPipedFileReaderTransport(boost::shared_ptr<TFileReaderTransport> srcTrans, boost::shared_ptr<TTransport> dstTrans) + : TPipedTransport(srcTrans, dstTrans), + srcTrans_(srcTrans) { +} + +TPipedFileReaderTransport::~TPipedFileReaderTransport() { +} + +bool TPipedFileReaderTransport::isOpen() { + return TPipedTransport::isOpen(); +} + +bool TPipedFileReaderTransport::peek() { + return TPipedTransport::peek(); +} + +void TPipedFileReaderTransport::open() { + TPipedTransport::open(); +} + +void TPipedFileReaderTransport::close() { + TPipedTransport::close(); +} + +uint32_t TPipedFileReaderTransport::read(uint8_t* buf, uint32_t len) { + return TPipedTransport::read(buf, len); +} + +uint32_t TPipedFileReaderTransport::readAll(uint8_t* buf, uint32_t len) { + uint32_t have = 0; + uint32_t get = 0; + + while (have < len) { + get = read(buf+have, len-have); + if (get <= 0) { + throw TEOFException(); + } + have += get; + } + + return have; +} + +void TPipedFileReaderTransport::readEnd() { + TPipedTransport::readEnd(); +} + +void TPipedFileReaderTransport::write(const uint8_t* buf, uint32_t len) { + TPipedTransport::write(buf, len); +} + +void TPipedFileReaderTransport::writeEnd() { + TPipedTransport::writeEnd(); +} + +void TPipedFileReaderTransport::flush() { + TPipedTransport::flush(); +} + +int32_t TPipedFileReaderTransport::getReadTimeout() { + return srcTrans_->getReadTimeout(); +} + +void TPipedFileReaderTransport::setReadTimeout(int32_t readTimeout) { + srcTrans_->setReadTimeout(readTimeout); +} + +uint32_t TPipedFileReaderTransport::getNumChunks() { + return srcTrans_->getNumChunks(); +} + +uint32_t TPipedFileReaderTransport::getCurChunk() { + return srcTrans_->getCurChunk(); +} + +void TPipedFileReaderTransport::seekToChunk(int32_t chunk) { + srcTrans_->seekToChunk(chunk); +} + +void TPipedFileReaderTransport::seekToEnd() { + srcTrans_->seekToEnd(); +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TTransportUtils.h b/lib/cpp/src/transport/TTransportUtils.h new file mode 100644 index 000000000..d65c91674 --- /dev/null +++ b/lib/cpp/src/transport/TTransportUtils.h @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ +#define _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ 1 + +#include <cstdlib> +#include <cstring> +#include <string> +#include <algorithm> +#include <transport/TTransport.h> +// Include the buffered transports that used to be defined here. +#include <transport/TBufferTransports.h> +#include <transport/TFileTransport.h> + +namespace apache { namespace thrift { namespace transport { + +/** + * The null transport is a dummy transport that doesn't actually do anything. + * It's sort of an analogy to /dev/null, you can never read anything from it + * and it will let you write anything you want to it, though it won't actually + * go anywhere. + * + */ +class TNullTransport : public TTransport { + public: + TNullTransport() {} + + ~TNullTransport() {} + + bool isOpen() { + return true; + } + + void open() {} + + void write(const uint8_t* /* buf */, uint32_t /* len */) { + return; + } + +}; + + +/** + * TPipedTransport. This transport allows piping of a request from one + * transport to another either when readEnd() or writeEnd(). The typical + * use case for this is to log a request or a reply to disk. + * The underlying buffer expands to a keep a copy of the entire + * request/response. + * + */ +class TPipedTransport : virtual public TTransport { + public: + TPipedTransport(boost::shared_ptr<TTransport> srcTrans, + boost::shared_ptr<TTransport> dstTrans) : + srcTrans_(srcTrans), + dstTrans_(dstTrans), + rBufSize_(512), rPos_(0), rLen_(0), + wBufSize_(512), wLen_(0) { + + // default is to to pipe the request when readEnd() is called + pipeOnRead_ = true; + pipeOnWrite_ = false; + + rBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * rBufSize_); + wBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * wBufSize_); + } + + TPipedTransport(boost::shared_ptr<TTransport> srcTrans, + boost::shared_ptr<TTransport> dstTrans, + uint32_t sz) : + srcTrans_(srcTrans), + dstTrans_(dstTrans), + rBufSize_(512), rPos_(0), rLen_(0), + wBufSize_(sz), wLen_(0) { + + rBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * rBufSize_); + wBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * wBufSize_); + } + + ~TPipedTransport() { + std::free(rBuf_); + std::free(wBuf_); + } + + bool isOpen() { + return srcTrans_->isOpen(); + } + + bool peek() { + if (rPos_ >= rLen_) { + // Double the size of the underlying buffer if it is full + if (rLen_ == rBufSize_) { + rBufSize_ *=2; + rBuf_ = (uint8_t *)std::realloc(rBuf_, sizeof(uint8_t) * rBufSize_); + } + + // try to fill up the buffer + rLen_ += srcTrans_->read(rBuf_+rPos_, rBufSize_ - rPos_); + } + return (rLen_ > rPos_); + } + + + void open() { + srcTrans_->open(); + } + + void close() { + srcTrans_->close(); + } + + void setPipeOnRead(bool pipeVal) { + pipeOnRead_ = pipeVal; + } + + void setPipeOnWrite(bool pipeVal) { + pipeOnWrite_ = pipeVal; + } + + uint32_t read(uint8_t* buf, uint32_t len); + + void readEnd() { + + if (pipeOnRead_) { + dstTrans_->write(rBuf_, rPos_); + dstTrans_->flush(); + } + + srcTrans_->readEnd(); + + // If requests are being pipelined, copy down our read-ahead data, + // then reset our state. + int read_ahead = rLen_ - rPos_; + memcpy(rBuf_, rBuf_ + rPos_, read_ahead); + rPos_ = 0; + rLen_ = read_ahead; + } + + void write(const uint8_t* buf, uint32_t len); + + void writeEnd() { + if (pipeOnWrite_) { + dstTrans_->write(wBuf_, wLen_); + dstTrans_->flush(); + } + } + + void flush(); + + boost::shared_ptr<TTransport> getTargetTransport() { + return dstTrans_; + } + + protected: + boost::shared_ptr<TTransport> srcTrans_; + boost::shared_ptr<TTransport> dstTrans_; + + uint8_t* rBuf_; + uint32_t rBufSize_; + uint32_t rPos_; + uint32_t rLen_; + + uint8_t* wBuf_; + uint32_t wBufSize_; + uint32_t wLen_; + + bool pipeOnRead_; + bool pipeOnWrite_; +}; + + +/** + * Wraps a transport into a pipedTransport instance. + * + */ +class TPipedTransportFactory : public TTransportFactory { + public: + TPipedTransportFactory() {} + TPipedTransportFactory(boost::shared_ptr<TTransport> dstTrans) { + initializeTargetTransport(dstTrans); + } + virtual ~TPipedTransportFactory() {} + + /** + * Wraps the base transport into a piped transport. + */ + virtual boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> srcTrans) { + return boost::shared_ptr<TTransport>(new TPipedTransport(srcTrans, dstTrans_)); + } + + virtual void initializeTargetTransport(boost::shared_ptr<TTransport> dstTrans) { + if (dstTrans_.get() == NULL) { + dstTrans_ = dstTrans; + } else { + throw TException("Target transport already initialized"); + } + } + + protected: + boost::shared_ptr<TTransport> dstTrans_; +}; + +/** + * TPipedFileTransport. This is just like a TTransport, except that + * it is a templatized class, so that clients who rely on a specific + * TTransport can still access the original transport. + * + */ +class TPipedFileReaderTransport : public TPipedTransport, + public TFileReaderTransport { + public: + TPipedFileReaderTransport(boost::shared_ptr<TFileReaderTransport> srcTrans, boost::shared_ptr<TTransport> dstTrans); + + ~TPipedFileReaderTransport(); + + // TTransport functions + bool isOpen(); + bool peek(); + void open(); + void close(); + uint32_t read(uint8_t* buf, uint32_t len); + uint32_t readAll(uint8_t* buf, uint32_t len); + void readEnd(); + void write(const uint8_t* buf, uint32_t len); + void writeEnd(); + void flush(); + + // TFileReaderTransport functions + int32_t getReadTimeout(); + void setReadTimeout(int32_t readTimeout); + uint32_t getNumChunks(); + uint32_t getCurChunk(); + void seekToChunk(int32_t chunk); + void seekToEnd(); + + protected: + // shouldn't be used + TPipedFileReaderTransport(); + boost::shared_ptr<TFileReaderTransport> srcTrans_; +}; + +/** + * Creates a TPipedFileReaderTransport from a filepath and a destination transport + * + */ +class TPipedFileReaderTransportFactory : public TPipedTransportFactory { + public: + TPipedFileReaderTransportFactory() {} + TPipedFileReaderTransportFactory(boost::shared_ptr<TTransport> dstTrans) + : TPipedTransportFactory(dstTrans) + {} + virtual ~TPipedFileReaderTransportFactory() {} + + boost::shared_ptr<TTransport> getTransport(boost::shared_ptr<TTransport> srcTrans) { + boost::shared_ptr<TFileReaderTransport> pFileReaderTransport = boost::dynamic_pointer_cast<TFileReaderTransport>(srcTrans); + if (pFileReaderTransport.get() != NULL) { + return getFileReaderTransport(pFileReaderTransport); + } else { + return boost::shared_ptr<TTransport>(); + } + } + + boost::shared_ptr<TFileReaderTransport> getFileReaderTransport(boost::shared_ptr<TFileReaderTransport> srcTrans) { + return boost::shared_ptr<TFileReaderTransport>(new TPipedFileReaderTransport(srcTrans, dstTrans_)); + } +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ diff --git a/lib/cpp/src/transport/TZlibTransport.cpp b/lib/cpp/src/transport/TZlibTransport.cpp new file mode 100644 index 000000000..2f14e906b --- /dev/null +++ b/lib/cpp/src/transport/TZlibTransport.cpp @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <cassert> +#include <cstring> +#include <algorithm> +#include <transport/TZlibTransport.h> +#include <zlib.h> + +using std::string; + +namespace apache { namespace thrift { namespace transport { + +// Don't call this outside of the constructor. +void TZlibTransport::initZlib() { + int rv; + bool r_init = false; + try { + rstream_ = new z_stream; + wstream_ = new z_stream; + + rstream_->zalloc = Z_NULL; + wstream_->zalloc = Z_NULL; + rstream_->zfree = Z_NULL; + wstream_->zfree = Z_NULL; + rstream_->opaque = Z_NULL; + wstream_->opaque = Z_NULL; + + rstream_->next_in = crbuf_; + wstream_->next_in = uwbuf_; + rstream_->next_out = urbuf_; + wstream_->next_out = cwbuf_; + rstream_->avail_in = 0; + wstream_->avail_in = 0; + rstream_->avail_out = urbuf_size_; + wstream_->avail_out = cwbuf_size_; + + rv = inflateInit(rstream_); + checkZlibRv(rv, rstream_->msg); + + // Have to set this flag so we know whether to de-initialize. + r_init = true; + + rv = deflateInit(wstream_, Z_DEFAULT_COMPRESSION); + checkZlibRv(rv, wstream_->msg); + } + + catch (...) { + if (r_init) { + rv = inflateEnd(rstream_); + checkZlibRvNothrow(rv, rstream_->msg); + } + // There is no way we can get here if wstream_ was initialized. + + throw; + } +} + +inline void TZlibTransport::checkZlibRv(int status, const char* message) { + if (status != Z_OK) { + throw TZlibTransportException(status, message); + } +} + +inline void TZlibTransport::checkZlibRvNothrow(int status, const char* message) { + if (status != Z_OK) { + string output = "TZlibTransport: zlib failure in destructor: " + + TZlibTransportException::errorMessage(status, message); + GlobalOutput(output.c_str()); + } +} + +TZlibTransport::~TZlibTransport() { + int rv; + rv = inflateEnd(rstream_); + checkZlibRvNothrow(rv, rstream_->msg); + rv = deflateEnd(wstream_); + checkZlibRvNothrow(rv, wstream_->msg); + + delete[] urbuf_; + delete[] crbuf_; + delete[] uwbuf_; + delete[] cwbuf_; + delete rstream_; + delete wstream_; +} + +bool TZlibTransport::isOpen() { + return (readAvail() > 0) || transport_->isOpen(); +} + +// READING STRATEGY +// +// We have two buffers for reading: one containing the compressed data (crbuf_) +// and one containing the uncompressed data (urbuf_). When read is called, +// we repeat the following steps until we have satisfied the request: +// - Copy data from urbuf_ into the caller's buffer. +// - If we had enough, return. +// - If urbuf_ is empty, read some data into it from the underlying transport. +// - Inflate data from crbuf_ into urbuf_. +// +// In standalone objects, we set input_ended_ to true when inflate returns +// Z_STREAM_END. This allows to make sure that a checksum was verified. + +inline int TZlibTransport::readAvail() { + return urbuf_size_ - rstream_->avail_out - urpos_; +} + +uint32_t TZlibTransport::read(uint8_t* buf, uint32_t len) { + int need = len; + + // TODO(dreiss): Skip urbuf on big reads. + + while (true) { + // Copy out whatever we have available, then give them the min of + // what we have and what they want, then advance indices. + int give = std::min(readAvail(), need); + memcpy(buf, urbuf_ + urpos_, give); + need -= give; + buf += give; + urpos_ += give; + + // If they were satisfied, we are done. + if (need == 0) { + return len; + } + + // If we get to this point, we need to get some more data. + + // If zlib has reported the end of a stream, we can't really do any more. + if (input_ended_) { + return len - need; + } + + // The uncompressed read buffer is empty, so reset the stream fields. + rstream_->next_out = urbuf_; + rstream_->avail_out = urbuf_size_; + urpos_ = 0; + + // If we don't have any more compressed data available, + // read some from the underlying transport. + if (rstream_->avail_in == 0) { + uint32_t got = transport_->read(crbuf_, crbuf_size_); + if (got == 0) { + return len - need; + } + rstream_->next_in = crbuf_; + rstream_->avail_in = got; + } + + // We have some compressed data now. Uncompress it. + int zlib_rv = inflate(rstream_, Z_SYNC_FLUSH); + + if (zlib_rv == Z_STREAM_END) { + if (standalone_) { + input_ended_ = true; + } + } else { + checkZlibRv(zlib_rv, rstream_->msg); + } + + // Okay. The read buffer should have whatever we can give it now. + // Loop back to the start and try to give some more. + } +} + + +// WRITING STRATEGY +// +// We buffer up small writes before sending them to zlib, so our logic is: +// - Is the write big? +// - Send the buffer to zlib. +// - Send this data to zlib. +// - Is the write small? +// - Is there insufficient space in the buffer for it? +// - Send the buffer to zlib. +// - Copy the data to the buffer. +// +// We have two buffers for writing also: the uncompressed buffer (mentioned +// above) and the compressed buffer. When sending data to zlib we loop over +// the following until the source (uncompressed buffer or big write) is empty: +// - Is there no more space in the compressed buffer? +// - Write the compressed buffer to the underlying transport. +// - Deflate from the source into the compressed buffer. + +void TZlibTransport::write(const uint8_t* buf, uint32_t len) { + // zlib's "deflate" function has enough logic in it that I think + // we're better off (performance-wise) buffering up small writes. + if ((int)len > MIN_DIRECT_DEFLATE_SIZE) { + flushToZlib(uwbuf_, uwpos_); + uwpos_ = 0; + flushToZlib(buf, len); + } else if (len > 0) { + if (uwbuf_size_ - uwpos_ < (int)len) { + flushToZlib(uwbuf_, uwpos_); + uwpos_ = 0; + } + memcpy(uwbuf_ + uwpos_, buf, len); + uwpos_ += len; + } +} + +void TZlibTransport::flush() { + flushToZlib(uwbuf_, uwpos_, true); + assert((int)wstream_->avail_out != cwbuf_size_); + transport_->write(cwbuf_, cwbuf_size_ - wstream_->avail_out); + transport_->flush(); +} + +void TZlibTransport::flushToZlib(const uint8_t* buf, int len, bool finish) { + int flush = (finish ? Z_FINISH : Z_NO_FLUSH); + + wstream_->next_in = const_cast<uint8_t*>(buf); + wstream_->avail_in = len; + + while (wstream_->avail_in > 0 || finish) { + // If our ouput buffer is full, flush to the underlying transport. + if (wstream_->avail_out == 0) { + transport_->write(cwbuf_, cwbuf_size_); + wstream_->next_out = cwbuf_; + wstream_->avail_out = cwbuf_size_; + } + + int zlib_rv = deflate(wstream_, flush); + + if (finish && zlib_rv == Z_STREAM_END) { + assert(wstream_->avail_in == 0); + break; + } + + checkZlibRv(zlib_rv, wstream_->msg); + } +} + +const uint8_t* TZlibTransport::borrow(uint8_t* buf, uint32_t* len) { + // Don't try to be clever with shifting buffers. + // If we have enough data, give a pointer to it, + // otherwise let the protcol use its slow path. + if (readAvail() >= (int)*len) { + *len = (uint32_t)readAvail(); + return urbuf_ + urpos_; + } + return NULL; +} + +void TZlibTransport::consume(uint32_t len) { + if (readAvail() >= (int)len) { + urpos_ += len; + } else { + throw TTransportException(TTransportException::BAD_ARGS, + "consume did not follow a borrow."); + } +} + +void TZlibTransport::verifyChecksum() { + if (!standalone_) { + throw TTransportException( + TTransportException::BAD_ARGS, + "TZLibTransport can only verify checksums for standalone objects."); + } + + if (!input_ended_) { + // This should only be called when reading is complete, + // but it's possible that the whole checksum has not been fed to zlib yet. + // We try to read an extra byte here to force zlib to finish the stream. + // It might not always be easy to "unread" this byte, + // but we throw an exception if we get it, which is not really + // a recoverable error, so it doesn't matter. + uint8_t buf[1]; + uint32_t got = this->read(buf, sizeof(buf)); + if (got || !input_ended_) { + throw TTransportException( + TTransportException::CORRUPTED_DATA, + "Zlib stream not complete."); + } + } + + // If the checksum had been bad, we would have gotten an error while + // inflating. +} + + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TZlibTransport.h b/lib/cpp/src/transport/TZlibTransport.h new file mode 100644 index 000000000..1439d9de7 --- /dev/null +++ b/lib/cpp/src/transport/TZlibTransport.h @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1 + +#include <boost/lexical_cast.hpp> +#include <transport/TTransport.h> + +struct z_stream_s; + +namespace apache { namespace thrift { namespace transport { + +class TZlibTransportException : public TTransportException { + public: + TZlibTransportException(int status, const char* msg) : + TTransportException(TTransportException::INTERNAL_ERROR, + errorMessage(status, msg)), + zlib_status_(status), + zlib_msg_(msg == NULL ? "(null)" : msg) {} + + virtual ~TZlibTransportException() throw() {} + + int getZlibStatus() { return zlib_status_; } + std::string getZlibMessage() { return zlib_msg_; } + + static std::string errorMessage(int status, const char* msg) { + std::string rv = "zlib error: "; + if (msg) { + rv += msg; + } else { + rv += "(no message)"; + } + rv += " (status = "; + rv += boost::lexical_cast<std::string>(status); + rv += ")"; + return rv; + } + + int zlib_status_; + std::string zlib_msg_; +}; + +/** + * This transport uses zlib's compressed format on the "far" side. + * + * There are two kinds of TZlibTransport objects: + * - Standalone objects are used to encode self-contained chunks of data + * (like structures). They include checksums. + * - Non-standalone transports are used for RPC. They are not implemented yet. + * + * TODO(dreiss): Don't do an extra copy of the compressed data if + * the underlying transport is TBuffered or TMemory. + * + */ +class TZlibTransport : public TTransport { + public: + + /** + * @param transport The transport to read compressed data from + * and write compressed data to. + * @param use_for_rpc True if this object will be used for RPC, + * false if this is a standalone object. + * @param urbuf_size Uncompressed buffer size for reading. + * @param crbuf_size Compressed buffer size for reading. + * @param uwbuf_size Uncompressed buffer size for writing. + * @param cwbuf_size Compressed buffer size for writing. + * + * TODO(dreiss): Write a constructor that isn't a pain. + */ + TZlibTransport(boost::shared_ptr<TTransport> transport, + bool use_for_rpc, + int urbuf_size = DEFAULT_URBUF_SIZE, + int crbuf_size = DEFAULT_CRBUF_SIZE, + int uwbuf_size = DEFAULT_UWBUF_SIZE, + int cwbuf_size = DEFAULT_CWBUF_SIZE) : + transport_(transport), + standalone_(!use_for_rpc), + urpos_(0), + uwpos_(0), + input_ended_(false), + output_flushed_(false), + urbuf_size_(urbuf_size), + crbuf_size_(crbuf_size), + uwbuf_size_(uwbuf_size), + cwbuf_size_(cwbuf_size), + urbuf_(NULL), + crbuf_(NULL), + uwbuf_(NULL), + cwbuf_(NULL), + rstream_(NULL), + wstream_(NULL) + { + + if (!standalone_) { + throw TTransportException( + TTransportException::BAD_ARGS, + "TZLibTransport has not been tested for RPC."); + } + + if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE) { + // Have to copy this into a local because of a linking issue. + int minimum = MIN_DIRECT_DEFLATE_SIZE; + throw TTransportException( + TTransportException::BAD_ARGS, + "TZLibTransport: uncompressed write buffer must be at least" + + boost::lexical_cast<std::string>(minimum) + "."); + } + + try { + urbuf_ = new uint8_t[urbuf_size]; + crbuf_ = new uint8_t[crbuf_size]; + uwbuf_ = new uint8_t[uwbuf_size]; + cwbuf_ = new uint8_t[cwbuf_size]; + + // Don't call this outside of the constructor. + initZlib(); + + } catch (...) { + delete[] urbuf_; + delete[] crbuf_; + delete[] uwbuf_; + delete[] cwbuf_; + throw; + } + } + + // Don't call this outside of the constructor. + void initZlib(); + + ~TZlibTransport(); + + bool isOpen(); + + void open() { + transport_->open(); + } + + void close() { + transport_->close(); + } + + uint32_t read(uint8_t* buf, uint32_t len); + + void write(const uint8_t* buf, uint32_t len); + + void flush(); + + const uint8_t* borrow(uint8_t* buf, uint32_t* len); + + void consume(uint32_t len); + + void verifyChecksum(); + + /** + * TODO(someone_smart): Choose smart defaults. + */ + static const int DEFAULT_URBUF_SIZE = 128; + static const int DEFAULT_CRBUF_SIZE = 1024; + static const int DEFAULT_UWBUF_SIZE = 128; + static const int DEFAULT_CWBUF_SIZE = 1024; + + protected: + + inline void checkZlibRv(int status, const char* msg); + inline void checkZlibRvNothrow(int status, const char* msg); + inline int readAvail(); + void flushToZlib(const uint8_t* buf, int len, bool finish = false); + + // Writes smaller than this are buffered up. + // Larger (or equal) writes are dumped straight to zlib. + static const int MIN_DIRECT_DEFLATE_SIZE = 32; + + boost::shared_ptr<TTransport> transport_; + bool standalone_; + + int urpos_; + int uwpos_; + + /// True iff zlib has reached the end of a stream. + /// This is only ever true in standalone protcol objects. + bool input_ended_; + /// True iff we have flushed the output stream. + /// This is only ever true in standalone protcol objects. + bool output_flushed_; + + int urbuf_size_; + int crbuf_size_; + int uwbuf_size_; + int cwbuf_size_; + + uint8_t* urbuf_; + uint8_t* crbuf_; + uint8_t* uwbuf_; + uint8_t* cwbuf_; + + struct z_stream_s* rstream_; + struct z_stream_s* wstream_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ diff --git a/lib/cpp/thrift-nb.pc.in b/lib/cpp/thrift-nb.pc.in new file mode 100644 index 000000000..ae0518874 --- /dev/null +++ b/lib/cpp/thrift-nb.pc.in @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: Thrift +Description: Thrift Nonblocking API +Version: @VERSION@ +Requires: thrift = @VERSION@ +Libs: -L${libdir} -lthriftnb +Cflags: -I${includedir}/thrift diff --git a/lib/cpp/thrift-z.pc.in b/lib/cpp/thrift-z.pc.in new file mode 100644 index 000000000..72f46bf94 --- /dev/null +++ b/lib/cpp/thrift-z.pc.in @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: Thrift +Description: Thrift Zlib API +Version: @VERSION@ +Requires: thrift = @VERSION@ +Libs: -L${libdir} -lthriftz +Cflags: -I${includedir}/thrift diff --git a/lib/cpp/thrift.pc.in b/lib/cpp/thrift.pc.in new file mode 100644 index 000000000..7aec09f1d --- /dev/null +++ b/lib/cpp/thrift.pc.in @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: Thrift +Description: Thrift C++ API +Version: @VERSION@ +Libs: -L${libdir} -lthrift +Cflags: -I${includedir}/thrift diff --git a/lib/csharp/Makefile.am b/lib/csharp/Makefile.am new file mode 100644 index 000000000..4047011cb --- /dev/null +++ b/lib/csharp/Makefile.am @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +THRIFTCODE= \ + src/Collections/THashSet.cs \ + src/Protocol/TBase.cs \ + src/Protocol/TProtocolException.cs \ + src/Protocol/TProtocolFactory.cs \ + src/Protocol/TList.cs \ + src/Protocol/TSet.cs \ + src/Protocol/TMap.cs \ + src/Protocol/TProtocolUtil.cs \ + src/Protocol/TMessageType.cs \ + src/Protocol/TProtocol.cs \ + src/Protocol/TType.cs \ + src/Protocol/TField.cs \ + src/Protocol/TMessage.cs \ + src/Protocol/TStruct.cs \ + src/Protocol/TBinaryProtocol.cs \ + src/Server/TThreadedServer.cs \ + src/Server/TThreadPoolServer.cs \ + src/Server/TSimpleServer.cs \ + src/Server/TServer.cs \ + src/Transport/TBufferedTransport.cs \ + src/Transport/TTransport.cs \ + src/Transport/TSocket.cs \ + src/Transport/TTransportException.cs \ + src/Transport/TStreamTransport.cs \ + src/Transport/TServerTransport.cs \ + src/Transport/TServerSocket.cs \ + src/Transport/TTransportFactory.cs \ + src/TProcessor.cs \ + src/TApplicationException.cs + + +CSC=gmcs + +if NET_2_0 +MONO_DEFINES=/d:NET_2_0 +endif + +all-local: Thrift.dll + +Thrift.dll: $(THRIFTCODE) + $(CSC) $(THRIFTCODE) /out:Thrift.dll /target:library $(MONO_DEFINES) + +clean-local: + $(RM) Thrift.dll + +EXTRA_DIST = \ + $(THRIFTCODE) \ + ThriftMSBuildTask \ + src/Thrift.csproj \ + src/Thrift.sln diff --git a/lib/csharp/README b/lib/csharp/README new file mode 100644 index 000000000..b7dc5de32 --- /dev/null +++ b/lib/csharp/README @@ -0,0 +1,26 @@ +Thrift C# Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with C# +==================== + +Thrift requires Mono >= 1.2.6 or .NET framework >= 3.5 diff --git a/lib/csharp/ThriftMSBuildTask/Properties/AssemblyInfo.cs b/lib/csharp/ThriftMSBuildTask/Properties/AssemblyInfo.cs new file mode 100644 index 000000000..d79c2039f --- /dev/null +++ b/lib/csharp/ThriftMSBuildTask/Properties/AssemblyInfo.cs @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("ThriftMSBuildTask")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("ThriftMSBuildTask")] +[assembly: AssemblyCopyright("Copyright © 2009 The Apache Software Foundation")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("5095e09d-7b95-4be1-b250-e1c1db1c485e")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("1.0.*")] +[assembly: AssemblyVersion("1.0.0.0")] +[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/lib/csharp/ThriftMSBuildTask/ThriftBuild.cs b/lib/csharp/ThriftMSBuildTask/ThriftBuild.cs new file mode 100644 index 000000000..4389e0a6e --- /dev/null +++ b/lib/csharp/ThriftMSBuildTask/ThriftBuild.cs @@ -0,0 +1,242 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.Build.Framework; +using Microsoft.Build.Utilities; +using Microsoft.Build.Tasks; +using System.IO; +using System.Diagnostics; + +namespace ThriftMSBuildTask +{ + /// <summary> + /// MSBuild Task to generate csharp from .thrift files, and compile the code into a library: ThriftImpl.dll + /// </summary> + public class ThriftBuild : Task + { + /// <summary> + /// The full path to the thrift.exe compiler + /// </summary> + [Required] + public ITaskItem ThriftExecutable + { + get; + set; + } + + /// <summary> + /// The full path to a thrift.dll C# library + /// </summary> + [Required] + public ITaskItem ThriftLibrary + { + get; + set; + } + + /// <summary> + /// A direcotry containing .thrift files + /// </summary> + [Required] + public ITaskItem ThriftDefinitionDir + { + get; + set; + } + + /// <summary> + /// The name of the auto-gen and compiled thrift library. It will placed in + /// the same directory as ThriftLibrary + /// </summary> + [Required] + public ITaskItem OutputName + { + get; + set; + } + + /// <summary> + /// The full path to the compiled ThriftLibrary. This allows msbuild tasks to use this + /// output as a variable for use elsewhere. + /// </summary> + [Output] + public ITaskItem ThriftImplementation + { + get { return thriftImpl; } + } + + private ITaskItem thriftImpl; + private const string lastCompilationName = "LAST_COMP_TIMESTAMP"; + + //use the Message Build Task to write something to build log + private void LogMessage(string text, MessageImportance importance) + { + Message m = new Message(); + m.Text = text; + m.Importance = importance.ToString(); + m.BuildEngine = this.BuildEngine; + m.Execute(); + } + + //recursively find .cs files in srcDir, paths should initially be non-null and empty + private void FindSourcesHelper(string srcDir, List<string> paths) + { + string[] files = Directory.GetFiles(srcDir, "*.cs"); + foreach (string f in files) + { + paths.Add(f); + } + string[] dirs = Directory.GetDirectories(srcDir); + foreach (string dir in dirs) + { + FindSourcesHelper(dir, paths); + } + } + + /// <summary> + /// Quote paths with spaces + /// </summary> + private string SafePath(string path) + { + if (path.Contains(' ') && !path.StartsWith("\"")) + { + return "\"" + path + "\""; + } + return path; + } + + private ITaskItem[] FindSources(string srcDir) + { + List<string> files = new List<string>(); + FindSourcesHelper(srcDir, files); + ITaskItem[] items = new ITaskItem[files.Count]; + for (int i = 0; i < items.Length; i++) + { + items[i] = new TaskItem(files[i]); + } + return items; + } + + private string LastWriteTime(string defDir) + { + string[] files = Directory.GetFiles(defDir, "*.thrift"); + DateTime d = (new DirectoryInfo(defDir)).LastWriteTime; + foreach(string file in files) + { + FileInfo f = new FileInfo(file); + DateTime curr = f.LastWriteTime; + if (DateTime.Compare(curr, d) > 0) + { + d = curr; + } + } + return d.ToFileTimeUtc().ToString(); + } + + public override bool Execute() + { + string defDir = SafePath(ThriftDefinitionDir.ItemSpec); + //look for last compilation timestamp + string lastBuildPath = Path.Combine(defDir, lastCompilationName); + DirectoryInfo defDirInfo = new DirectoryInfo(defDir); + string lastWrite = LastWriteTime(defDir); + if (File.Exists(lastBuildPath)) + { + string lastComp = File.ReadAllText(lastBuildPath); + //don't recompile if the thrift library has been updated since lastComp + FileInfo f = new FileInfo(ThriftLibrary.ItemSpec); + string thriftLibTime = f.LastWriteTimeUtc.ToFileTimeUtc().ToString(); + if (lastComp.CompareTo(thriftLibTime) < 0) + { + //new thrift library, do a compile + lastWrite = thriftLibTime; + } + else if (lastComp == lastWrite || (lastComp == thriftLibTime && lastComp.CompareTo(lastWrite) > 0)) + { + //the .thrift dir hasn't been written to since last compilation, don't need to do anything + LogMessage("ThriftImpl up-to-date", MessageImportance.High); + return true; + } + } + + //find the directory of the thriftlibrary (that's where output will go) + FileInfo thriftLibInfo = new FileInfo(SafePath(ThriftLibrary.ItemSpec)); + string thriftDir = thriftLibInfo.Directory.FullName; + + string genDir = Path.Combine(thriftDir, "gen-csharp"); + if (Directory.Exists(genDir)) + { + try + { + Directory.Delete(genDir, true); + } + catch { /*eh i tried, just over-write now*/} + } + + //run the thrift executable to generate C# + foreach (string thriftFile in Directory.GetFiles(defDir, "*.thrift")) + { + LogMessage("Generating code for: " + thriftFile, MessageImportance.Normal); + Process p = new Process(); + p.StartInfo.FileName = SafePath(ThriftExecutable.ItemSpec); + p.StartInfo.Arguments = "--gen csharp -o " + SafePath(thriftDir) + " -r " + thriftFile; + p.StartInfo.UseShellExecute = false; + p.StartInfo.CreateNoWindow = true; + p.StartInfo.RedirectStandardOutput = false; + p.Start(); + p.WaitForExit(); + if (p.ExitCode != 0) + { + LogMessage("thrift.exe failed to compile " + thriftFile, MessageImportance.High); + return false; + } + if (p.ExitCode != 0) + { + LogMessage("thrift.exe failed to compile " + thriftFile, MessageImportance.High); + return false; + } + } + + Csc csc = new Csc(); + csc.TargetType = "library"; + csc.References = new ITaskItem[] { new TaskItem(ThriftLibrary.ItemSpec) }; + csc.EmitDebugInformation = true; + string outputPath = Path.Combine(thriftDir, OutputName.ItemSpec); + csc.OutputAssembly = new TaskItem(outputPath); + csc.Sources = FindSources(Path.Combine(thriftDir, "gen-csharp")); + csc.BuildEngine = this.BuildEngine; + LogMessage("Compiling generated cs...", MessageImportance.Normal); + if (!csc.Execute()) + { + return false; + } + + //write file to defDir to indicate a build was successfully completed + File.WriteAllText(lastBuildPath, lastWrite); + + thriftImpl = new TaskItem(outputPath); + + return true; + } + } +} diff --git a/lib/csharp/ThriftMSBuildTask/ThriftMSBuildTask.csproj b/lib/csharp/ThriftMSBuildTask/ThriftMSBuildTask.csproj new file mode 100644 index 000000000..02110eae6 --- /dev/null +++ b/lib/csharp/ThriftMSBuildTask/ThriftMSBuildTask.csproj @@ -0,0 +1,66 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="3.5" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <PropertyGroup> + <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> + <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> + <ProductVersion>9.0.21022</ProductVersion> + <SchemaVersion>2.0</SchemaVersion> + <ProjectGuid>{EC0A0231-66EA-4593-A792-C6CA3BB8668E}</ProjectGuid> + <OutputType>Library</OutputType> + <AppDesignerFolder>Properties</AppDesignerFolder> + <RootNamespace>ThriftMSBuildTask</RootNamespace> + <AssemblyName>ThriftMSBuildTask</AssemblyName> + <TargetFrameworkVersion>v3.5</TargetFrameworkVersion> + <FileAlignment>512</FileAlignment> + <SccProjectName>SAK</SccProjectName> + <SccLocalPath>SAK</SccLocalPath> + <SccAuxPath>SAK</SccAuxPath> + <SccProvider>SAK</SccProvider> + </PropertyGroup> + <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' "> + <DebugSymbols>true</DebugSymbols> + <DebugType>full</DebugType> + <Optimize>false</Optimize> + <OutputPath>bin\Debug\</OutputPath> + <DefineConstants>DEBUG;TRACE</DefineConstants> + <ErrorReport>prompt</ErrorReport> + <WarningLevel>4</WarningLevel> + </PropertyGroup> + <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> + <DebugType>pdbonly</DebugType> + <Optimize>true</Optimize> + <OutputPath>bin\Release\</OutputPath> + <DefineConstants>TRACE</DefineConstants> + <ErrorReport>prompt</ErrorReport> + <WarningLevel>4</WarningLevel> + </PropertyGroup> + <ItemGroup> + <Reference Include="Microsoft.Build.Framework" /> + <Reference Include="Microsoft.Build.Tasks" /> + <Reference Include="Microsoft.Build.Utilities" /> + <Reference Include="System" /> + <Reference Include="System.Core"> + <RequiredTargetFramework>3.5</RequiredTargetFramework> + </Reference> + <Reference Include="System.Xml.Linq"> + <RequiredTargetFramework>3.5</RequiredTargetFramework> + </Reference> + <Reference Include="System.Data.DataSetExtensions"> + <RequiredTargetFramework>3.5</RequiredTargetFramework> + </Reference> + <Reference Include="System.Data" /> + <Reference Include="System.Xml" /> + </ItemGroup> + <ItemGroup> + <Compile Include="ThriftBuild.cs" /> + <Compile Include="Properties\AssemblyInfo.cs" /> + </ItemGroup> + <Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" /> + <!-- To modify your build process, add your task inside one of the targets below and uncomment it. + Other similar extension points exist, see Microsoft.Common.targets. + <Target Name="BeforeBuild"> + </Target> + <Target Name="AfterBuild"> + </Target> + --> +</Project> diff --git a/lib/csharp/src/Collections/THashSet.cs b/lib/csharp/src/Collections/THashSet.cs new file mode 100644 index 000000000..a9957693e --- /dev/null +++ b/lib/csharp/src/Collections/THashSet.cs @@ -0,0 +1,142 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Thrift.Collections +{ + public class THashSet<T> : ICollection<T> + { +#if NET_2_0 + TDictSet<T> set = new TDictSet<T>(); +#else + HashSet<T> set = new HashSet<T>(); +#endif + public int Count + { + get { return set.Count; } + } + + public bool IsReadOnly + { + get { return false; } + } + + public void Add(T item) + { + set.Add(item); + } + + public void Clear() + { + set.Clear(); + } + + public bool Contains(T item) + { + return set.Contains(item); + } + + public void CopyTo(T[] array, int arrayIndex) + { + set.CopyTo(array, arrayIndex); + } + + public IEnumerator GetEnumerator() + { + return set.GetEnumerator(); + } + + IEnumerator<T> IEnumerable<T>.GetEnumerator() + { + return ((IEnumerable<T>)set).GetEnumerator(); + } + + public bool Remove(T item) + { + return set.Remove(item); + } + +#if NET_2_0 + private class TDictSet<V> : ICollection<V> + { + Dictionary<V, TDictSet<V>> dict = new Dictionary<V, TDictSet<V>>(); + + public int Count + { + get { return dict.Count; } + } + + public bool IsReadOnly + { + get { return false; } + } + + public IEnumerator GetEnumerator() + { + return ((IEnumerable)dict.Keys).GetEnumerator(); + } + + IEnumerator<V> IEnumerable<V>.GetEnumerator() + { + return dict.Keys.GetEnumerator(); + } + + public bool Add(V item) + { + if (!dict.ContainsKey(item)) + { + dict[item] = this; + return true; + } + + return false; + } + + void ICollection<V>.Add(V item) + { + Add(item); + } + + public void Clear() + { + dict.Clear(); + } + + public bool Contains(V item) + { + return dict.ContainsKey(item); + } + + public void CopyTo(V[] array, int arrayIndex) + { + dict.Keys.CopyTo(array, arrayIndex); + } + + public bool Remove(V item) + { + return dict.Remove(item); + } + } +#endif + } + +} diff --git a/lib/csharp/src/Protocol/TBase.cs b/lib/csharp/src/Protocol/TBase.cs new file mode 100644 index 000000000..1969bb3d2 --- /dev/null +++ b/lib/csharp/src/Protocol/TBase.cs @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace Thrift.Protocol +{ + public interface TBase + { + /// + /// Reads the TObject from the given input protocol. + /// + void Read(TProtocol tProtocol); + + /// + /// Writes the objects out to the protocol + /// + void Write(TProtocol tProtocol); + } +} diff --git a/lib/csharp/src/Protocol/TBinaryProtocol.cs b/lib/csharp/src/Protocol/TBinaryProtocol.cs new file mode 100644 index 000000000..14ca43b74 --- /dev/null +++ b/lib/csharp/src/Protocol/TBinaryProtocol.cs @@ -0,0 +1,392 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Text; +using Thrift.Transport; + +namespace Thrift.Protocol +{ + public class TBinaryProtocol : TProtocol + { + protected const uint VERSION_MASK = 0xffff0000; + protected const uint VERSION_1 = 0x80010000; + + protected bool strictRead_ = false; + protected bool strictWrite_ = true; + + protected int readLength_; + protected bool checkReadLength_ = false; + + + #region BinaryProtocol Factory + /** + * Factory + */ + public class Factory : TProtocolFactory { + + protected bool strictRead_ = false; + protected bool strictWrite_ = true; + + public Factory() + :this(false, true) + { + } + + public Factory(bool strictRead, bool strictWrite) + { + strictRead_ = strictRead; + strictWrite_ = strictWrite; + } + + public TProtocol GetProtocol(TTransport trans) { + return new TBinaryProtocol(trans, strictRead_, strictWrite_); + } + } + + #endregion + + public TBinaryProtocol(TTransport trans) + : this(trans, false, true) + { + } + + public TBinaryProtocol(TTransport trans, bool strictRead, bool strictWrite) + :base(trans) + { + strictRead_ = strictRead; + strictWrite_ = strictWrite; + } + + #region Write Methods + + public override void WriteMessageBegin(TMessage message) + { + if (strictWrite_) + { + uint version = VERSION_1 | (uint)(message.Type); + WriteI32((int)version); + WriteString(message.Name); + WriteI32(message.SeqID); + } + else + { + WriteString(message.Name); + WriteByte((byte)message.Type); + WriteI32(message.SeqID); + } + } + + public override void WriteMessageEnd() + { + } + + public override void WriteStructBegin(TStruct struc) + { + } + + public override void WriteStructEnd() + { + } + + public override void WriteFieldBegin(TField field) + { + WriteByte((byte)field.Type); + WriteI16(field.ID); + } + + public override void WriteFieldEnd() + { + } + + public override void WriteFieldStop() + { + WriteByte((byte)TType.Stop); + } + + public override void WriteMapBegin(TMap map) + { + WriteByte((byte)map.KeyType); + WriteByte((byte)map.ValueType); + WriteI32(map.Count); + } + + public override void WriteMapEnd() + { + } + + public override void WriteListBegin(TList list) + { + WriteByte((byte)list.ElementType); + WriteI32(list.Count); + } + + public override void WriteListEnd() + { + } + + public override void WriteSetBegin(TSet set) + { + WriteByte((byte)set.ElementType); + WriteI32(set.Count); + } + + public override void WriteSetEnd() + { + } + + public override void WriteBool(bool b) + { + WriteByte(b ? (byte)1 : (byte)0); + } + + private byte[] bout = new byte[1]; + public override void WriteByte(byte b) + { + bout[0] = b; + trans.Write(bout, 0, 1); + } + + private byte[] i16out = new byte[2]; + public override void WriteI16(short s) + { + i16out[0] = (byte)(0xff & (s >> 8)); + i16out[1] = (byte)(0xff & s); + trans.Write(i16out, 0, 2); + } + + private byte[] i32out = new byte[4]; + public override void WriteI32(int i32) + { + i32out[0] = (byte)(0xff & (i32 >> 24)); + i32out[1] = (byte)(0xff & (i32 >> 16)); + i32out[2] = (byte)(0xff & (i32 >> 8)); + i32out[3] = (byte)(0xff & i32); + trans.Write(i32out, 0, 4); + } + + private byte[] i64out = new byte[8]; + public override void WriteI64(long i64) + { + i64out[0] = (byte)(0xff & (i64 >> 56)); + i64out[1] = (byte)(0xff & (i64 >> 48)); + i64out[2] = (byte)(0xff & (i64 >> 40)); + i64out[3] = (byte)(0xff & (i64 >> 32)); + i64out[4] = (byte)(0xff & (i64 >> 24)); + i64out[5] = (byte)(0xff & (i64 >> 16)); + i64out[6] = (byte)(0xff & (i64 >> 8)); + i64out[7] = (byte)(0xff & i64); + trans.Write(i64out, 0, 8); + } + + public override void WriteDouble(double d) + { + WriteI64(BitConverter.DoubleToInt64Bits(d)); + } + + public override void WriteBinary(byte[] b) + { + WriteI32(b.Length); + trans.Write(b, 0, b.Length); + } + + #endregion + + #region ReadMethods + + public override TMessage ReadMessageBegin() + { + TMessage message = new TMessage(); + int size = ReadI32(); + if (size < 0) + { + uint version = (uint)size & VERSION_MASK; + if (version != VERSION_1) + { + throw new TProtocolException(TProtocolException.BAD_VERSION, "Bad version in ReadMessageBegin: " + version); + } + message.Type = (TMessageType)(size & 0x000000ff); + message.Name = ReadString(); + message.SeqID = ReadI32(); + } + else + { + if (strictRead_) + { + throw new TProtocolException(TProtocolException.BAD_VERSION, "Missing version in readMessageBegin, old client?"); + } + message.Name = ReadStringBody(size); + message.Type = (TMessageType)ReadByte(); + message.SeqID = ReadI32(); + } + return message; + } + + public override void ReadMessageEnd() + { + } + + public override TStruct ReadStructBegin() + { + return new TStruct(); + } + + public override void ReadStructEnd() + { + } + + public override TField ReadFieldBegin() + { + TField field = new TField(); + field.Type = (TType)ReadByte(); + + if (field.Type != TType.Stop) + { + field.ID = ReadI16(); + } + + return field; + } + + public override void ReadFieldEnd() + { + } + + public override TMap ReadMapBegin() + { + TMap map = new TMap(); + map.KeyType = (TType)ReadByte(); + map.ValueType = (TType)ReadByte(); + map.Count = ReadI32(); + + return map; + } + + public override void ReadMapEnd() + { + } + + public override TList ReadListBegin() + { + TList list = new TList(); + list.ElementType = (TType)ReadByte(); + list.Count = ReadI32(); + + return list; + } + + public override void ReadListEnd() + { + } + + public override TSet ReadSetBegin() + { + TSet set = new TSet(); + set.ElementType = (TType)ReadByte(); + set.Count = ReadI32(); + + return set; + } + + public override void ReadSetEnd() + { + } + + public override bool ReadBool() + { + return ReadByte() == 1; + } + + private byte[] bin = new byte[1]; + public override byte ReadByte() + { + ReadAll(bin, 0, 1); + return bin[0]; + } + + private byte[] i16in = new byte[2]; + public override short ReadI16() + { + ReadAll(i16in, 0, 2); + return (short)(((i16in[0] & 0xff) << 8) | ((i16in[1] & 0xff))); + } + + private byte[] i32in = new byte[4]; + public override int ReadI32() + { + ReadAll(i32in, 0, 4); + return (int)(((i32in[0] & 0xff) << 24) | ((i32in[1] & 0xff) << 16) | ((i32in[2] & 0xff) << 8) | ((i32in[3] & 0xff))); + } + + private byte[] i64in = new byte[8]; + public override long ReadI64() + { + ReadAll(i64in, 0, 8); + return (long)(((long)(i64in[0] & 0xff) << 56) | ((long)(i64in[1] & 0xff) << 48) | ((long)(i64in[2] & 0xff) << 40) | ((long)(i64in[3] & 0xff) << 32) | + ((long)(i64in[4] & 0xff) << 24) | ((long)(i64in[5] & 0xff) << 16) | ((long)(i64in[6] & 0xff) << 8) | ((long)(i64in[7] & 0xff))); + } + + public override double ReadDouble() + { + return BitConverter.Int64BitsToDouble(ReadI64()); + } + + public void SetReadLength(int readLength) + { + readLength_ = readLength; + checkReadLength_ = true; + } + + protected void CheckReadLength(int length) + { + if (checkReadLength_) + { + readLength_ -= length; + if (readLength_ < 0) + { + throw new Exception("Message length exceeded: " + length); + } + } + } + + public override byte[] ReadBinary() + { + int size = ReadI32(); + CheckReadLength(size); + byte[] buf = new byte[size]; + trans.ReadAll(buf, 0, size); + return buf; + } + private string ReadStringBody(int size) + { + CheckReadLength(size); + byte[] buf = new byte[size]; + trans.ReadAll(buf, 0, size); + return Encoding.UTF8.GetString(buf); + } + + private int ReadAll(byte[] buf, int off, int len) + { + CheckReadLength(len); + return trans.ReadAll(buf, off, len); + } + + #endregion + } +} diff --git a/lib/csharp/src/Protocol/TField.cs b/lib/csharp/src/Protocol/TField.cs new file mode 100644 index 000000000..485c994bb --- /dev/null +++ b/lib/csharp/src/Protocol/TField.cs @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TField + { + private string name; + private TType type; + private short id; + + public TField(string name, TType type, short id) + :this() + { + this.name = name; + this.type = type; + this.id = id; + } + + public string Name + { + get { return name; } + set { name = value; } + } + + public TType Type + { + get { return type; } + set { type = value; } + } + + public short ID + { + get { return id; } + set { id = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TList.cs b/lib/csharp/src/Protocol/TList.cs new file mode 100644 index 000000000..dbc5c40e8 --- /dev/null +++ b/lib/csharp/src/Protocol/TList.cs @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TList + { + private TType elementType; + private int count; + + public TList(TType elementType, int count) + :this() + { + this.elementType = elementType; + this.count = count; + } + + public TType ElementType + { + get { return elementType; } + set { elementType = value; } + } + + public int Count + { + get { return count; } + set { count = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TMap.cs b/lib/csharp/src/Protocol/TMap.cs new file mode 100644 index 000000000..8b53f8990 --- /dev/null +++ b/lib/csharp/src/Protocol/TMap.cs @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TMap + { + private TType keyType; + private TType valueType; + private int count; + + public TMap(TType keyType, TType valueType, int count) + :this() + { + this.keyType = keyType; + this.valueType = valueType; + this.count = count; + } + + public TType KeyType + { + get { return keyType; } + set { keyType = value; } + } + + public TType ValueType + { + get { return valueType; } + set { valueType = value; } + } + + public int Count + { + get { return count; } + set { count = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TMessage.cs b/lib/csharp/src/Protocol/TMessage.cs new file mode 100644 index 000000000..8cb6e0b1a --- /dev/null +++ b/lib/csharp/src/Protocol/TMessage.cs @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TMessage + { + private string name; + private TMessageType type; + private int seqID; + + public TMessage(string name, TMessageType type, int seqid) + :this() + { + this.name = name; + this.type = type; + this.seqID = seqid; + } + + public string Name + { + get { return name; } + set { name = value; } + } + + public TMessageType Type + { + get { return type; } + set { type = value; } + } + + public int SeqID + { + get { return seqID; } + set { seqID = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TMessageType.cs b/lib/csharp/src/Protocol/TMessageType.cs new file mode 100644 index 000000000..ab07cf6c0 --- /dev/null +++ b/lib/csharp/src/Protocol/TMessageType.cs @@ -0,0 +1,31 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Protocol +{ + public enum TMessageType + { + Call = 1, + Reply = 2, + Exception = 3, + Oneway = 4 + } +} diff --git a/lib/csharp/src/Protocol/TProtocol.cs b/lib/csharp/src/Protocol/TProtocol.cs new file mode 100644 index 000000000..acf9c1b3d --- /dev/null +++ b/lib/csharp/src/Protocol/TProtocol.cs @@ -0,0 +1,87 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Text; +using Thrift.Transport; + +namespace Thrift.Protocol +{ + public abstract class TProtocol + { + protected TTransport trans; + + protected TProtocol(TTransport trans) + { + this.trans = trans; + } + + public TTransport Transport + { + get { return trans; } + } + + public abstract void WriteMessageBegin(TMessage message); + public abstract void WriteMessageEnd(); + public abstract void WriteStructBegin(TStruct struc); + public abstract void WriteStructEnd(); + public abstract void WriteFieldBegin(TField field); + public abstract void WriteFieldEnd(); + public abstract void WriteFieldStop(); + public abstract void WriteMapBegin(TMap map); + public abstract void WriteMapEnd(); + public abstract void WriteListBegin(TList list); + public abstract void WriteListEnd(); + public abstract void WriteSetBegin(TSet set); + public abstract void WriteSetEnd(); + public abstract void WriteBool(bool b); + public abstract void WriteByte(byte b); + public abstract void WriteI16(short i16); + public abstract void WriteI32(int i32); + public abstract void WriteI64(long i64); + public abstract void WriteDouble(double d); + public void WriteString(string s) { + WriteBinary(Encoding.UTF8.GetBytes(s)); + } + public abstract void WriteBinary(byte[] b); + + public abstract TMessage ReadMessageBegin(); + public abstract void ReadMessageEnd(); + public abstract TStruct ReadStructBegin(); + public abstract void ReadStructEnd(); + public abstract TField ReadFieldBegin(); + public abstract void ReadFieldEnd(); + public abstract TMap ReadMapBegin(); + public abstract void ReadMapEnd(); + public abstract TList ReadListBegin(); + public abstract void ReadListEnd(); + public abstract TSet ReadSetBegin(); + public abstract void ReadSetEnd(); + public abstract bool ReadBool(); + public abstract byte ReadByte(); + public abstract short ReadI16(); + public abstract int ReadI32(); + public abstract long ReadI64(); + public abstract double ReadDouble(); + public string ReadString() { + return Encoding.UTF8.GetString(ReadBinary()); + } + public abstract byte[] ReadBinary(); + } +} diff --git a/lib/csharp/src/Protocol/TProtocolException.cs b/lib/csharp/src/Protocol/TProtocolException.cs new file mode 100644 index 000000000..9c2504766 --- /dev/null +++ b/lib/csharp/src/Protocol/TProtocolException.cs @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Protocol +{ + class TProtocolException : Exception + { + public const int UNKNOWN = 0; + public const int INVALID_DATA = 1; + public const int NEGATIVE_SIZE = 2; + public const int SIZE_LIMIT = 3; + public const int BAD_VERSION = 4; + + protected int type_ = UNKNOWN; + + public TProtocolException() + : base() + { + } + + public TProtocolException(int type) + : base() + { + type_ = type; + } + + public TProtocolException(int type, String message) + : base(message) + { + type_ = type; + } + + public TProtocolException(String message) + : base(message) + { + } + + public int getType() + { + return type_; + } + } +} diff --git a/lib/csharp/src/Protocol/TProtocolFactory.cs b/lib/csharp/src/Protocol/TProtocolFactory.cs new file mode 100644 index 000000000..ae976acd5 --- /dev/null +++ b/lib/csharp/src/Protocol/TProtocolFactory.cs @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Transport; + +namespace Thrift.Protocol +{ + public interface TProtocolFactory + { + TProtocol GetProtocol(TTransport trans); + } +} diff --git a/lib/csharp/src/Protocol/TProtocolUtil.cs b/lib/csharp/src/Protocol/TProtocolUtil.cs new file mode 100644 index 000000000..57cef0eff --- /dev/null +++ b/lib/csharp/src/Protocol/TProtocolUtil.cs @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Protocol +{ + public static class TProtocolUtil + { + public static void Skip(TProtocol prot, TType type) + { + switch (type) + { + case TType.Bool: + prot.ReadBool(); + break; + case TType.Byte: + prot.ReadByte(); + break; + case TType.I16: + prot.ReadI16(); + break; + case TType.I32: + prot.ReadI32(); + break; + case TType.I64: + prot.ReadI64(); + break; + case TType.Double: + prot.ReadDouble(); + break; + case TType.String: + // Don't try to decode the string, just skip it. + prot.ReadBinary(); + break; + case TType.Struct: + prot.ReadStructBegin(); + while (true) + { + TField field = prot.ReadFieldBegin(); + if (field.Type == TType.Stop) + { + break; + } + Skip(prot, field.Type); + prot.ReadFieldEnd(); + } + prot.ReadStructEnd(); + break; + case TType.Map: + TMap map = prot.ReadMapBegin(); + for (int i = 0; i < map.Count; i++) + { + Skip(prot, map.KeyType); + Skip(prot, map.ValueType); + } + prot.ReadMapEnd(); + break; + case TType.Set: + TSet set = prot.ReadSetBegin(); + for (int i = 0; i < set.Count; i++) + { + Skip(prot, set.ElementType); + } + prot.ReadSetEnd(); + break; + case TType.List: + TList list = prot.ReadListBegin(); + for (int i = 0; i < list.Count; i++) + { + Skip(prot, list.ElementType); + } + prot.ReadListEnd(); + break; + } + } + } +} diff --git a/lib/csharp/src/Protocol/TSet.cs b/lib/csharp/src/Protocol/TSet.cs new file mode 100644 index 000000000..ac73992dc --- /dev/null +++ b/lib/csharp/src/Protocol/TSet.cs @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TSet + { + private TType elementType; + private int count; + + public TSet(TType elementType, int count) + :this() + { + this.elementType = elementType; + this.count = count; + } + + public TType ElementType + { + get { return elementType; } + set { elementType = value; } + } + + public int Count + { + get { return count; } + set { count = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TStruct.cs b/lib/csharp/src/Protocol/TStruct.cs new file mode 100644 index 000000000..0cac2733e --- /dev/null +++ b/lib/csharp/src/Protocol/TStruct.cs @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TStruct + { + private string name; + + public TStruct(string name) + :this() + { + this.name = name; + } + + public string Name + { + get { return name; } + set { name = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TType.cs b/lib/csharp/src/Protocol/TType.cs new file mode 100644 index 000000000..c2d78edc8 --- /dev/null +++ b/lib/csharp/src/Protocol/TType.cs @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Protocol +{ + public enum TType : byte + { + Stop = 0, + Void = 1, + Bool = 2, + Byte = 3, + Double = 4, + I16 = 6, + I32 = 8, + I64 = 10, + String = 11, + Struct = 12, + Map = 13, + Set = 14, + List = 15 + } +} diff --git a/lib/csharp/src/Server/TServer.cs b/lib/csharp/src/Server/TServer.cs new file mode 100644 index 000000000..61a9416fc --- /dev/null +++ b/lib/csharp/src/Server/TServer.cs @@ -0,0 +1,135 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Protocol; +using Thrift.Transport; +using System.IO; + +namespace Thrift.Server +{ + public abstract class TServer + { + /** + * Core processor + */ + protected TProcessor processor; + + /** + * Server transport + */ + protected TServerTransport serverTransport; + + /** + * Input Transport Factory + */ + protected TTransportFactory inputTransportFactory; + + /** + * Output Transport Factory + */ + protected TTransportFactory outputTransportFactory; + + /** + * Input Protocol Factory + */ + protected TProtocolFactory inputProtocolFactory; + + /** + * Output Protocol Factory + */ + protected TProtocolFactory outputProtocolFactory; + public delegate void LogDelegate(string str); + protected LogDelegate logDelegate; + + /** + * Default constructors. + */ + + public TServer(TProcessor processor, + TServerTransport serverTransport) + :this(processor, serverTransport, new TTransportFactory(), new TTransportFactory(), new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), DefaultLogDelegate) + { + } + + public TServer(TProcessor processor, + TServerTransport serverTransport, + LogDelegate logDelegate) + : this(processor, serverTransport, new TTransportFactory(), new TTransportFactory(), new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), DefaultLogDelegate) + { + } + + public TServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory) + :this(processor, + serverTransport, + transportFactory, + transportFactory, + new TBinaryProtocol.Factory(), + new TBinaryProtocol.Factory(), + DefaultLogDelegate) + { + } + + public TServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) + :this(processor, + serverTransport, + transportFactory, + transportFactory, + protocolFactory, + protocolFactory, + DefaultLogDelegate) + { + } + + public TServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + LogDelegate logDelegate) + { + this.processor = processor; + this.serverTransport = serverTransport; + this.inputTransportFactory = inputTransportFactory; + this.outputTransportFactory = outputTransportFactory; + this.inputProtocolFactory = inputProtocolFactory; + this.outputProtocolFactory = outputProtocolFactory; + this.logDelegate = logDelegate; + } + + /** + * The run method fires up the server and gets things going. + */ + public abstract void Serve(); + + public abstract void Stop(); + + protected static void DefaultLogDelegate(string s) + { + Console.Error.WriteLine(s); + } + } +} + diff --git a/lib/csharp/src/Server/TSimpleServer.cs b/lib/csharp/src/Server/TSimpleServer.cs new file mode 100644 index 000000000..34a51de4a --- /dev/null +++ b/lib/csharp/src/Server/TSimpleServer.cs @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Transport; +using Thrift.Protocol; + +namespace Thrift.Server +{ + /// <summary> + /// Simple single-threaded server for testing + /// </summary> + public class TSimpleServer : TServer + { + private bool stop = false; + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport) + :base(processor, serverTransport, new TTransportFactory(), new TTransportFactory(), new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), DefaultLogDelegate) + { + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + LogDelegate logDel) + : base(processor, serverTransport, new TTransportFactory(), new TTransportFactory(), new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), logDel) + { + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory) + :base(processor, + serverTransport, + transportFactory, + transportFactory, + new TBinaryProtocol.Factory(), + new TBinaryProtocol.Factory(), + DefaultLogDelegate) + { + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) + :base(processor, + serverTransport, + transportFactory, + transportFactory, + protocolFactory, + protocolFactory, + DefaultLogDelegate) + { + } + + public override void Serve() + { + try + { + serverTransport.Listen(); + } + catch (TTransportException ttx) + { + logDelegate(ttx.ToString()); + return; + } + + while (!stop) + { + TTransport client = null; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try + { + client = serverTransport.Accept(); + if (client != null) + { + inputTransport = inputTransportFactory.GetTransport(client); + outputTransport = outputTransportFactory.GetTransport(client); + inputProtocol = inputProtocolFactory.GetProtocol(inputTransport); + outputProtocol = outputProtocolFactory.GetProtocol(outputTransport); + while (processor.Process(inputProtocol, outputProtocol)) { } + } + } + catch (TTransportException ttx) + { + // Client died, just move on + if (stop) + { + logDelegate("TSimpleServer was shutting down, caught " + ttx.GetType().Name); + } + } + catch (Exception x) + { + logDelegate(x.ToString()); + } + + if (inputTransport != null) + { + inputTransport.Close(); + } + + if (outputTransport != null) + { + outputTransport.Close(); + } + } + + if (stop) + { + try + { + serverTransport.Close(); + } + catch (TTransportException ttx) + { + logDelegate("TServerTranport failed on close: " + ttx.Message); + } + stop = false; + } + } + + public override void Stop() + { + stop = true; + serverTransport.Close(); + } + } +} diff --git a/lib/csharp/src/Server/TThreadPoolServer.cs b/lib/csharp/src/Server/TThreadPoolServer.cs new file mode 100644 index 000000000..efc71f01e --- /dev/null +++ b/lib/csharp/src/Server/TThreadPoolServer.cs @@ -0,0 +1,186 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Threading; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Thrift.Server +{ + /// <summary> + /// Server that uses C# built-in ThreadPool to spawn threads when handling requests + /// </summary> + public class TThreadPoolServer : TServer + { + private const int DEFAULT_MIN_THREADS = 10; + private const int DEFAULT_MAX_THREADS = 100; + private volatile bool stop = false; + + public TThreadPoolServer(TProcessor processor, TServerTransport serverTransport) + :this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), + DEFAULT_MIN_THREADS, DEFAULT_MAX_THREADS, DefaultLogDelegate) + { + } + + public TThreadPoolServer(TProcessor processor, TServerTransport serverTransport, LogDelegate logDelegate) + : this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), + DEFAULT_MIN_THREADS, DEFAULT_MAX_THREADS, logDelegate) + { + } + + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) + :this(processor, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory, + DEFAULT_MIN_THREADS, DEFAULT_MAX_THREADS, DefaultLogDelegate) + { + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + int minThreadPoolThreads, int maxThreadPoolThreads, LogDelegate logDel) + :base(processor, serverTransport, inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, logDel) + { + if (!ThreadPool.SetMinThreads(minThreadPoolThreads, minThreadPoolThreads)) + { + throw new Exception("Error: could not SetMinThreads in ThreadPool"); + } + if (!ThreadPool.SetMaxThreads(maxThreadPoolThreads, maxThreadPoolThreads)) + { + throw new Exception("Error: could not SetMaxThreads in ThreadPool"); + } + } + + /// <summary> + /// Use new ThreadPool thread for each new client connection + /// </summary> + public override void Serve() + { + try + { + serverTransport.Listen(); + } + catch (TTransportException ttx) + { + logDelegate("Error, could not listen on ServerTransport: " + ttx); + return; + } + + while (!stop) + { + int failureCount = 0; + try + { + TTransport client = serverTransport.Accept(); + ThreadPool.QueueUserWorkItem(this.Execute, client); + } + catch (TTransportException ttx) + { + if (stop) + { + logDelegate("TThreadPoolServer was shutting down, caught " + ttx.GetType().Name); + } + else + { + ++failureCount; + logDelegate(ttx.ToString()); + } + + } + } + + if (stop) + { + try + { + serverTransport.Close(); + } + catch (TTransportException ttx) + { + logDelegate("TServerTransport failed on close: " + ttx.Message); + } + stop = false; + } + } + + /// <summary> + /// Loops on processing a client forever + /// threadContext will be a TTransport instance + /// </summary> + /// <param name="threadContext"></param> + private void Execute(Object threadContext) + { + TTransport client = (TTransport)threadContext; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try + { + inputTransport = inputTransportFactory.GetTransport(client); + outputTransport = outputTransportFactory.GetTransport(client); + inputProtocol = inputProtocolFactory.GetProtocol(inputTransport); + outputProtocol = outputProtocolFactory.GetProtocol(outputTransport); + while (processor.Process(inputProtocol, outputProtocol)) + { + //keep processing requests until client disconnects + } + } + catch (TTransportException) + { + // Assume the client died and continue silently + //Console.WriteLine(ttx); + } + + catch (Exception x) + { + logDelegate("Error: " + x); + } + + if (inputTransport != null) + { + inputTransport.Close(); + } + if (outputTransport != null) + { + outputTransport.Close(); + } + } + + public override void Stop() + { + stop = true; + serverTransport.Close(); + } + } +} diff --git a/lib/csharp/src/Server/TThreadedServer.cs b/lib/csharp/src/Server/TThreadedServer.cs new file mode 100644 index 000000000..75206f15c --- /dev/null +++ b/lib/csharp/src/Server/TThreadedServer.cs @@ -0,0 +1,234 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Threading; +using Thrift.Collections; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Thrift.Server +{ + /// <summary> + /// Server that uses C# threads (as opposed to the ThreadPool) when handling requests + /// </summary> + public class TThreadedServer : TServer + { + private const int DEFAULT_MAX_THREADS = 100; + private volatile bool stop = false; + private readonly int maxThreads; + + private Queue<TTransport> clientQueue; + private THashSet<Thread> clientThreads; + private object clientLock; + private Thread workerThread; + + public TThreadedServer(TProcessor processor, TServerTransport serverTransport) + : this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), + DEFAULT_MAX_THREADS, DefaultLogDelegate) + { + } + + public TThreadedServer(TProcessor processor, TServerTransport serverTransport, LogDelegate logDelegate) + : this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), + DEFAULT_MAX_THREADS, logDelegate) + { + } + + + public TThreadedServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) + : this(processor, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory, + DEFAULT_MAX_THREADS, DefaultLogDelegate) + { + } + + public TThreadedServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + int maxThreads, LogDelegate logDel) + : base(processor, serverTransport, inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, logDel) + { + this.maxThreads = maxThreads; + clientQueue = new Queue<TTransport>(); + clientLock = new object(); + clientThreads = new THashSet<Thread>(); + } + + /// <summary> + /// Use new Thread for each new client connection. block until numConnections < maxTHreads + /// </summary> + public override void Serve() + { + try + { + //start worker thread + workerThread = new Thread(new ThreadStart(Execute)); + workerThread.Start(); + serverTransport.Listen(); + } + catch (TTransportException ttx) + { + logDelegate("Error, could not listen on ServerTransport: " + ttx); + return; + } + + while (!stop) + { + int failureCount = 0; + try + { + TTransport client = serverTransport.Accept(); + lock (clientLock) + { + clientQueue.Enqueue(client); + Monitor.Pulse(clientLock); + } + } + catch (TTransportException ttx) + { + if (stop) + { + logDelegate("TThreadPoolServer was shutting down, caught " + ttx); + } + else + { + ++failureCount; + logDelegate(ttx.ToString()); + } + + } + } + + if (stop) + { + try + { + serverTransport.Close(); + } + catch (TTransportException ttx) + { + logDelegate("TServeTransport failed on close: " + ttx.Message); + } + stop = false; + } + } + + /// <summary> + /// Loops on processing a client forever + /// threadContext will be a TTransport instance + /// </summary> + /// <param name="threadContext"></param> + private void Execute() + { + while (!stop) + { + TTransport client; + Thread t; + lock (clientLock) + { + //don't dequeue if too many connections + while (clientThreads.Count >= maxThreads) + { + Monitor.Wait(clientLock); + } + + while (clientQueue.Count == 0) + { + Monitor.Wait(clientLock); + } + + client = clientQueue.Dequeue(); + t = new Thread(new ParameterizedThreadStart(ClientWorker)); + clientThreads.Add(t); + } + //start processing requests from client on new thread + t.Start(client); + } + } + + private void ClientWorker(Object context) + { + TTransport client = (TTransport)context; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try + { + inputTransport = inputTransportFactory.GetTransport(client); + outputTransport = outputTransportFactory.GetTransport(client); + inputProtocol = inputProtocolFactory.GetProtocol(inputTransport); + outputProtocol = outputProtocolFactory.GetProtocol(outputTransport); + while (processor.Process(inputProtocol, outputProtocol)) + { + //keep processing requests until client disconnects + } + } + catch (TTransportException) + { + } + catch (Exception x) + { + logDelegate("Error: " + x); + } + + if (inputTransport != null) + { + inputTransport.Close(); + } + if (outputTransport != null) + { + outputTransport.Close(); + } + + lock (clientLock) + { + clientThreads.Remove(Thread.CurrentThread); + Monitor.Pulse(clientLock); + } + return; + } + + public override void Stop() + { + stop = true; + serverTransport.Close(); + //clean up all the threads myself + workerThread.Abort(); + foreach (Thread t in clientThreads) + { + t.Abort(); + } + } + } +} diff --git a/lib/csharp/src/TApplicationException.cs b/lib/csharp/src/TApplicationException.cs new file mode 100644 index 000000000..127196866 --- /dev/null +++ b/lib/csharp/src/TApplicationException.cs @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Protocol; + +namespace Thrift +{ + public class TApplicationException : Exception + { + protected ExceptionType type; + + public TApplicationException() + { + } + + public TApplicationException(ExceptionType type) + { + this.type = type; + } + + public TApplicationException(ExceptionType type, string message) + : base(message) + { + this.type = type; + } + + public static TApplicationException Read(TProtocol iprot) + { + TField field; + + string message = null; + ExceptionType type = ExceptionType.Unknown; + + while (true) + { + field = iprot.ReadFieldBegin(); + if (field.Type == TType.Stop) + { + break; + } + + switch (field.ID) + { + case 1: + if (field.Type == TType.String) + { + message = iprot.ReadString(); + } + else + { + TProtocolUtil.Skip(iprot, field.Type); + } + break; + case 2: + if (field.Type == TType.I32) + { + type = (ExceptionType)iprot.ReadI32(); + } + else + { + TProtocolUtil.Skip(iprot, field.Type); + } + break; + default: + TProtocolUtil.Skip(iprot, field.Type); + break; + } + + iprot.ReadFieldEnd(); + } + + iprot.ReadStructEnd(); + + return new TApplicationException(type, message); + } + + public void Write(TProtocol oprot) + { + TStruct struc = new TStruct("TApplicationException"); + TField field = new TField(); + + oprot.WriteStructBegin(struc); + + if (!String.IsNullOrEmpty(Message)) + { + field.Name = "message"; + field.Type = TType.String; + field.ID = 1; + oprot.WriteFieldBegin(field); + oprot.WriteString(Message); + oprot.WriteFieldEnd(); + } + + field.Name = "type"; + field.Type = TType.I32; + field.ID = 2; + oprot.WriteFieldBegin(field); + oprot.WriteI32((int)type); + oprot.WriteFieldEnd(); + oprot.WriteFieldStop(); + oprot.WriteStructEnd(); + } + + public enum ExceptionType + { + Unknown, + UnknownMethod, + InvalidMessageType, + WrongMethodName, + BadSequenceID, + MissingResult + } + } +} diff --git a/lib/csharp/src/TProcessor.cs b/lib/csharp/src/TProcessor.cs new file mode 100644 index 000000000..cbb55b798 --- /dev/null +++ b/lib/csharp/src/TProcessor.cs @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Protocol; + +namespace Thrift +{ + public interface TProcessor + { + bool Process(TProtocol iprot, TProtocol oprot); + } +} diff --git a/lib/csharp/src/Thrift.csproj b/lib/csharp/src/Thrift.csproj new file mode 100644 index 000000000..1eb4355d5 --- /dev/null +++ b/lib/csharp/src/Thrift.csproj @@ -0,0 +1,77 @@ +<Project ToolsVersion="3.5" DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <PropertyGroup> + <Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration> + <Platform Condition=" '$(Platform)' == '' ">AnyCPU</Platform> + <ProjectGuid>{499EB63C-D74C-47E8-AE48-A2FC94538E9D}</ProjectGuid> + <ProductVersion>9.0.21022</ProductVersion> + <SchemaVersion>2.0</SchemaVersion> + <OutputType>Library</OutputType> + <NoStandardLibraries>false</NoStandardLibraries> + <AssemblyName>Thrift</AssemblyName> + <TargetFrameworkVersion>v3.5</TargetFrameworkVersion> + <FileAlignment>512</FileAlignment> + <RootNamespace>Thrift</RootNamespace> + <SccProjectName>SAK</SccProjectName> + <SccLocalPath>SAK</SccLocalPath> + <SccAuxPath>SAK</SccAuxPath> + <SccProvider>SAK</SccProvider> + </PropertyGroup> + <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|AnyCPU' "> + <DebugSymbols>true</DebugSymbols> + <DebugType>full</DebugType> + <Optimize>false</Optimize> + <OutputPath>bin\Debug\</OutputPath> + <DefineConstants>DEBUG;TRACE</DefineConstants> + <ErrorReport>prompt</ErrorReport> + <WarningLevel>4</WarningLevel> + </PropertyGroup> + <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' "> + <DebugType>pdbonly</DebugType> + <Optimize>true</Optimize> + <OutputPath>bin\Release\</OutputPath> + <DefineConstants>TRACE</DefineConstants> + <ErrorReport>prompt</ErrorReport> + <WarningLevel>4</WarningLevel> + </PropertyGroup> + <ItemGroup> + <Reference Include="System" /> + <Reference Include="System.Core"> + <RequiredTargetFramework>3.5</RequiredTargetFramework> + </Reference> + </ItemGroup> + <ItemGroup> + <Compile Include="Collections\THashSet.cs" /> + <Compile Include="Protocol\TBase.cs" /> + <Compile Include="Protocol\TBinaryProtocol.cs" /> + <Compile Include="Protocol\TField.cs" /> + <Compile Include="Protocol\TList.cs" /> + <Compile Include="Protocol\TMap.cs" /> + <Compile Include="Protocol\TMessage.cs" /> + <Compile Include="Protocol\TMessageType.cs" /> + <Compile Include="Protocol\TProtocol.cs" /> + <Compile Include="Protocol\TProtocolException.cs" /> + <Compile Include="Protocol\TProtocolFactory.cs" /> + <Compile Include="Protocol\TProtocolUtil.cs" /> + <Compile Include="Protocol\TSet.cs" /> + <Compile Include="Protocol\TStruct.cs" /> + <Compile Include="Protocol\TType.cs" /> + <Compile Include="Server\TThreadedServer.cs" /> + <Compile Include="Server\TServer.cs" /> + <Compile Include="Server\TSimpleServer.cs" /> + <Compile Include="Server\TThreadPoolServer.cs" /> + <Compile Include="TApplicationException.cs" /> + <Compile Include="TProcessor.cs" /> + <Compile Include="Transport\TBufferedTransport.cs" /> + <Compile Include="Transport\TServerSocket.cs" /> + <Compile Include="Transport\TServerTransport.cs" /> + <Compile Include="Transport\TSocket.cs" /> + <Compile Include="Transport\TStreamTransport.cs" /> + <Compile Include="Transport\TTransport.cs" /> + <Compile Include="Transport\TTransportException.cs" /> + <Compile Include="Transport\TTransportFactory.cs" /> + </ItemGroup> + <Import Project="$(MSBuildBinPath)\Microsoft.CSHARP.Targets" /> + <ProjectExtensions> + <VisualStudio AllowExistingFolder="true" /> + </ProjectExtensions> +</Project>
\ No newline at end of file diff --git a/lib/csharp/src/Thrift.sln b/lib/csharp/src/Thrift.sln new file mode 100644 index 000000000..cb7342cd3 --- /dev/null +++ b/lib/csharp/src/Thrift.sln @@ -0,0 +1,51 @@ + +Microsoft Visual Studio Solution File, Format Version 10.00 +# Visual Studio 2008 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Thrift", "Thrift.csproj", "{499EB63C-D74C-47E8-AE48-A2FC94538E9D}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ThriftTest", "..\..\..\test\csharp\ThriftTest\ThriftTest.csproj", "{48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}" + ProjectSection(ProjectDependencies) = postProject + {499EB63C-D74C-47E8-AE48-A2FC94538E9D} = {499EB63C-D74C-47E8-AE48-A2FC94538E9D} + EndProjectSection +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ThriftMSBuildTask", "..\ThriftMSBuildTask\ThriftMSBuildTask.csproj", "{EC0A0231-66EA-4593-A792-C6CA3BB8668E}" +EndProject +Global + GlobalSection(SourceCodeControl) = preSolution + SccNumberOfProjects = 4 + SccProjectName0 = Perforce\u0020Project + SccLocalPath0 = ..\\..\\.. + SccProvider0 = MSSCCI:Perforce\u0020SCM + SccProjectFilePathRelativizedFromConnection0 = lib\\csharp\\src\\ + SccProjectUniqueName1 = Thrift.csproj + SccLocalPath1 = ..\\..\\.. + SccProjectFilePathRelativizedFromConnection1 = lib\\csharp\\src\\ + SccProjectUniqueName2 = ..\\..\\..\\test\\csharp\\ThriftTest\\ThriftTest.csproj + SccLocalPath2 = ..\\..\\.. + SccProjectFilePathRelativizedFromConnection2 = test\\csharp\\ThriftTest\\ + SccProjectUniqueName3 = ..\\ThriftMSBuildTask\\ThriftMSBuildTask.csproj + SccLocalPath3 = ..\\..\\.. + SccProjectFilePathRelativizedFromConnection3 = lib\\csharp\\ThriftMSBuildTask\\ + EndGlobalSection + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {499EB63C-D74C-47E8-AE48-A2FC94538E9D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {499EB63C-D74C-47E8-AE48-A2FC94538E9D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {499EB63C-D74C-47E8-AE48-A2FC94538E9D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {499EB63C-D74C-47E8-AE48-A2FC94538E9D}.Release|Any CPU.Build.0 = Release|Any CPU + {48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}.Release|Any CPU.Build.0 = Release|Any CPU + {EC0A0231-66EA-4593-A792-C6CA3BB8668E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EC0A0231-66EA-4593-A792-C6CA3BB8668E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EC0A0231-66EA-4593-A792-C6CA3BB8668E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EC0A0231-66EA-4593-A792-C6CA3BB8668E}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection +EndGlobal diff --git a/lib/csharp/src/Transport/TBufferedTransport.cs b/lib/csharp/src/Transport/TBufferedTransport.cs new file mode 100644 index 000000000..28a855a55 --- /dev/null +++ b/lib/csharp/src/Transport/TBufferedTransport.cs @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.IO; + +namespace Thrift.Transport +{ + public class TBufferedTransport : TTransport + { + private BufferedStream inputBuffer; + private BufferedStream outputBuffer; + private int bufSize; + private TStreamTransport transport; + + public TBufferedTransport(TStreamTransport transport) + :this(transport, 1024) + { + + } + + public TBufferedTransport(TStreamTransport transport, int bufSize) + { + this.bufSize = bufSize; + this.transport = transport; + InitBuffers(); + } + + private void InitBuffers() + { + if (transport.InputStream != null) + { + inputBuffer = new BufferedStream(transport.InputStream, bufSize); + } + if (transport.OutputStream != null) + { + outputBuffer = new BufferedStream(transport.OutputStream, bufSize); + } + } + + public TTransport UnderlyingTransport + { + get { return transport; } + } + + public override bool IsOpen + { + get { return transport.IsOpen; } + } + + public override void Open() + { + transport.Open(); + InitBuffers(); + } + + public override void Close() + { + if (inputBuffer != null && inputBuffer.CanRead) + { + inputBuffer.Close(); + } + if (outputBuffer != null && outputBuffer.CanWrite) + { + outputBuffer.Close(); + } + } + + public override int Read(byte[] buf, int off, int len) + { + return inputBuffer.Read(buf, off, len); + } + + public override void Write(byte[] buf, int off, int len) + { + outputBuffer.Write(buf, off, len); + } + + public override void Flush() + { + outputBuffer.Flush(); + } + } +} diff --git a/lib/csharp/src/Transport/TServerSocket.cs b/lib/csharp/src/Transport/TServerSocket.cs new file mode 100644 index 000000000..2658fce89 --- /dev/null +++ b/lib/csharp/src/Transport/TServerSocket.cs @@ -0,0 +1,157 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Net.Sockets; + + +namespace Thrift.Transport +{ + public class TServerSocket : TServerTransport + { + /** + * Underlying server with socket + */ + private TcpListener server = null; + + /** + * Port to listen on + */ + private int port = 0; + + /** + * Timeout for client sockets from accept + */ + private int clientTimeout = 0; + + /** + * Whether or not to wrap new TSocket connections in buffers + */ + private bool useBufferedSockets = false; + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(TcpListener listener) + :this(listener, 0) + { + } + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(TcpListener listener, int clientTimeout) + { + this.server = listener; + this.clientTimeout = clientTimeout; + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port) + : this(port, 0) + { + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port, int clientTimeout) + :this(port, clientTimeout, false) + { + } + + public TServerSocket(int port, int clientTimeout, bool useBufferedSockets) + { + this.port = port; + this.clientTimeout = clientTimeout; + this.useBufferedSockets = useBufferedSockets; + try + { + // Make server socket + server = new TcpListener(System.Net.IPAddress.Any, this.port); + } + catch (Exception) + { + server = null; + throw new TTransportException("Could not create ServerSocket on port " + port + "."); + } + } + + public override void Listen() + { + // Make sure not to block on accept + if (server != null) + { + try + { + server.Start(); + } + catch (SocketException sx) + { + throw new TTransportException("Could not accept on listening socket: " + sx.Message); + } + } + } + + protected override TTransport AcceptImpl() + { + if (server == null) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "No underlying server socket."); + } + try + { + TcpClient result = server.AcceptTcpClient(); + TSocket result2 = new TSocket(result); + result2.Timeout = clientTimeout; + if (useBufferedSockets) + { + TBufferedTransport result3 = new TBufferedTransport(result2); + return result3; + } + else + { + return result2; + } + } + catch (Exception ex) + { + throw new TTransportException(ex.ToString()); + } + } + + public override void Close() + { + if (server != null) + { + try + { + server.Stop(); + } + catch (Exception ex) + { + throw new TTransportException("WARNING: Could not close server socket: " + ex); + } + server = null; + } + } + } +} diff --git a/lib/csharp/src/Transport/TServerTransport.cs b/lib/csharp/src/Transport/TServerTransport.cs new file mode 100644 index 000000000..9cb52e5c6 --- /dev/null +++ b/lib/csharp/src/Transport/TServerTransport.cs @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Transport +{ + public abstract class TServerTransport + { + public abstract void Listen(); + public abstract void Close(); + protected abstract TTransport AcceptImpl(); + + public TTransport Accept() + { + TTransport transport = AcceptImpl(); + if (transport == null) { + throw new TTransportException("accept() may not return NULL"); + } + return transport; + } + } +} diff --git a/lib/csharp/src/Transport/TSocket.cs b/lib/csharp/src/Transport/TSocket.cs new file mode 100644 index 000000000..18cf1547b --- /dev/null +++ b/lib/csharp/src/Transport/TSocket.cs @@ -0,0 +1,144 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Net.Sockets; + +namespace Thrift.Transport +{ + public class TSocket : TStreamTransport + { + private TcpClient client = null; + private string host = null; + private int port = 0; + private int timeout = 0; + + public TSocket(TcpClient client) + { + this.client = client; + + if (IsOpen) + { + inputStream = client.GetStream(); + outputStream = client.GetStream(); + } + } + + public TSocket(string host, int port) : this(host, port, 0) + { + } + + public TSocket(string host, int port, int timeout) + { + this.host = host; + this.port = port; + this.timeout = timeout; + + InitSocket(); + } + + private void InitSocket() + { + client = new TcpClient(); + client.ReceiveTimeout = client.SendTimeout = timeout; + } + + public int Timeout + { + set + { + client.ReceiveTimeout = client.SendTimeout = timeout = value; + } + } + + public TcpClient TcpClient + { + get + { + return client; + } + } + + public string Host + { + get + { + return host; + } + } + + public int Port + { + get + { + return port; + } + } + + public override bool IsOpen + { + get + { + if (client == null) + { + return false; + } + + return client.Connected; + } + } + + public override void Open() + { + if (IsOpen) + { + throw new TTransportException(TTransportException.ExceptionType.AlreadyOpen, "Socket already connected"); + } + + if (String.IsNullOrEmpty(host)) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open null host"); + } + + if (port <= 0) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open without port"); + } + + if (client == null) + { + InitSocket(); + } + + client.Connect(host, port); + inputStream = client.GetStream(); + outputStream = client.GetStream(); + } + + public override void Close() + { + base.Close(); + if (client != null) + { + client.Close(); + client = null; + } + } + } +} diff --git a/lib/csharp/src/Transport/TStreamTransport.cs b/lib/csharp/src/Transport/TStreamTransport.cs new file mode 100644 index 000000000..7681e0d97 --- /dev/null +++ b/lib/csharp/src/Transport/TStreamTransport.cs @@ -0,0 +1,103 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.IO; + +namespace Thrift.Transport +{ + public class TStreamTransport : TTransport + { + protected Stream inputStream; + protected Stream outputStream; + + public TStreamTransport() + { + } + + public TStreamTransport(Stream inputStream, Stream outputStream) + { + this.inputStream = inputStream; + this.outputStream = outputStream; + } + + public Stream OutputStream + { + get { return outputStream; } + } + + public Stream InputStream + { + get { return inputStream; } + } + + public override bool IsOpen + { + get { return true; } + } + + public override void Open() + { + } + + public override void Close() + { + if (inputStream != null) + { + inputStream.Close(); + inputStream = null; + } + if (outputStream != null) + { + outputStream.Close(); + outputStream = null; + } + } + + public override int Read(byte[] buf, int off, int len) + { + if (inputStream == null) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot read from null inputstream"); + } + + return inputStream.Read(buf, off, len); + } + + public override void Write(byte[] buf, int off, int len) + { + if (outputStream == null) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot write to null outputstream"); + } + + outputStream.Write(buf, off, len); + } + + public override void Flush() + { + if (outputStream == null) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot flush null outputstream"); + } + + outputStream.Flush(); + } + } +} diff --git a/lib/csharp/src/Transport/TTransport.cs b/lib/csharp/src/Transport/TTransport.cs new file mode 100644 index 000000000..83f6776c0 --- /dev/null +++ b/lib/csharp/src/Transport/TTransport.cs @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Transport +{ + public abstract class TTransport + { + public abstract bool IsOpen + { + get; + } + + public bool Peek() + { + return IsOpen; + } + + public abstract void Open(); + + public abstract void Close(); + + public abstract int Read(byte[] buf, int off, int len); + + public int ReadAll(byte[] buf, int off, int len) + { + int got = 0; + int ret = 0; + + while (got < len) + { + ret = Read(buf, off + got, len - got); + if (ret <= 0) + { + throw new TTransportException("Cannot read, Remote side has closed"); + } + got += ret; + } + + return got; + } + + public abstract void Write(byte[] buf, int off, int len); + + public virtual void Flush() + { + } + } +} diff --git a/lib/csharp/src/Transport/TTransportException.cs b/lib/csharp/src/Transport/TTransportException.cs new file mode 100644 index 000000000..fe10faa5e --- /dev/null +++ b/lib/csharp/src/Transport/TTransportException.cs @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Transport +{ + public class TTransportException : Exception + { + protected ExceptionType type; + + public TTransportException() + : base() + { + } + + public TTransportException(ExceptionType type) + : this() + { + this.type = type; + } + + public TTransportException(ExceptionType type, string message) + : base(message) + { + this.type = type; + } + + public TTransportException(string message) + : base(message) + { + } + + public ExceptionType Type + { + get { return type; } + } + + public enum ExceptionType + { + Unknown, + NotOpen, + AlreadyOpen, + TimedOut, + EndOfFile + } + } +} diff --git a/lib/csharp/src/Transport/TTransportFactory.cs b/lib/csharp/src/Transport/TTransportFactory.cs new file mode 100644 index 000000000..3d3694db2 --- /dev/null +++ b/lib/csharp/src/Transport/TTransportFactory.cs @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Transport +{ + /// <summary> + /// From Mark Slee & Aditya Agarwal of Facebook: + /// Factory class used to create wrapped instance of Transports. + /// This is used primarily in servers, which get Transports from + /// a ServerTransport and then may want to mutate them (i.e. create + /// a BufferedTransport from the underlying base transport) + /// </summary> + public class TTransportFactory + { + public virtual TTransport GetTransport(TTransport trans) + { + return trans; + } + } +} diff --git a/lib/erl/Makefile b/lib/erl/Makefile new file mode 100644 index 000000000..77fe8b6a6 --- /dev/null +++ b/lib/erl/Makefile @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +MODULES = \ + src + +all clean docs: + for dir in $(MODULES); do \ + (cd $$dir; ${MAKE} $@); \ + done + +install: all + echo 'No install target, sorry.' + +check: all + +distclean: clean + +# Hack to make "make dist" work. +# This should not work, but it appears to. +distdir: diff --git a/lib/erl/README b/lib/erl/README new file mode 100644 index 000000000..ddb6946f3 --- /dev/null +++ b/lib/erl/README @@ -0,0 +1,56 @@ +Thrift Erlang Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Example +======= + +Example session using thrift_client: + +118> f(), {ok, C} = thrift_client:start_link("localhost", 9090, thriftTest_thrif +t). +{ok,<0.271.0>} +119> thrift_client:call(C, testVoid, []). +{ok,ok} +120> thrift_client:call(C, testVoid, [asdf]). +{error,{bad_args,testVoid,[asdf]}} +121> thrift_client:call(C, testI32, [123]). +{ok,123} +122> thrift_client:call(C, testOneway, [1]). +{ok,ok} +123> catch thrift_client:call(C, testXception, ["foo"]). +{error,{no_function,testXception}} +124> catch thrift_client:call(C, testException, ["foo"]). +{ok,ok} +125> catch thrift_client:call(C, testException, ["Xception"]). +{xception,1001,"This is an Xception"} +126> thrift_client:call(C, testException, ["Xception"]). + +=ERROR REPORT==== 24-Feb-2008::23:00:23 === +Error in process <0.269.0> with exit value: {{nocatch,{xception,1001,"This is an + Xception"}},[{thrift_client,call,3},{erl_eval,do_apply,5},{shell,exprs,6},{shel +l,eval_loop,3}]} + +** exited: {{nocatch,{xception,1001,"This is an Xception"}}, + [{thrift_client,call,3}, + {erl_eval,do_apply,5}, + {shell,exprs,6}, + {shell,eval_loop,3}]} ** diff --git a/lib/erl/build/beamver b/lib/erl/build/beamver new file mode 100644 index 000000000..2b5f77b33 --- /dev/null +++ b/lib/erl/build/beamver @@ -0,0 +1,59 @@ +#!/bin/sh + +# erlwareSys: otp/build/beamver,v 1.1 2002/02/14 11:45:20 hal Exp $ + +# usage: beamver <beam_file> +# +# if there's a usable -vsn() attribute, print it and exit with status 0 +# otherwise, print nothing and exit with status 1 + +# From the Erlang shell: +# +# 5> code:which(acca_inets). +# "/home/martin/work/otp/releases/<app>/../../acca/ebin/<app>.beam" +# +# 8> beam_lib:version(code:which(<app>)). +# {ok,{<app>,['$Id: beamver,v 1.1.1.1 2003/06/13 21:43:21 mlogan Exp $ ']}} + +# TMPFILE looks like this: +# +# io:format("hello ~p~n", +# beam_lib:version("/home/hal/work/otp/acca/ebin/acca_inets.beam")]). + +TMPFILE=/tmp/beamver.$$ + +# exit with failure if we can't read the file +test -f "$1" || exit 1 +BEAMFILE=\"$1\" + +cat > $TMPFILE <<_EOF +io:format("~p~n", + [beam_lib:version($BEAMFILE)]). +_EOF + +# beam_result is {ok,{Module_name, Beam_version} or {error,beam_lib,{Reason}} +beam_result=`erl -noshell \ + -s file eval $TMPFILE \ + -s erlang halt` + +rm -f $TMPFILE + +# sed regexes: +# remove brackets and anything outside them +# remove quotes and anything outside them +# remove apostrophes and anything outside them +# remove leading and trailing spaces + +case $beam_result in +\{ok*) + echo $beam_result | sed -e 's/.*\[\(.*\)].*/\1/' \ + -e 's/.*\"\(.*\)\".*/\1/' \ + -e "s/.*\'\(.*\)\'.*/\1/" \ + -e 's/ *$//' -e 's/^ *//' + exit 0 + ;; +*) + exit 1 + ;; +esac + diff --git a/lib/erl/build/buildtargets.mk b/lib/erl/build/buildtargets.mk new file mode 100644 index 000000000..db52b785a --- /dev/null +++ b/lib/erl/build/buildtargets.mk @@ -0,0 +1,15 @@ +EBIN ?= ../ebin +ESRC ?= . +EMULATOR = beam + +ERLC_WFLAGS = -W +ERLC = erlc $(ERLC_WFLAGS) $(ERLC_FLAGS) +ERL = erl -boot start_clean + +$(EBIN)/%.beam: $(ESRC)/%.erl + @echo " ERLC $<" + @$(ERLC) $(ERL_FLAGS) $(ERL_COMPILE_FLAGS) -o$(EBIN) $< + +.erl.beam: + $(ERLC) $(ERL_FLAGS) $(ERL_COMPILE_FLAGS) -o$(dir $@) $< + diff --git a/lib/erl/build/colors.mk b/lib/erl/build/colors.mk new file mode 100644 index 000000000..4d69c41d8 --- /dev/null +++ b/lib/erl/build/colors.mk @@ -0,0 +1,24 @@ +# Colors to assist visual inspection of make output. + +# Colors +LGRAY=$$'\e[0;37m' +DGRAY=$$'\e[1;30m' +LGREEN=$$'\e[1;32m' +LBLUE=$$'\e[1;34m' +LCYAN=$$'\e[1;36m' +LPURPLE=$$'\e[1;35m' +LRED=$$'\e[1;31m' +NO_COLOR=$$'\e[0m' +DEFAULT=$$'\e[0m' +BLACK=$$'\e[0;30m' +BLUE=$$'\e[0;34m' +GREEN=$$'\e[0;32m' +CYAN=$$'\e[0;36m' +RED=$$'\e[0;31m' +PURPLE=$$'\e[0;35m' +BROWN=$$'\e[0;33m' +YELLOW=$$'\e[1;33m' +WHITE=$$'\e[1;37m' + +BOLD=$$'\e[1;37m' +OFF=$$'\e[0m' diff --git a/lib/erl/build/docs.mk b/lib/erl/build/docs.mk new file mode 100644 index 000000000..b0b7377f6 --- /dev/null +++ b/lib/erl/build/docs.mk @@ -0,0 +1,12 @@ +EDOC_PATH=../../../tools/utilities + +#single place to include docs from. +docs: + @mkdir -p ../doc + @echo -n $${MY_BLUE:-$(BLUE)}; \ + $(EDOC_PATH)/edoc $(APP_NAME); \ + if [ $$? -eq 0 ]; then \ + echo $${MY_LRED:-$(LRED)}"$$d Doc Failed"; \ + fi; \ + echo -n $(OFF)$(NO_COLOR) + diff --git a/lib/erl/build/mime.types b/lib/erl/build/mime.types new file mode 100644 index 000000000..d6e3c0d0b --- /dev/null +++ b/lib/erl/build/mime.types @@ -0,0 +1,98 @@ + +application/activemessage +application/andrew-inset +application/applefile +application/atomicmail +application/dca-rft +application/dec-dx +application/mac-binhex40 hqx +application/mac-compactpro cpt +application/macwriteii +application/msword doc +application/news-message-id +application/news-transmission +application/octet-stream bin dms lha lzh exe class +application/oda oda +application/pdf pdf +application/postscript ai eps ps +application/powerpoint ppt +application/remote-printing +application/rtf rtf +application/slate +application/wita +application/wordperfect5.1 +application/x-bcpio bcpio +application/x-cdlink vcd +application/x-compress Z +application/x-cpio cpio +application/x-csh csh +application/x-director dcr dir dxr +application/x-dvi dvi +application/x-gtar gtar +application/x-gzip gz +application/x-hdf hdf +application/x-httpd-cgi cgi +application/x-koan skp skd skt skm +application/x-latex latex +application/x-mif mif +application/x-netcdf nc cdf +application/x-sh sh +application/x-shar shar +application/x-stuffit sit +application/x-sv4cpio sv4cpio +application/x-sv4crc sv4crc +application/x-tar tar +application/x-tcl tcl +application/x-tex tex +application/x-texinfo texinfo texi +application/x-troff t tr roff +application/x-troff-man man +application/x-troff-me me +application/x-troff-ms ms +application/x-ustar ustar +application/x-wais-source src +application/zip zip +audio/basic au snd +audio/mpeg mpga mp2 +audio/x-aiff aif aiff aifc +audio/x-pn-realaudio ram +audio/x-pn-realaudio-plugin rpm +audio/x-realaudio ra +audio/x-wav wav +chemical/x-pdb pdb xyz +image/gif gif +image/ief ief +image/jpeg jpeg jpg jpe +image/png png +image/tiff tiff tif +image/x-cmu-raster ras +image/x-portable-anymap pnm +image/x-portable-bitmap pbm +image/x-portable-graymap pgm +image/x-portable-pixmap ppm +image/x-rgb rgb +image/x-xbitmap xbm +image/x-xpixmap xpm +image/x-xwindowdump xwd +message/external-body +message/news +message/partial +message/rfc822 +multipart/alternative +multipart/appledouble +multipart/digest +multipart/mixed +multipart/parallel +text/html html htm +text/x-server-parsed-html shtml +text/plain txt +text/richtext rtx +text/tab-separated-values tsv +text/x-setext etx +text/x-sgml sgml sgm +video/mpeg mpeg mpg mpe +video/quicktime qt mov +video/x-msvideo avi +video/x-sgi-movie movie +x-conference/x-cooltalk ice +x-world/x-vrml wrl vrml diff --git a/lib/erl/build/otp.mk b/lib/erl/build/otp.mk new file mode 100644 index 000000000..1d16e2c83 --- /dev/null +++ b/lib/erl/build/otp.mk @@ -0,0 +1,146 @@ +# +----------------------------------------------------------------------+ +# $Id: otp.mk,v 1.4 2004/07/01 14:57:10 tfee Exp $ +# +----------------------------------------------------------------------+ + +# otp.mk +# - to be included in all OTP Makefiles +# installed to /usr/local/include/erlang/otp.mk + +# gmake looks in /usr/local/include - that's hard-coded +# users of this file will use +# include erlang/top.mk + +# most interface files will be installed to $ERL_RUN_TOP/app-vsn/include/*.hrl + +# group owner for library/include directories +ERLANGDEV_GROUP=erlangdev + +# ERL_TOP is root of Erlang source tree +# ERL_RUN_TOP is root of Erlang target tree (some Ericsson Makefiles use $ROOT) +# ERLANG_OTP is target root for Erlang code +# - see sasl/systools reference manual page; grep "TEST" + +# OS_TYPE is FreeBSD, NetBSD, OpenBSD, Linux, SCO_SV, SunOS. +OS_TYPE=${shell uname} + +# MHOST is the host where this Makefile runs. +MHOST=${shell hostname -s} +ERL_COMPILE_FLAGS+=-W0 + +# The location of the erlang runtime system. +ifndef ERL_RUN_TOP +ERL_RUN_TOP=/usr/local/lib/erlang +endif + + +# Edit to reflect local environment. +# ifeq (${OS_TYPE},Linux) +# ERL_RUN_TOP=/usr/local/lib/erlang +# Note* ERL_RUN_TOP can be determined by starting an +# erlang shell and typing code:root_dir(). +# ERL_TOP=a symbolic link to the actual source top, which changes from version to version +# Note* ERL_TOP is the directory where the erlang +# source files reside. Make sure to run ./configure there. +# TARGET=i686-pc-linux-gnu +# Note* Target can be found in $ERL_TOP/erts +# endif + +# See above for directions. +ifeq (${OS_TYPE},Linux) +ERL_TOP=/opt/OTP_SRC +TARGET=i686-pc-linux-gnu +endif + +ERLANG_OTP=/usr/local/erlang/otp +VAR_OTP=/var/otp + + +# Aliases for common binaries +# Note - CFLAGS is modified in erlang.conf + + +################################ +# SunOS +################################ +ifeq (${OS_TYPE},SunOS) + + CC=gcc + CXX=g++ + AR=/usr/ccs/bin/ar + ARFLAGS=-rv + CXXFLAGS+=${CFLAGS} -I/usr/include/g++ + LD=/usr/ccs/bin/ld + RANLIB=/usr/ccs/bin/ranlib + +CFLAGS+=-Wall -pedantic -ansi -O +CORE=*.core +endif + + +################################ +# FreeBSD +################################ +ifeq (${OS_TYPE},FreeBSD) + + ifdef LINUXBIN + COMPAT_LINUX=/compat/linux + CC=${COMPAT_LINUX}/usr/bin/gcc + CXX=${COMPAT_LINUX}/usr/bin/g++ + AR=${COMPAT_LINUX}/usr/bin/ar + ARFLAGS=-rv + CXXFLAGS+=-fhandle-exceptions ${CFLAGS} -I${COMPAT_LINUX}/usr/include/g++ + LD=${COMPAT_LINUX}/usr/bin/ld + RANLIB=${COMPAT_LINUX}/usr/bin/ranlib + BRANDELF=brandelf -t Linux + else + CC=gcc + CXX=g++ + AR=/usr/bin/ar + ARFLAGS=-rv + CXXFLAGS+=-fhandle-exceptions ${CFLAGS} -I/usr/include/g++ + LD=/usr/bin/ld + RANLIB=/usr/bin/ranlib + BRANDELF=@true + + ifdef USES_PTHREADS + CFLAGS+=-D_THREAD_SAFE + LDFLAGS+=-lc_r + + # -pthread flag for 3.0+ + ifneq (${shell uname -r | cut -d. -f1},2) + CFLAGS+=-pthread + endif + endif + endif + +CFLAGS+=-Wall -pedantic -ansi -O -DFREEBSD +CORE=*.core +endif + +################################ +# OpenBSD +################################ +ifeq (${OS_TYPE},OpenBSD) + + CC=gcc + CXX=g++ + AR=/usr/bin/ar + ARFLAGS=-rv + CXXFLAGS+=${CFLAGS} -I/usr/include/g++ + LD=/usr/bin/ld + RANLIB=/usr/bin/ranlib + + ifdef USES_PTHREADS + CFLAGS+=-D_THREAD_SAFE + LDFLAGS+=-lc_r + + # -pthread flag for 3.0+ + ifneq (${shell uname -r | cut -d. -f1},2) + CFLAGS+=-pthread + endif + endif + +CFLAGS+=-Wall -pedantic -ansi -O -DOPENBSD +CORE=*.core +endif + diff --git a/lib/erl/build/otp_subdir.mk b/lib/erl/build/otp_subdir.mk new file mode 100644 index 000000000..2a36c6588 --- /dev/null +++ b/lib/erl/build/otp_subdir.mk @@ -0,0 +1,85 @@ +# Comment by tfee 2004-07-01 +# ========================== +# This file is a mod of the stock OTP one. +# The change allows make to stop when a compile error occurs. +# This file needs to go into two places: +# /usr/local/include/erlang +# /opt/OTP_SRC/make +# +# where OTP_SRC is a symbolic link to a peer directory containing +# the otp source, e.g. otp_src_R9C-2. +# +# After installing OTP, running sudo make install in otp/build +# will push this file out to the two places listed above. +# +# The mod involves setting the shell variable $short_circuit, which we +# introduce - ie it is not in the stock file. This variable is tested +# to affect execution flow and is also returned to affect the flow in +# the calling script (this one). The latter step is necessary because +# of the recursion involved. +# ===================================================================== + + +# ``The contents of this file are subject to the Erlang Public License, +# Version 1.1, (the "License"); you may not use this file except in +# compliance with the License. You should have received a copy of the +# Erlang Public License along with this software. If not, it can be +# retrieved via the world wide web at http://www.erlang.org/. +# +# Software distributed under the License is distributed on an "AS IS" +# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See +# the License for the specific language governing rights and limitations +# under the License. +# +# The Initial Developer of the Original Code is Ericsson Utvecklings AB. +# Portions created by Ericsson are Copyright 1999, Ericsson Utvecklings +# AB. All Rights Reserved.'' +# +# $Id: otp_subdir.mk,v 1.5 2004/07/12 15:12:23 jeinhorn Exp $ +# +# +# Make include file for otp + +.PHONY: debug opt release docs release_docs tests release_tests \ + clean depend + +# +# Targets that don't affect documentation directories +# +debug opt release docs release_docs tests release_tests clean depend: prepare + @set -e ; \ + app_pwd=`pwd` ; \ + if test -f vsn.mk; then \ + echo "=== Entering application" `basename $$app_pwd` ; \ + fi ; \ + case "$(MAKE)" in *clearmake*) tflag="-T";; *) tflag="";; esac; \ + short_circuit=0 ; \ + for d in $(SUB_DIRECTORIES); do \ + if [[ $$short_circuit = 0 ]]; then \ + if test -f $$d/SKIP ; then \ + echo "=== Skipping subdir $$d, reason:" ; \ + cat $$d/SKIP ; \ + echo "===" ; \ + else \ + if test ! -d $$d ; then \ + echo "=== Skipping subdir $$d, it is missing" ; \ + else \ + xflag="" ; \ + if test -f $$d/ignore_config_record.inf; then \ + xflag=$$tflag ; \ + fi ; \ + (cd $$d && $(MAKE) $$xflag $@) ; \ + if [[ $$? != 0 ]]; then \ + short_circuit=1 ; \ + fi ; \ + fi ; \ + fi ; \ + fi ; \ + done ; \ + if test -f vsn.mk; then \ + echo "=== Leaving application" `basename $$app_pwd` ; \ + fi ; \ + exit $$short_circuit + +prepare: + echo diff --git a/lib/erl/build/raw_test.mk b/lib/erl/build/raw_test.mk new file mode 100644 index 000000000..bf8535d14 --- /dev/null +++ b/lib/erl/build/raw_test.mk @@ -0,0 +1,29 @@ +# for testing erlang files directly. The set up for a +# this type of test would be +# files to test reside in lib/<app_name>/src and the test files which are +# just plain erlang code reside in lib/<app_name>/test +# +# This color codes emitted while the tests run assume that you are using +# a white-on-black display schema. If not, e.g. if you use a white +# background, you will not be able to read the "WHITE" text. +# You can override this by supplying your own "white" color, +# which may in fact be black! You do this by defining an environment +# variable named "MY_WHITE" and setting it to $'\e[0;30m' (which is +# simply bash's way of specifying "Escape [ 0 ; 3 0 m"). +# Similarly, you can set your versions of the standard colors +# found in colors.mk. + +test: + @TEST_MODULES=`ls *_test.erl`; \ + trap "echo $(OFF)$(NO_COLOR); exit 1;" 1 2 3 6; \ + for d in $$TEST_MODULES; do \ + echo $${MY_GREEN:-$(GREEN)}"Testing File $$d" $${MY_WHITE:-$(WHITE)}; \ + echo -n $${MY_BLUE:-$(BLUE)}; \ + erl -name $(APP_NAME) $(TEST_LIBS) \ + -s `basename $$d .erl` all -s init stop -noshell; \ + if [ $$? -ne 0 ]; then \ + echo $${MY_LRED:-$(LRED)}"$$d Test Failed"; \ + fi; \ + echo -n $(OFF)$(NO_COLOR); \ + done + diff --git a/lib/erl/include/thrift_constants.hrl b/lib/erl/include/thrift_constants.hrl new file mode 100644 index 000000000..36eb49bf2 --- /dev/null +++ b/lib/erl/include/thrift_constants.hrl @@ -0,0 +1,54 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +%% TType +-define(tType_STOP, 0). +-define(tType_VOID, 1). +-define(tType_BOOL, 2). +-define(tType_BYTE, 3). +-define(tType_DOUBLE, 4). +-define(tType_I16, 6). +-define(tType_I32, 8). +-define(tType_I64, 10). +-define(tType_STRING, 11). +-define(tType_STRUCT, 12). +-define(tType_MAP, 13). +-define(tType_SET, 14). +-define(tType_LIST, 15). + +% TMessageType +-define(tMessageType_CALL, 1). +-define(tMessageType_REPLY, 2). +-define(tMessageType_EXCEPTION, 3). +-define(tMessageType_ONEWAY, 4). + +% TApplicationException +-define(TApplicationException_Structure, + {struct, [{1, string}, + {2, i32}]}). + +-record('TApplicationException', {message, type}). + +-define(TApplicationException_UNKNOWN, 0). +-define(TApplicationException_UNKNOWN_METHOD, 1). +-define(TApplicationException_INVALID_MESSAGE_TYPE, 2). +-define(TApplicationException_WRONG_METHOD_NAME, 3). +-define(TApplicationException_BAD_SEQUENCE_ID, 4). +-define(TApplicationException_MISSING_RESULT, 5). + diff --git a/lib/erl/include/thrift_protocol.hrl b/lib/erl/include/thrift_protocol.hrl new file mode 100644 index 000000000..f4e1901f7 --- /dev/null +++ b/lib/erl/include/thrift_protocol.hrl @@ -0,0 +1,31 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-ifndef(THRIFT_PROTOCOL_INCLUDED). +-define(THRIFT_PROTOCOL_INCLUDED, yea). + +-record(protocol_message_begin, {name, type, seqid}). +-record(protocol_struct_begin, {name}). +-record(protocol_field_begin, {name, type, id}). +-record(protocol_map_begin, {ktype, vtype, size}). +-record(protocol_list_begin, {etype, size}). +-record(protocol_set_begin, {etype, size}). + + +-endif. diff --git a/lib/erl/src/Makefile b/lib/erl/src/Makefile new file mode 100644 index 000000000..980af8122 --- /dev/null +++ b/lib/erl/src/Makefile @@ -0,0 +1,116 @@ +# $Id: Makefile,v 1.3 2004/08/13 16:35:59 mlogan Exp $ +# +include ../build/otp.mk +include ../build/colors.mk +include ../build/buildtargets.mk + +# ---------------------------------------------------- +# Application version +# ---------------------------------------------------- + +include ../vsn.mk +APP_NAME=thrift +PFX=thrift +VSN=$(THRIFT_VSN) + +# ---------------------------------------------------- +# Install directory specification +# WARNING: INSTALL_DIR the command to install a directory. +# INSTALL_DST is the target directory +# ---------------------------------------------------- +INSTALL_DST = $(ERLANG_OTP)/lib/$(APP_NAME)-$(VSN) + +# ---------------------------------------------------- +# Target Specs +# ---------------------------------------------------- + + +MODULES = $(shell find . -name \*.erl | sed 's:^\./::' | sed 's/\.erl//') +MODULES_STRING_LIST = $(shell find . -name \*.erl | sed 's:^\./:":' | sed 's/\.erl/",/') + +HRL_FILES= +INTERNAL_HRL_FILES= $(APP_NAME).hrl +ERL_FILES= $(MODULES:%=%.erl) +DOC_FILES=$(ERL_FILES) + +APP_FILE= $(APP_NAME).app +APPUP_FILE= $(APP_NAME).appup + +APP_SRC= $(APP_FILE).src +APPUP_SRC= $(APPUP_FILE).src + +APP_TARGET= $(EBIN)/$(APP_FILE) +APPUP_TARGET= $(EBIN)/$(APPUP_FILE) + +BEAMS= $(MODULES:%=$(EBIN)/%.$(EMULATOR)) +TARGET_FILES= $(BEAMS) $(APP_TARGET) $(APPUP_TARGET) + +WEB_TARGET=/var/yaws/www/$(APP_NAME) + +# ---------------------------------------------------- +# FLAGS +# ---------------------------------------------------- + +ERL_FLAGS += +ERL_INCLUDE = -I../include -I../../fslib/include -I../../system_status/include +ERL_COMPILE_FLAGS += $(ERL_INCLUDE) + +# ---------------------------------------------------- +# Targets +# ---------------------------------------------------- + +all debug opt: $(EBIN) $(TARGET_FILES) + +#$(EBIN)/rm_logger.beam: $(APP_NAME).hrl +include ../build/docs.mk + +# Note: In the open-source build clean must not destroy the preloaded +# beam files. +clean: + rm -f $(TARGET_FILES) + rm -f *~ + rm -f core + rm -rf $(EBIN) + rm -rf *html + +$(EBIN): + mkdir $(EBIN) + +dialyzer: $(TARGET_FILES) + dialyzer --src -r . $(ERL_INCLUDE) + +# ---------------------------------------------------- +# Special Build Targets +# ---------------------------------------------------- + +$(APP_TARGET): $(APP_SRC) ../vsn.mk $(BEAMS) + sed -e 's;%VSN%;$(VSN);' \ + -e 's;%PFX%;$(PFX);' \ + -e 's;%APP_NAME%;$(APP_NAME);' \ + -e 's;%MODULES%;%MODULES%$(MODULES_STRING_LIST);' \ + $< > $<".tmp" + sed -e 's/%MODULES%\(.*\),/\1/' \ + $<".tmp" > $@ + rm $<".tmp" + +$(APPUP_TARGET): $(APPUP_SRC) ../vsn.mk + sed -e 's;%VSN%;$(VSN);' $< > $@ + +$(WEB_TARGET): ../markup/* + rm -rf $(WEB_TARGET) + mkdir $(WEB_TARGET) + cp -r ../markup/ $(WEB_TARGET) + cp -r ../skins/ $(WEB_TARGET) + +# ---------------------------------------------------- +# Install Target +# ---------------------------------------------------- + +install: all $(WEB_TARGET) +# $(INSTALL_DIR) $(INSTALL_DST)/src +# $(INSTALL_DATA) $(ERL_FILES) $(INSTALL_DST)/src +# $(INSTALL_DATA) $(INTERNAL_HRL_FILES) $(INSTALL_DST)/src +# $(INSTALL_DIR) $(INSTALL_DST)/include +# $(INSTALL_DATA) $(HRL_FILES) $(INSTALL_DST)/include +# $(INSTALL_DIR) $(INSTALL_DST)/ebin +# $(INSTALL_DATA) $(TARGET_FILES) $(INSTALL_DST)/ebin diff --git a/lib/erl/src/test_handler.erl b/lib/erl/src/test_handler.erl new file mode 100644 index 000000000..28a3acd30 --- /dev/null +++ b/lib/erl/src/test_handler.erl @@ -0,0 +1,26 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(test_handler). + +-export([handle_function/2]). + +handle_function(add, Params = {A, B}) -> + io:format("Got params: ~p~n", [Params]), + {reply, A + B}. diff --git a/lib/erl/src/test_service.erl b/lib/erl/src/test_service.erl new file mode 100644 index 000000000..7aa4827f8 --- /dev/null +++ b/lib/erl/src/test_service.erl @@ -0,0 +1,29 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(test_service). +% +% Test service definition + +-export([function_info/2]). + +function_info(add, params_type) -> + {struct, [{1, i32}, + {2, i32}]}; +function_info(add, reply_type) -> i32. diff --git a/lib/erl/src/thrift.app.src b/lib/erl/src/thrift.app.src new file mode 100644 index 000000000..681b3eb3e --- /dev/null +++ b/lib/erl/src/thrift.app.src @@ -0,0 +1,44 @@ +%%% -*- mode:erlang -*- +{application, %APP_NAME%, + [ + % A quick description of the application. + {description, "Thrift bindings"}, + + % The version of the applicaton + {vsn, "%VSN%"}, + + % All modules used by the application. + {modules, [ + %MODULES% + ]}, + + % All of the registered names the application uses. This can be ignored. + {registered, []}, + + % Applications that are to be started prior to this one. This can be ignored + % leave it alone unless you understand it well and let the .rel files in + % your release handle this. + {applications, + [ + kernel, + stdlib + ]}, + + % OTP application loader will load, but not start, included apps. Again + % this can be ignored as well. To load but not start an application it + % is easier to include it in the .rel file followed by the atom 'none' + {included_applications, []}, + + % configuration parameters similar to those in the config file specified + % on the command line. can be fetched with gas:get_env + {env, [ + % If an error/crash occurs during processing of a function, + % should the TApplicationException serialized back to the client + % include the erlang backtrace? + {exceptions_include_traces, true} + ]}, + + % The Module and Args used to start this application. + {mod, {thrift_app, []}} + ] +}. diff --git a/lib/erl/src/thrift.appup.src b/lib/erl/src/thrift.appup.src new file mode 100644 index 000000000..54a63833e --- /dev/null +++ b/lib/erl/src/thrift.appup.src @@ -0,0 +1 @@ +{"%VSN%",[],[]}. diff --git a/lib/erl/src/thrift_base64_transport.erl b/lib/erl/src/thrift_base64_transport.erl new file mode 100644 index 000000000..9d13151c6 --- /dev/null +++ b/lib/erl/src/thrift_base64_transport.erl @@ -0,0 +1,64 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_base64_transport). + +-behaviour(thrift_transport). + +%% API +-export([new/1, new_transport_factory/1]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +%% State +-record(b64_transport, {wrapped}). + +new(Wrapped) -> + State = #b64_transport{wrapped = Wrapped}, + thrift_transport:new(?MODULE, State). + + +write(#b64_transport{wrapped = Wrapped}, Data) -> + thrift_transport:write(Wrapped, base64:encode(iolist_to_binary(Data))). + + +%% base64 doesn't support reading quite yet since it would involve +%% nasty buffering and such +read(#b64_transport{wrapped = Wrapped}, Data) -> + {error, no_reads_allowed}. + + +flush(#b64_transport{wrapped = Wrapped}) -> + thrift_transport:write(Wrapped, <<"\n">>), + thrift_transport:flush(Wrapped). + + +close(Me = #b64_transport{wrapped = Wrapped}) -> + flush(Me), + thrift_transport:close(Wrapped). + + +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +new_transport_factory(WrapFactory) -> + F = fun() -> + {ok, Wrapped} = WrapFactory(), + new(Wrapped) + end, + {ok, F}. diff --git a/lib/erl/src/thrift_binary_protocol.erl b/lib/erl/src/thrift_binary_protocol.erl new file mode 100644 index 000000000..ad5338422 --- /dev/null +++ b/lib/erl/src/thrift_binary_protocol.erl @@ -0,0 +1,325 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_binary_protocol). + +-behavior(thrift_protocol). + +-include("thrift_constants.hrl"). +-include("thrift_protocol.hrl"). + +-export([new/1, new/2, + read/2, + write/2, + flush_transport/1, + close_transport/1, + + new_protocol_factory/2 + ]). + +-record(binary_protocol, {transport, + strict_read=true, + strict_write=true + }). + +-define(VERSION_MASK, 16#FFFF0000). +-define(VERSION_1, 16#80010000). +-define(TYPE_MASK, 16#000000ff). + +new(Transport) -> + new(Transport, _Options = []). + +new(Transport, Options) -> + State = #binary_protocol{transport = Transport}, + State1 = parse_options(Options, State), + thrift_protocol:new(?MODULE, State1). + +parse_options([], State) -> + State; +parse_options([{strict_read, Bool} | Rest], State) when is_boolean(Bool) -> + parse_options(Rest, State#binary_protocol{strict_read=Bool}); +parse_options([{strict_write, Bool} | Rest], State) when is_boolean(Bool) -> + parse_options(Rest, State#binary_protocol{strict_write=Bool}). + + +flush_transport(#binary_protocol{transport = Transport}) -> + thrift_transport:flush(Transport). + +close_transport(#binary_protocol{transport = Transport}) -> + thrift_transport:close(Transport). + +%%% +%%% instance methods +%%% + +write(This, #protocol_message_begin{ + name = Name, + type = Type, + seqid = Seqid}) -> + case This#binary_protocol.strict_write of + true -> + write(This, {i32, ?VERSION_1 bor Type}), + write(This, {string, Name}), + write(This, {i32, Seqid}); + false -> + write(This, {string, Name}), + write(This, {byte, Type}), + write(This, {i32, Seqid}) + end, + ok; + +write(This, message_end) -> ok; + +write(This, #protocol_field_begin{ + name = _Name, + type = Type, + id = Id}) -> + write(This, {byte, Type}), + write(This, {i16, Id}), + ok; + +write(This, field_stop) -> + write(This, {byte, ?tType_STOP}), + ok; + +write(This, field_end) -> ok; + +write(This, #protocol_map_begin{ + ktype = Ktype, + vtype = Vtype, + size = Size}) -> + write(This, {byte, Ktype}), + write(This, {byte, Vtype}), + write(This, {i32, Size}), + ok; + +write(This, map_end) -> ok; + +write(This, #protocol_list_begin{ + etype = Etype, + size = Size}) -> + write(This, {byte, Etype}), + write(This, {i32, Size}), + ok; + +write(This, list_end) -> ok; + +write(This, #protocol_set_begin{ + etype = Etype, + size = Size}) -> + write(This, {byte, Etype}), + write(This, {i32, Size}), + ok; + +write(This, set_end) -> ok; + +write(This, #protocol_struct_begin{}) -> ok; +write(This, struct_end) -> ok; + +write(This, {bool, true}) -> write(This, {byte, 1}); +write(This, {bool, false}) -> write(This, {byte, 0}); + +write(This, {byte, Byte}) -> + write(This, <<Byte:8/big-signed>>); + +write(This, {i16, I16}) -> + write(This, <<I16:16/big-signed>>); + +write(This, {i32, I32}) -> + write(This, <<I32:32/big-signed>>); + +write(This, {i64, I64}) -> + write(This, <<I64:64/big-signed>>); + +write(This, {double, Double}) -> + write(This, <<Double:64/big-signed-float>>); + +write(This, {string, Str}) when is_list(Str) -> + write(This, {i32, length(Str)}), + write(This, list_to_binary(Str)); + +write(This, {string, Bin}) when is_binary(Bin) -> + write(This, {i32, size(Bin)}), + write(This, Bin); + +%% Data :: iolist() +write(This, Data) -> + thrift_transport:write(This#binary_protocol.transport, Data). + +%% + +read(This, message_begin) -> + case read(This, ui32) of + {ok, Sz} when Sz band ?VERSION_MASK =:= ?VERSION_1 -> + %% we're at version 1 + {ok, Name} = read(This, string), + Type = Sz band ?TYPE_MASK, + {ok, SeqId} = read(This, i32), + #protocol_message_begin{name = binary_to_list(Name), + type = Type, + seqid = SeqId}; + + {ok, Sz} when Sz < 0 -> + %% there's a version number but it's unexpected + {error, {bad_binary_protocol_version, Sz}}; + + {ok, Sz} when This#binary_protocol.strict_read =:= true -> + %% strict_read is true and there's no version header; that's an error + {error, no_binary_protocol_version}; + + {ok, Sz} when This#binary_protocol.strict_read =:= false -> + %% strict_read is false, so just read the old way + {ok, Name} = read(This, Sz), + {ok, Type} = read(This, byte), + {ok, SeqId} = read(This, i32), + #protocol_message_begin{name = binary_to_list(Name), + type = Type, + seqid = SeqId}; + + Err = {error, closed} -> Err; + Err = {error, timeout}-> Err; + Err = {error, ebadf} -> Err + end; + +read(This, message_end) -> ok; + +read(This, struct_begin) -> ok; +read(This, struct_end) -> ok; + +read(This, field_begin) -> + case read(This, byte) of + {ok, Type = ?tType_STOP} -> + #protocol_field_begin{type = Type}; + {ok, Type} -> + {ok, Id} = read(This, i16), + #protocol_field_begin{type = Type, + id = Id} + end; + +read(This, field_end) -> ok; + +read(This, map_begin) -> + {ok, Ktype} = read(This, byte), + {ok, Vtype} = read(This, byte), + {ok, Size} = read(This, i32), + #protocol_map_begin{ktype = Ktype, + vtype = Vtype, + size = Size}; +read(This, map_end) -> ok; + +read(This, list_begin) -> + {ok, Etype} = read(This, byte), + {ok, Size} = read(This, i32), + #protocol_list_begin{etype = Etype, + size = Size}; +read(This, list_end) -> ok; + +read(This, set_begin) -> + {ok, Etype} = read(This, byte), + {ok, Size} = read(This, i32), + #protocol_set_begin{etype = Etype, + size = Size}; +read(This, set_end) -> ok; + +read(This, field_stop) -> + {ok, ?tType_STOP} = read(This, byte), + ok; + +%% + +read(This, bool) -> + case read(This, byte) of + {ok, Byte} -> {ok, Byte /= 0}; + Else -> Else + end; + +read(This, byte) -> + case read(This, 1) of + {ok, <<Val:8/integer-signed-big, _/binary>>} -> {ok, Val}; + Else -> Else + end; + +read(This, i16) -> + case read(This, 2) of + {ok, <<Val:16/integer-signed-big, _/binary>>} -> {ok, Val}; + Else -> Else + end; + +read(This, i32) -> + case read(This, 4) of + {ok, <<Val:32/integer-signed-big, _/binary>>} -> {ok, Val}; + Else -> Else + end; + +%% unsigned ints aren't used by thrift itself, but it's used for the parsing +%% of the packet version header. Without this special function BEAM works fine +%% but hipe thinks it received a bad version header. +read(This, ui32) -> + case read(This, 4) of + {ok, <<Val:32/integer-unsigned-big, _/binary>>} -> {ok, Val}; + Else -> Else + end; + +read(This, i64) -> + case read(This, 8) of + {ok, <<Val:64/integer-signed-big, _/binary>>} -> {ok, Val}; + Else -> Else + end; + +read(This, double) -> + case read(This, 8) of + {ok, <<Val:64/float-signed-big, _/binary>>} -> {ok, Val}; + Else -> Else + end; + +% returns a binary directly, call binary_to_list if necessary +read(This, string) -> + {ok, Sz} = read(This, i32), + {ok, Bin} = read(This, Sz); + +read(This, 0) -> {ok, <<>>}; +read(This, Len) when is_integer(Len), Len >= 0 -> + thrift_transport:read(This#binary_protocol.transport, Len). + + +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +-record(tbp_opts, {strict_read = true, + strict_write = true}). + +parse_factory_options([], Opts) -> + Opts; +parse_factory_options([{strict_read, Bool} | Rest], Opts) when is_boolean(Bool) -> + parse_factory_options(Rest, Opts#tbp_opts{strict_read=Bool}); +parse_factory_options([{strict_write, Bool} | Rest], Opts) when is_boolean(Bool) -> + parse_factory_options(Rest, Opts#tbp_opts{strict_write=Bool}). + + +%% returns a (fun() -> thrift_protocol()) +new_protocol_factory(TransportFactory, Options) -> + ParsedOpts = parse_factory_options(Options, #tbp_opts{}), + F = fun() -> + {ok, Transport} = TransportFactory(), + thrift_binary_protocol:new( + Transport, + [{strict_read, ParsedOpts#tbp_opts.strict_read}, + {strict_write, ParsedOpts#tbp_opts.strict_write}]) + end, + {ok, F}. + diff --git a/lib/erl/src/thrift_buffered_transport.erl b/lib/erl/src/thrift_buffered_transport.erl new file mode 100644 index 000000000..ebc16bd65 --- /dev/null +++ b/lib/erl/src/thrift_buffered_transport.erl @@ -0,0 +1,180 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_buffered_transport). + +-behaviour(gen_server). +-behaviour(thrift_transport). + +%% API +-export([new/1, new_transport_factory/1]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +-record(buffered_transport, {wrapped, % a thrift_transport + write_buffer % iolist() + }). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +new(WrappedTransport) -> + case gen_server:start_link(?MODULE, [WrappedTransport], []) of + {ok, Pid} -> + thrift_transport:new(?MODULE, Pid); + Else -> + Else + end. + + + +%%-------------------------------------------------------------------- +%% Function: write(Transport, Data) -> ok +%% +%% Data = iolist() +%% +%% Description: Writes data into the buffer +%%-------------------------------------------------------------------- +write(Transport, Data) -> + gen_server:call(Transport, {write, Data}). + +%%-------------------------------------------------------------------- +%% Function: flush(Transport) -> ok +%% +%% Description: Flushes the buffer through to the wrapped transport +%%-------------------------------------------------------------------- +flush(Transport) -> + gen_server:call(Transport, flush). + +%%-------------------------------------------------------------------- +%% Function: close(Transport) -> ok +%% +%% Description: Closes the transport and the wrapped transport +%%-------------------------------------------------------------------- +close(Transport) -> + gen_server:cast(Transport, close). + +%%-------------------------------------------------------------------- +%% Function: Read(Transport, Len) -> {ok, Data} +%% +%% Data = binary() +%% +%% Description: Reads data through from the wrapped transoprt +%%-------------------------------------------------------------------- +read(Transport, Len) when is_integer(Len) -> + gen_server:call(Transport, {read, Len}, _Timeout=10000). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init([Wrapped]) -> + {ok, #buffered_transport{wrapped = Wrapped, + write_buffer = []}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call({write, Data}, _From, State = #buffered_transport{write_buffer = WBuf}) -> + {reply, ok, State#buffered_transport{write_buffer = [WBuf, Data]}}; + +handle_call({read, Len}, _From, State = #buffered_transport{wrapped = Wrapped}) -> + Response = thrift_transport:read(Wrapped, Len), + {reply, Response, State}; + +handle_call(flush, _From, State = #buffered_transport{write_buffer = WBuf, + wrapped = Wrapped}) -> + Response = thrift_transport:write(Wrapped, WBuf), + thrift_transport:flush(Wrapped), + {reply, Response, State#buffered_transport{write_buffer = []}}. + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast(close, State = #buffered_transport{write_buffer = WBuf, + wrapped = Wrapped}) -> + thrift_transport:write(Wrapped, WBuf), + %% Wrapped is closed by terminate/2 + %% error_logger:info_msg("thrift_buffered_transport ~p: closing", [self()]), + {stop, normal, State}; +handle_cast(Msg, State=#buffered_transport{}) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(_Reason, State = #buffered_transport{wrapped=Wrapped}) -> + thrift_transport:close(Wrapped), + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +new_transport_factory(WrapFactory) -> + F = fun() -> + {ok, Wrapped} = WrapFactory(), + new(Wrapped) + end, + {ok, F}. diff --git a/lib/erl/src/thrift_client.erl b/lib/erl/src/thrift_client.erl new file mode 100644 index 000000000..5ba8aee6d --- /dev/null +++ b/lib/erl/src/thrift_client.erl @@ -0,0 +1,330 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_client). + +-behaviour(gen_server). + +%% API +-export([start_link/2, start_link/3, start_link/4, call/3, send_call/3, close/1]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + + +-include("thrift_constants.hrl"). +-include("thrift_protocol.hrl"). + +-record(state, {service, protocol, seqid}). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +start_link(Host, Port, Service) when is_integer(Port), is_atom(Service) -> + start_link(Host, Port, Service, []). + + +%% +%% Splits client options into protocol options and transport options +%% +%% split_options([Options...]) -> {ProtocolOptions, TransportOptions} +%% +split_options(Options) -> + split_options(Options, [], []). + +split_options([], ProtoIn, TransIn) -> + {ProtoIn, TransIn}; + +split_options([Opt = {OptKey, _} | Rest], ProtoIn, TransIn) + when OptKey =:= strict_read; + OptKey =:= strict_write -> + split_options(Rest, [Opt | ProtoIn], TransIn); + +split_options([Opt = {OptKey, _} | Rest], ProtoIn, TransIn) + when OptKey =:= framed; + OptKey =:= connect_timeout; + OptKey =:= sockopts -> + split_options(Rest, ProtoIn, [Opt | TransIn]). + + +%% Backwards-compatible starter for the common-case of socket transports +start_link(Host, Port, Service, Options) + when is_integer(Port), is_atom(Service), is_list(Options) -> + {ProtoOpts, TransOpts} = split_options(Options), + + {ok, TransportFactory} = + thrift_socket_transport:new_transport_factory(Host, Port, TransOpts), + + {ok, ProtocolFactory} = thrift_binary_protocol:new_protocol_factory( + TransportFactory, ProtoOpts), + + start_link(ProtocolFactory, Service). + + +%% ProtocolFactory :: fun() -> thrift_protocol() +start_link(ProtocolFactory, Service) + when is_function(ProtocolFactory), is_atom(Service) -> + case gen_server:start_link(?MODULE, [Service], []) of + {ok, Pid} -> + case gen_server:call(Pid, {connect, ProtocolFactory}) of + ok -> + {ok, Pid}; + Error -> + Error + end; + Else -> + Else + end. + +call(Client, Function, Args) + when is_pid(Client), is_atom(Function), is_list(Args) -> + case gen_server:call(Client, {call, Function, Args}) of + R = {ok, _} -> R; + R = {error, _} -> R; + {exception, Exception} -> throw(Exception) + end. + +cast(Client, Function, Args) + when is_pid(Client), is_atom(Function), is_list(Args) -> + gen_server:cast(Client, {call, Function, Args}). + +%% Sends a function call but does not read the result. This is useful +%% if you're trying to log non-oneway function calls to write-only +%% transports like thrift_disk_log_transport. +send_call(Client, Function, Args) + when is_pid(Client), is_atom(Function), is_list(Args) -> + gen_server:call(Client, {send_call, Function, Args}). + +close(Client) when is_pid(Client) -> + gen_server:cast(Client, close). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init([Service]) -> + {ok, #state{service = Service}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call({connect, ProtocolFactory}, _From, + State = #state{service = Service}) -> + case ProtocolFactory() of + {ok, Protocol} -> + {reply, ok, State#state{protocol = Protocol, + seqid = 0}}; + Error -> + {stop, normal, Error, State} + end; + +handle_call({call, Function, Args}, _From, State = #state{service = Service}) -> + Result = catch_function_exceptions( + fun() -> + ok = send_function_call(State, Function, Args), + receive_function_result(State, Function) + end, + Service), + {reply, Result, State}; + + +handle_call({send_call, Function, Args}, _From, State = #state{service = Service}) -> + Result = catch_function_exceptions( + fun() -> + send_function_call(State, Function, Args) + end, + Service), + {reply, Result, State}. + + +%% Helper function that catches exceptions thrown by sending or receiving +%% a function and returns the correct response for call or send_only above. +catch_function_exceptions(Fun, Service) -> + try + Fun() + catch + throw:{return, Return} -> + Return; + error:function_clause -> + ST = erlang:get_stacktrace(), + case hd(ST) of + {Service, function_info, [Function, _]} -> + {error, {no_function, Function}}; + _ -> throw({error, {function_clause, ST}}) + end + end. + + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast({call, Function, Args}, State = #state{service = Service, + protocol = Protocol, + seqid = SeqId}) -> + _Result = + try + ok = send_function_call(State, Function, Args), + receive_function_result(State, Function) + catch + Class:Reason -> + error_logger:error_msg("error ignored in handle_cast({cast,...},...): ~p:~p~n", [Class, Reason]) + end, + + {noreply, State}; + +handle_cast(close, State=#state{protocol = Protocol}) -> +%% error_logger:info_msg("thrift_client ~p received close", [self()]), + {stop,normal,State}; +handle_cast(_Msg, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(Reason, State = #state{protocol=undefined}) -> + ok; +terminate(Reason, State = #state{protocol=Protocol}) -> + thrift_protocol:close_transport(Protocol), + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +send_function_call(#state{protocol = Proto, + service = Service, + seqid = SeqId}, + Function, + Args) -> + Params = Service:function_info(Function, params_type), + {struct, PList} = Params, + if + length(PList) =/= length(Args) -> + throw({return, {error, {bad_args, Function, Args}}}); + true -> ok + end, + + Begin = #protocol_message_begin{name = atom_to_list(Function), + type = ?tMessageType_CALL, + seqid = SeqId}, + ok = thrift_protocol:write(Proto, Begin), + ok = thrift_protocol:write(Proto, {Params, list_to_tuple([Function | Args])}), + ok = thrift_protocol:write(Proto, message_end), + thrift_protocol:flush_transport(Proto), + ok. + +receive_function_result(State = #state{protocol = Proto, + service = Service}, + Function) -> + ResultType = Service:function_info(Function, reply_type), + read_result(State, Function, ResultType). + +read_result(_State, + _Function, + oneway_void) -> + {ok, ok}; + +read_result(State = #state{protocol = Proto, + seqid = SeqId}, + Function, + ReplyType) -> + case thrift_protocol:read(Proto, message_begin) of + #protocol_message_begin{seqid = RetSeqId} when RetSeqId =/= SeqId -> + {error, {bad_seq_id, SeqId}}; + + #protocol_message_begin{type = ?tMessageType_EXCEPTION} -> + handle_application_exception(State); + + #protocol_message_begin{type = ?tMessageType_REPLY} -> + handle_reply(State, Function, ReplyType) + end. + +handle_reply(State = #state{protocol = Proto, + service = Service}, + Function, + ReplyType) -> + {struct, ExceptionFields} = Service:function_info(Function, exceptions), + ReplyStructDef = {struct, [{0, ReplyType}] ++ ExceptionFields}, + {ok, Reply} = thrift_protocol:read(Proto, ReplyStructDef), + ReplyList = tuple_to_list(Reply), + true = length(ReplyList) == length(ExceptionFields) + 1, + ExceptionVals = tl(ReplyList), + Thrown = [X || X <- ExceptionVals, + X =/= undefined], + Result = + case Thrown of + [] when ReplyType == {struct, []} -> + {ok, ok}; + [] -> + {ok, hd(ReplyList)}; + [Exception] -> + {exception, Exception} + end, + ok = thrift_protocol:read(Proto, message_end), + Result. + +handle_application_exception(State = #state{protocol = Proto}) -> + {ok, Exception} = thrift_protocol:read(Proto, + ?TApplicationException_Structure), + ok = thrift_protocol:read(Proto, message_end), + XRecord = list_to_tuple( + ['TApplicationException' | tuple_to_list(Exception)]), + error_logger:error_msg("X: ~p~n", [XRecord]), + true = is_record(XRecord, 'TApplicationException'), + {exception, XRecord}. diff --git a/lib/erl/src/thrift_disk_log_transport.erl b/lib/erl/src/thrift_disk_log_transport.erl new file mode 100644 index 000000000..761fa3097 --- /dev/null +++ b/lib/erl/src/thrift_disk_log_transport.erl @@ -0,0 +1,118 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +%%% Todo: this might be better off as a gen_server type of transport +%%% that handles stuff like group commit, similar to TFileTransport +%%% in cpp land +-module(thrift_disk_log_transport). + +-behaviour(thrift_transport). + +%% API +-export([new/2, new_transport_factory/2, new_transport_factory/3]). + +%% thrift_transport callbacks +-export([read/2, write/2, force_flush/1, flush/1, close/1]). + +%% state +-record(dl_transport, {log, + close_on_close = false, + sync_every = infinity, + sync_tref}). + + +%% Create a transport attached to an already open log. +%% If you'd like this transport to close the disk_log using disk_log:lclose() +%% when the transport is closed, pass a {close_on_close, true} tuple in the +%% Opts list. +new(LogName, Opts) when is_atom(LogName), is_list(Opts) -> + State = parse_opts(Opts, #dl_transport{log = LogName}), + + State2 = + case State#dl_transport.sync_every of + N when is_integer(N), N > 0 -> + {ok, TRef} = timer:apply_interval(N, ?MODULE, force_flush, State), + State#dl_transport{sync_tref = TRef}; + _ -> State + end, + + thrift_transport:new(?MODULE, State2). + + +parse_opts([], State) -> + State; +parse_opts([{close_on_close, Bool} | Rest], State) when is_boolean(Bool) -> + State#dl_transport{close_on_close = Bool}; +parse_opts([{sync_every, Int} | Rest], State) when is_integer(Int), Int > 0 -> + State#dl_transport{sync_every = Int}. + + +%%%% TRANSPORT IMPLENTATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +%% disk_log_transport is write-only +read(_State, Len) -> + {error, no_read_from_disk_log}. + +write(#dl_transport{log = Log}, Data) -> + disk_log:balog(Log, erlang:iolist_to_binary(Data)). + +force_flush(#dl_transport{log = Log}) -> + error_logger:info_msg("~p syncing~n", [?MODULE]), + disk_log:sync(Log). + +flush(#dl_transport{log = Log, sync_every = SE}) -> + case SE of + undefined -> % no time-based sync + disk_log:sync(Log); + _Else -> % sync will happen automagically + ok + end. + + +%% On close, close the underlying log if we're configured to do so. +close(#dl_transport{close_on_close = false}) -> + ok; +close(#dl_transport{log = Log}) -> + disk_log:lclose(Log). + + +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +new_transport_factory(Name, ExtraLogOpts) -> + new_transport_factory(Name, ExtraLogOpts, [{close_on_close, true}, + {sync_every, 500}]). + +new_transport_factory(Name, ExtraLogOpts, TransportOpts) -> + F = fun() -> factory_impl(Name, ExtraLogOpts, TransportOpts) end, + {ok, F}. + +factory_impl(Name, ExtraLogOpts, TransportOpts) -> + LogOpts = [{name, Name}, + {format, external}, + {type, wrap} | + ExtraLogOpts], + Log = + case disk_log:open(LogOpts) of + {ok, Log} -> + Log; + {repaired, Log, Info1, Info2} -> + error_logger:info_msg("Disk log ~p repaired: ~p, ~p~n", [Log, Info1, Info2]), + Log + end, + new(Log, TransportOpts). diff --git a/lib/erl/src/thrift_file_transport.erl b/lib/erl/src/thrift_file_transport.erl new file mode 100644 index 000000000..5ac2dbe1f --- /dev/null +++ b/lib/erl/src/thrift_file_transport.erl @@ -0,0 +1,87 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_file_transport). + +-behaviour(thrift_transport). + +-export([new_reader/1, + new/1, + new/2, + write/2, read/2, flush/1, close/1]). + +-record(t_file_transport, {device, + should_close = true, + mode = write}). + +%%%% CONSTRUCTION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +new_reader(Filename) -> + case file:open(Filename, [read, binary, {read_ahead, 1024*1024}]) of + {ok, IODevice} -> + new(IODevice, [{should_close, true}, {mode, read}]); + Error -> Error + end. + +new(Device) -> + new(Device, []). + +%% Device :: io_device() +%% +%% Device should be opened in raw and binary mode. +new(Device, Opts) when is_list(Opts) -> + State = parse_opts(Opts, #t_file_transport{device = Device}), + thrift_transport:new(?MODULE, State). + + +%% Parse options +parse_opts([{should_close, Bool} | Rest], State) when is_boolean(Bool) -> + parse_opts(Rest, State#t_file_transport{should_close = Bool}); +parse_opts([{mode, Mode} | Rest], State) + when Mode =:= write; + Mode =:= read -> + parse_opts(Rest, State#t_file_transport{mode = Mode}); +parse_opts([], State) -> + State. + + +%%%% TRANSPORT IMPL %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +write(#t_file_transport{device = Device, mode = write}, Data) -> + file:write(Device, Data); +write(_T, _D) -> + {error, read_mode}. + + +read(#t_file_transport{device = Device, mode = read}, Len) + when is_integer(Len), Len >= 0 -> + file:read(Device, Len); +read(_T, _D) -> + {error, read_mode}. + +flush(#t_file_transport{device = Device, mode = write}) -> + file:sync(Device). + +close(#t_file_transport{device = Device, should_close = SC}) -> + case SC of + true -> + file:close(Device); + false -> + ok + end. diff --git a/lib/erl/src/thrift_framed_transport.erl b/lib/erl/src/thrift_framed_transport.erl new file mode 100644 index 000000000..01bab70ba --- /dev/null +++ b/lib/erl/src/thrift_framed_transport.erl @@ -0,0 +1,208 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_framed_transport). + +-behaviour(gen_server). +-behaviour(thrift_transport). + +%% API +-export([new/1]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +-record(framed_transport, {wrapped, % a thrift_transport + read_buffer, % iolist() + write_buffer % iolist() + }). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +new(WrappedTransport) -> + case gen_server:start_link(?MODULE, [WrappedTransport], []) of + {ok, Pid} -> + thrift_transport:new(?MODULE, Pid); + Else -> + Else + end. + +%%-------------------------------------------------------------------- +%% Function: write(Transport, Data) -> ok +%% +%% Data = iolist() +%% +%% Description: Writes data into the buffer +%%-------------------------------------------------------------------- +write(Transport, Data) -> + gen_server:call(Transport, {write, Data}). + +%%-------------------------------------------------------------------- +%% Function: flush(Transport) -> ok +%% +%% Description: Flushes the buffer through to the wrapped transport +%%-------------------------------------------------------------------- +flush(Transport) -> + gen_server:call(Transport, flush). + +%%-------------------------------------------------------------------- +%% Function: close(Transport) -> ok +%% +%% Description: Closes the transport and the wrapped transport +%%-------------------------------------------------------------------- +close(Transport) -> + gen_server:cast(Transport, close). + +%%-------------------------------------------------------------------- +%% Function: Read(Transport, Len) -> {ok, Data} +%% +%% Data = binary() +%% +%% Description: Reads data through from the wrapped transoprt +%%-------------------------------------------------------------------- +read(Transport, Len) when is_integer(Len) -> + gen_server:call(Transport, {read, Len}). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init([Wrapped]) -> + {ok, #framed_transport{wrapped = Wrapped, + read_buffer = [], + write_buffer = []}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call({write, Data}, _From, State = #framed_transport{write_buffer = WBuf}) -> + {reply, ok, State#framed_transport{write_buffer = [WBuf, Data]}}; + +handle_call({read, Len}, _From, State = #framed_transport{wrapped = Wrapped, + read_buffer = RBuf}) -> + {RBuf1, RBuf1Size} = + %% if the read buffer is empty, read another frame + %% otherwise, just read from what's left in the buffer + case iolist_size(RBuf) of + 0 -> + %% read the frame length + {ok, <<FrameLen:32/integer-signed-big, _/binary>>} = + thrift_transport:read(Wrapped, 4), + %% then read the data + {ok, Bin} = + thrift_transport:read(Wrapped, FrameLen), + {Bin, erlang:byte_size(Bin)}; + Sz -> + {RBuf, Sz} + end, + + %% pull off Give bytes, return them to the user, leave the rest in the buffer + Give = min(RBuf1Size, Len), + <<Data:Give/binary, RBuf2/binary>> = iolist_to_binary(RBuf1), + + Response = {ok, Data}, + State1 = State#framed_transport{read_buffer=RBuf2}, + + {reply, Response, State1}; + +handle_call(flush, _From, State) -> + {Response, State1} = do_flush(State), + {reply, Response, State1}. + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast(close, State) -> + {_, State1} = do_flush(State), + %% Wrapped is closed by terminate/2 + %% error_logger:info_msg("thrift_framed_transport ~p: closing", [self()]), + {stop, normal, State}; +handle_cast(Msg, State=#framed_transport{}) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(_Reason, State = #framed_transport{wrapped=Wrapped}) -> + thrift_transport:close(Wrapped), + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +do_flush(State = #framed_transport{write_buffer = Buffer, + wrapped = Wrapped}) -> + FrameLen = iolist_size(Buffer), + Data = [<<FrameLen:32/integer-signed-big>>, Buffer], + + Response = thrift_transport:write(Wrapped, Data), + + thrift_transport:flush(Wrapped), + + State1 = State#framed_transport{write_buffer = []}, + {Response, State1}. + +min(A,B) when A<B -> A; +min(_,B) -> B. + diff --git a/lib/erl/src/thrift_http_transport.erl b/lib/erl/src/thrift_http_transport.erl new file mode 100644 index 000000000..f8c182773 --- /dev/null +++ b/lib/erl/src/thrift_http_transport.erl @@ -0,0 +1,199 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_http_transport). + +-behaviour(gen_server). +-behaviour(thrift_transport). + +%% API +-export([new/2, new/3]). + +%% gen_server callbacks +-export([init/1, + handle_call/3, + handle_cast/2, + handle_info/2, + terminate/2, + code_change/3]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +-record(http_transport, {host, % string() + path, % string() + read_buffer, % iolist() + write_buffer, % iolist() + http_options, % see http(3) + extra_headers % [{str(), str()}, ...] + }). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: new() -> {ok, Transport} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +new(Host, Path) -> + new(Host, Path, _Options = []). + +%%-------------------------------------------------------------------- +%% Options include: +%% {http_options, HttpOptions} = See http(3) +%% {extra_headers, ExtraHeaders} = List of extra HTTP headers +%%-------------------------------------------------------------------- +new(Host, Path, Options) -> + case gen_server:start_link(?MODULE, {Host, Path, Options}, []) of + {ok, Pid} -> + thrift_transport:new(?MODULE, Pid); + Else -> + Else + end. + +%%-------------------------------------------------------------------- +%% Function: write(Transport, Data) -> ok +%% +%% Data = iolist() +%% +%% Description: Writes data into the buffer +%%-------------------------------------------------------------------- +write(Transport, Data) -> + gen_server:call(Transport, {write, Data}). + +%%-------------------------------------------------------------------- +%% Function: flush(Transport) -> ok +%% +%% Description: Flushes the buffer, making a request +%%-------------------------------------------------------------------- +flush(Transport) -> + gen_server:call(Transport, flush). + +%%-------------------------------------------------------------------- +%% Function: close(Transport) -> ok +%% +%% Description: Closes the transport +%%-------------------------------------------------------------------- +close(Transport) -> + gen_server:cast(Transport, close). + +%%-------------------------------------------------------------------- +%% Function: Read(Transport, Len) -> {ok, Data} +%% +%% Data = binary() +%% +%% Description: Reads data through from the wrapped transoprt +%%-------------------------------------------------------------------- +read(Transport, Len) when is_integer(Len) -> + gen_server:call(Transport, {read, Len}). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +init({Host, Path, Options}) -> + State1 = #http_transport{host = Host, + path = Path, + read_buffer = [], + write_buffer = [], + http_options = [], + extra_headers = []}, + ApplyOption = + fun + ({http_options, HttpOpts}, State = #http_transport{}) -> + State#http_transport{http_options = HttpOpts}; + ({extra_headers, ExtraHeaders}, State = #http_transport{}) -> + State#http_transport{extra_headers = ExtraHeaders}; + (Other, #http_transport{}) -> + {invalid_option, Other}; + (_, Error) -> + Error + end, + case lists:foldl(ApplyOption, State1, Options) of + State2 = #http_transport{} -> + {ok, State2}; + Else -> + {stop, Else} + end. + +handle_call({write, Data}, _From, State = #http_transport{write_buffer = WBuf}) -> + {reply, ok, State#http_transport{write_buffer = [WBuf, Data]}}; + +handle_call({read, Len}, _From, State = #http_transport{read_buffer = RBuf}) -> + %% Pull off Give bytes, return them to the user, leave the rest in the buffer. + Give = min(iolist_size(RBuf), Len), + case iolist_to_binary(RBuf) of + <<Data:Give/binary, RBuf1/binary>> -> + Response = {ok, Data}, + State1 = State#http_transport{read_buffer=RBuf1}, + {reply, Response, State1}; + _ -> + {reply, {error, 'EOF'}, State} + end; + +handle_call(flush, _From, State) -> + {Response, State1} = do_flush(State), + {reply, Response, State1}. + +handle_cast(close, State) -> + {_, State1} = do_flush(State), + {stop, normal, State1}; + +handle_cast(_Msg, State=#http_transport{}) -> + {noreply, State}. + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +do_flush(State = #http_transport{host = Host, + path = Path, + read_buffer = Rbuf, + write_buffer = Wbuf, + http_options = HttpOptions, + extra_headers = ExtraHeaders}) -> + case iolist_to_binary(Wbuf) of + <<>> -> + %% Don't bother flushing empty buffers. + {ok, State}; + WBinary -> + {ok, {{_Version, 200, _ReasonPhrase}, _Headers, Body}} = + http:request(post, + {"http://" ++ Host ++ Path, + [{"User-Agent", "Erlang/thrift_http_transport"} | ExtraHeaders], + "application/x-thrift", + WBinary}, + HttpOptions, + [{body_format, binary}]), + + State1 = State#http_transport{read_buffer = [Rbuf, Body], + write_buffer = []}, + {ok, State1} + end. + +min(A,B) when A<B -> A; +min(_,B) -> B. diff --git a/lib/erl/src/thrift_memory_buffer.erl b/lib/erl/src/thrift_memory_buffer.erl new file mode 100644 index 000000000..b4f607a95 --- /dev/null +++ b/lib/erl/src/thrift_memory_buffer.erl @@ -0,0 +1,164 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_memory_buffer). + +-behaviour(gen_server). +-behaviour(thrift_transport). + +%% API +-export([new/0, new_transport_factory/0]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +-record(memory_buffer, {buffer}). + +%%==================================================================== +%% API +%%==================================================================== +new() -> + case gen_server:start_link(?MODULE, [], []) of + {ok, Pid} -> + thrift_transport:new(?MODULE, Pid); + Else -> + Else + end. + +new_transport_factory() -> + {ok, fun() -> new() end}. + +%%-------------------------------------------------------------------- +%% Function: write(Transport, Data) -> ok +%% +%% Data = iolist() +%% +%% Description: Writes data into the buffer +%%-------------------------------------------------------------------- +write(Transport, Data) -> + gen_server:call(Transport, {write, Data}). + +%%-------------------------------------------------------------------- +%% Function: flush(Transport) -> ok +%% +%% Description: Flushes the buffer through to the wrapped transport +%%-------------------------------------------------------------------- +flush(Transport) -> + gen_server:call(Transport, flush). + +%%-------------------------------------------------------------------- +%% Function: close(Transport) -> ok +%% +%% Description: Closes the transport and the wrapped transport +%%-------------------------------------------------------------------- +close(Transport) -> + gen_server:cast(Transport, close). + +%%-------------------------------------------------------------------- +%% Function: Read(Transport, Len) -> {ok, Data} +%% +%% Data = binary() +%% +%% Description: Reads data through from the wrapped transoprt +%%-------------------------------------------------------------------- +read(Transport, Len) when is_integer(Len) -> + gen_server:call(Transport, {read, Len}). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init([]) -> + {ok, #memory_buffer{buffer = []}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call({write, Data}, _From, State = #memory_buffer{buffer = Buf}) -> + {reply, ok, State#memory_buffer{buffer = [Buf, Data]}}; + +handle_call({read, Len}, _From, State = #memory_buffer{buffer = Buf}) -> + Binary = iolist_to_binary(Buf), + Give = min(iolist_size(Binary), Len), + {Result, Remaining} = split_binary(Binary, Give), + {reply, {ok, Result}, State#memory_buffer{buffer = Remaining}}; + +handle_call(flush, _From, State) -> + {reply, ok, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast(close, State) -> + {stop, normal, State}; +handle_cast(Msg, State=#memory_buffer{}) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(_Reason, _State) -> + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +min(A,B) when A<B -> A; +min(_,B) -> B. + diff --git a/lib/erl/src/thrift_processor.erl b/lib/erl/src/thrift_processor.erl new file mode 100644 index 000000000..e26fb3303 --- /dev/null +++ b/lib/erl/src/thrift_processor.erl @@ -0,0 +1,188 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_processor). + +-export([init/1]). + +-include("thrift_constants.hrl"). +-include("thrift_protocol.hrl"). + +-record(thrift_processor, {handler, in_protocol, out_protocol, service}). + +init({Server, ProtoGen, Service, Handler}) when is_function(ProtoGen, 0) -> + {ok, IProt, OProt} = ProtoGen(), + loop(#thrift_processor{in_protocol = IProt, + out_protocol = OProt, + service = Service, + handler = Handler}). + +loop(State = #thrift_processor{in_protocol = IProto, + out_protocol = OProto}) -> + case thrift_protocol:read(IProto, message_begin) of + #protocol_message_begin{name = Function, + type = ?tMessageType_CALL} -> + ok = handle_function(State, list_to_atom(Function)), + loop(State); + #protocol_message_begin{name = Function, + type = ?tMessageType_ONEWAY} -> + ok = handle_function(State, list_to_atom(Function)), + loop(State); + {error, timeout} -> + thrift_protocol:close_transport(OProto), + ok; + {error, closed} -> + %% error_logger:info_msg("Client disconnected~n"), + thrift_protocol:close_transport(OProto), + exit(shutdown) + end. + +handle_function(State=#thrift_processor{in_protocol = IProto, + out_protocol = OProto, + handler = Handler, + service = Service}, + Function) -> + InParams = Service:function_info(Function, params_type), + + {ok, Params} = thrift_protocol:read(IProto, InParams), + + try + Result = Handler:handle_function(Function, Params), + %% {Micro, Result} = better_timer(Handler, handle_function, [Function, Params]), + %% error_logger:info_msg("Processed ~p(~p) in ~.4fms~n", + %% [Function, Params, Micro/1000.0]), + handle_success(State, Function, Result) + catch + Type:Data -> + handle_function_catch(State, Function, Type, Data) + end, + after_reply(OProto). + +handle_function_catch(State = #thrift_processor{service = Service}, + Function, ErrType, ErrData) -> + IsOneway = Service:function_info(Function, reply_type) =:= oneway_void, + + case {ErrType, ErrData} of + _ when IsOneway -> + Stack = erlang:get_stacktrace(), + error_logger:warning_msg( + "oneway void ~p threw error which must be ignored: ~p", + [Function, {ErrType, ErrData, Stack}]), + ok; + + {throw, Exception} when is_tuple(Exception), size(Exception) > 0 -> + error_logger:warning_msg("~p threw exception: ~p~n", [Function, Exception]), + handle_exception(State, Function, Exception), + ok; % we still want to accept more requests from this client + + {error, Error} -> + ok = handle_error(State, Function, Error) + end. + +handle_success(State = #thrift_processor{out_protocol = OProto, + service = Service}, + Function, + Result) -> + ReplyType = Service:function_info(Function, reply_type), + StructName = atom_to_list(Function) ++ "_result", + + ok = case Result of + {reply, ReplyData} -> + Reply = {{struct, [{0, ReplyType}]}, {StructName, ReplyData}}, + send_reply(OProto, Function, ?tMessageType_REPLY, Reply); + + ok when ReplyType == {struct, []} -> + send_reply(OProto, Function, ?tMessageType_REPLY, {ReplyType, {StructName}}); + + ok when ReplyType == oneway_void -> + %% no reply for oneway void + ok + end. + +handle_exception(State = #thrift_processor{out_protocol = OProto, + service = Service}, + Function, + Exception) -> + ExceptionType = element(1, Exception), + %% Fetch a structure like {struct, [{-2, {struct, {Module, Type}}}, + %% {-3, {struct, {Module, Type}}}]} + + ReplySpec = Service:function_info(Function, exceptions), + {struct, XInfo} = ReplySpec, + + true = is_list(XInfo), + + %% Assuming we had a type1 exception, we'd get: [undefined, Exception, undefined] + %% e.g.: [{-1, type0}, {-2, type1}, {-3, type2}] + ExceptionList = [case Type of + ExceptionType -> Exception; + _ -> undefined + end + || {_Fid, {struct, {_Module, Type}}} <- XInfo], + + ExceptionTuple = list_to_tuple([Function | ExceptionList]), + + % Make sure we got at least one defined + case lists:all(fun(X) -> X =:= undefined end, ExceptionList) of + true -> + ok = handle_unknown_exception(State, Function, Exception); + false -> + ok = send_reply(OProto, Function, ?tMessageType_REPLY, {ReplySpec, ExceptionTuple}) + end. + +%% +%% Called when an exception has been explicitly thrown by the service, but it was +%% not one of the exceptions that was defined for the function. +%% +handle_unknown_exception(State, Function, Exception) -> + handle_error(State, Function, {exception_not_declared_as_thrown, + Exception}). + +handle_error(#thrift_processor{out_protocol = OProto}, Function, Error) -> + Stack = erlang:get_stacktrace(), + error_logger:error_msg("~p had an error: ~p~n", [Function, {Error, Stack}]), + + Message = + case application:get_env(thrift, exceptions_include_traces) of + {ok, true} -> + lists:flatten(io_lib:format("An error occurred: ~p~n", + [{Error, Stack}])); + _ -> + "An unknown handler error occurred." + end, + Reply = {?TApplicationException_Structure, + #'TApplicationException'{ + message = Message, + type = ?TApplicationException_UNKNOWN}}, + send_reply(OProto, Function, ?tMessageType_EXCEPTION, Reply). + +send_reply(OProto, Function, ReplyMessageType, Reply) -> + ok = thrift_protocol:write(OProto, #protocol_message_begin{ + name = atom_to_list(Function), + type = ReplyMessageType, + seqid = 0}), + ok = thrift_protocol:write(OProto, Reply), + ok = thrift_protocol:write(OProto, message_end), + ok = thrift_protocol:flush_transport(OProto), + ok. + +after_reply(OProto) -> + ok = thrift_protocol:flush_transport(OProto) + %% ok = thrift_protocol:close_transport(OProto) + . diff --git a/lib/erl/src/thrift_protocol.erl b/lib/erl/src/thrift_protocol.erl new file mode 100644 index 000000000..1bfb0a426 --- /dev/null +++ b/lib/erl/src/thrift_protocol.erl @@ -0,0 +1,356 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_protocol). + +-export([new/2, + write/2, + read/2, + read/3, + skip/2, + flush_transport/1, + close_transport/1, + typeid_to_atom/1 + ]). + +-export([behaviour_info/1]). + +-include("thrift_constants.hrl"). +-include("thrift_protocol.hrl"). + +-record(protocol, {module, data}). + +behaviour_info(callbacks) -> + [ + {read, 2}, + {write, 2}, + {flush_transport, 1}, + {close_transport, 1} + ]; +behaviour_info(_Else) -> undefined. + +new(Module, Data) when is_atom(Module) -> + {ok, #protocol{module = Module, + data = Data}}. + +flush_transport(#protocol{module = Module, + data = Data}) -> + Module:flush_transport(Data). + +close_transport(#protocol{module = Module, + data = Data}) -> + Module:close_transport(Data). + +typeid_to_atom(?tType_STOP) -> field_stop; +typeid_to_atom(?tType_VOID) -> void; +typeid_to_atom(?tType_BOOL) -> bool; +typeid_to_atom(?tType_BYTE) -> byte; +typeid_to_atom(?tType_DOUBLE) -> double; +typeid_to_atom(?tType_I16) -> i16; +typeid_to_atom(?tType_I32) -> i32; +typeid_to_atom(?tType_I64) -> i64; +typeid_to_atom(?tType_STRING) -> string; +typeid_to_atom(?tType_STRUCT) -> struct; +typeid_to_atom(?tType_MAP) -> map; +typeid_to_atom(?tType_SET) -> set; +typeid_to_atom(?tType_LIST) -> list. + +term_to_typeid(void) -> ?tType_VOID; +term_to_typeid(bool) -> ?tType_BOOL; +term_to_typeid(byte) -> ?tType_BYTE; +term_to_typeid(double) -> ?tType_DOUBLE; +term_to_typeid(i16) -> ?tType_I16; +term_to_typeid(i32) -> ?tType_I32; +term_to_typeid(i64) -> ?tType_I64; +term_to_typeid(string) -> ?tType_STRING; +term_to_typeid({struct, _}) -> ?tType_STRUCT; +term_to_typeid({map, _, _}) -> ?tType_MAP; +term_to_typeid({set, _}) -> ?tType_SET; +term_to_typeid({list, _}) -> ?tType_LIST. + +%% Structure is like: +%% [{Fid, Type}, ...] +read(IProto, {struct, Structure}, Tag) + when is_list(Structure), is_atom(Tag) -> + + % If we want a tagged tuple, we need to offset all the tuple indices + % by 1 to avoid overwriting the tag. + Offset = if Tag =/= undefined -> 1; true -> 0 end, + IndexList = case length(Structure) of + N when N > 0 -> lists:seq(1 + Offset, N + Offset); + _ -> [] + end, + + SWithIndices = [{Fid, {Type, Index}} || + {{Fid, Type}, Index} <- + lists:zip(Structure, IndexList)], + % Fid -> {Type, Index} + SDict = dict:from_list(SWithIndices), + + ok = read(IProto, struct_begin), + RTuple0 = erlang:make_tuple(length(Structure) + Offset, undefined), + RTuple1 = if Tag =/= undefined -> setelement(1, RTuple0, Tag); + true -> RTuple0 + end, + + RTuple2 = read_struct_loop(IProto, SDict, RTuple1), + {ok, RTuple2}. + +read(IProto, {struct, {Module, StructureName}}) when is_atom(Module), + is_atom(StructureName) -> + read(IProto, Module:struct_info(StructureName), StructureName); + +read(IProto, S={struct, Structure}) when is_list(Structure) -> + read(IProto, S, undefined); + +read(IProto, {list, Type}) -> + #protocol_list_begin{etype = EType, size = Size} = + read(IProto, list_begin), + List = [Result || {ok, Result} <- + [read(IProto, Type) || _X <- lists:duplicate(Size, 0)]], + ok = read(IProto, list_end), + {ok, List}; + +read(IProto, {map, KeyType, ValType}) -> + #protocol_map_begin{size = Size} = + read(IProto, map_begin), + + List = [{Key, Val} || {{ok, Key}, {ok, Val}} <- + [{read(IProto, KeyType), + read(IProto, ValType)} || _X <- lists:duplicate(Size, 0)]], + ok = read(IProto, map_end), + {ok, dict:from_list(List)}; + +read(IProto, {set, Type}) -> + #protocol_set_begin{etype = _EType, + size = Size} = + read(IProto, set_begin), + List = [Result || {ok, Result} <- + [read(IProto, Type) || _X <- lists:duplicate(Size, 0)]], + ok = read(IProto, set_end), + {ok, sets:from_list(List)}; + +read(#protocol{module = Module, + data = ModuleData}, ProtocolType) -> + Module:read(ModuleData, ProtocolType). + +read_struct_loop(IProto, SDict, RTuple) -> + #protocol_field_begin{type = FType, id = Fid, name = Name} = + thrift_protocol:read(IProto, field_begin), + case {FType, Fid} of + {?tType_STOP, _} -> + RTuple; + _Else -> + case dict:find(Fid, SDict) of + {ok, {Type, Index}} -> + case term_to_typeid(Type) of + FType -> + {ok, Val} = read(IProto, Type), + thrift_protocol:read(IProto, field_end), + NewRTuple = setelement(Index, RTuple, Val), + read_struct_loop(IProto, SDict, NewRTuple); + Expected -> + error_logger:info_msg( + "Skipping field ~p with wrong type (~p != ~p)~n", + [Fid, FType, Expected]), + skip_field(FType, IProto, SDict, RTuple) + end; + _Else2 -> + error_logger:info_msg("Skipping field ~p with unknown fid~n", [Fid]), + skip_field(FType, IProto, SDict, RTuple) + end + end. + +skip_field(FType, IProto, SDict, RTuple) -> + FTypeAtom = thrift_protocol:typeid_to_atom(FType), + thrift_protocol:skip(IProto, FTypeAtom), + read(IProto, field_end), + read_struct_loop(IProto, SDict, RTuple). + + +skip(Proto, struct) -> + ok = read(Proto, struct_begin), + ok = skip_struct_loop(Proto), + ok = read(Proto, struct_end); + +skip(Proto, map) -> + Map = read(Proto, map_begin), + ok = skip_map_loop(Proto, Map), + ok = read(Proto, map_end); + +skip(Proto, set) -> + Set = read(Proto, set_begin), + ok = skip_set_loop(Proto, Set), + ok = read(Proto, set_end); + +skip(Proto, list) -> + List = read(Proto, list_begin), + ok = skip_list_loop(Proto, List), + ok = read(Proto, list_end); + +skip(Proto, Type) when is_atom(Type) -> + _Ignore = read(Proto, Type), + ok. + + +skip_struct_loop(Proto) -> + #protocol_field_begin{type = Type} = read(Proto, field_begin), + case Type of + ?tType_STOP -> + ok; + _Else -> + skip(Proto, Type), + ok = read(Proto, field_end), + skip_struct_loop(Proto) + end. + +skip_map_loop(Proto, Map = #protocol_map_begin{ktype = Ktype, + vtype = Vtype, + size = Size}) -> + case Size of + N when N > 0 -> + skip(Proto, Ktype), + skip(Proto, Vtype), + skip_map_loop(Proto, + Map#protocol_map_begin{size = Size - 1}); + 0 -> ok + end. + +skip_set_loop(Proto, Map = #protocol_set_begin{etype = Etype, + size = Size}) -> + case Size of + N when N > 0 -> + skip(Proto, Etype), + skip_set_loop(Proto, + Map#protocol_set_begin{size = Size - 1}); + 0 -> ok + end. + +skip_list_loop(Proto, Map = #protocol_list_begin{etype = Etype, + size = Size}) -> + case Size of + N when N > 0 -> + skip(Proto, Etype), + skip_list_loop(Proto, + Map#protocol_list_begin{size = Size - 1}); + 0 -> ok + end. + + +%%-------------------------------------------------------------------- +%% Function: write(OProto, {Type, Data}) -> ok +%% +%% Type = {struct, StructDef} | +%% {list, Type} | +%% {map, KeyType, ValType} | +%% {set, Type} | +%% BaseType +%% +%% Data = +%% tuple() -- for struct +%% | list() -- for list +%% | dictionary() -- for map +%% | set() -- for set +%% | term() -- for base types +%% +%% Description: +%%-------------------------------------------------------------------- +write(Proto, {{struct, StructDef}, Data}) + when is_list(StructDef), is_tuple(Data), length(StructDef) == size(Data) - 1 -> + + [StructName | Elems] = tuple_to_list(Data), + ok = write(Proto, #protocol_struct_begin{name = StructName}), + ok = struct_write_loop(Proto, StructDef, Elems), + ok = write(Proto, struct_end), + ok; + +write(Proto, {{struct, {Module, StructureName}}, Data}) + when is_atom(Module), + is_atom(StructureName), + element(1, Data) =:= StructureName -> + StructType = Module:struct_info(StructureName), + write(Proto, {Module:struct_info(StructureName), Data}); + +write(Proto, {{list, Type}, Data}) + when is_list(Data) -> + ok = write(Proto, + #protocol_list_begin{ + etype = term_to_typeid(Type), + size = length(Data) + }), + lists:foreach(fun(Elem) -> + ok = write(Proto, {Type, Elem}) + end, + Data), + ok = write(Proto, list_end), + ok; + +write(Proto, {{map, KeyType, ValType}, Data}) -> + ok = write(Proto, + #protocol_map_begin{ + ktype = term_to_typeid(KeyType), + vtype = term_to_typeid(ValType), + size = dict:size(Data) + }), + dict:fold(fun(KeyData, ValData, _Acc) -> + ok = write(Proto, {KeyType, KeyData}), + ok = write(Proto, {ValType, ValData}) + end, + _AccO = ok, + Data), + ok = write(Proto, map_end), + ok; + +write(Proto, {{set, Type}, Data}) -> + true = sets:is_set(Data), + ok = write(Proto, + #protocol_set_begin{ + etype = term_to_typeid(Type), + size = sets:size(Data) + }), + sets:fold(fun(Elem, _Acc) -> + ok = write(Proto, {Type, Elem}) + end, + _Acc0 = ok, + Data), + ok = write(Proto, set_end), + ok; + +write(#protocol{module = Module, + data = ModuleData}, Data) -> + Module:write(ModuleData, Data). + +struct_write_loop(Proto, [{Fid, Type} | RestStructDef], [Data | RestData]) -> + case Data of + undefined -> + % null fields are skipped in response + skip; + _ -> + ok = write(Proto, + #protocol_field_begin{ + type = term_to_typeid(Type), + id = Fid + }), + ok = write(Proto, {Type, Data}), + ok = write(Proto, field_end) + end, + struct_write_loop(Proto, RestStructDef, RestData); +struct_write_loop(Proto, [], []) -> + ok = write(Proto, field_stop), + ok. diff --git a/lib/erl/src/thrift_server.erl b/lib/erl/src/thrift_server.erl new file mode 100644 index 000000000..5d0012ba2 --- /dev/null +++ b/lib/erl/src/thrift_server.erl @@ -0,0 +1,183 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_server). + +-behaviour(gen_server). + +%% API +-export([start_link/3, stop/1, take_socket/2]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-define(SERVER, ?MODULE). + +-record(state, {listen_socket, acceptor_ref, service, handler}). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +start_link(Port, Service, HandlerModule) when is_integer(Port), is_atom(HandlerModule) -> + gen_server:start_link({local, ?SERVER}, ?MODULE, {Port, Service, HandlerModule}, []). + +%%-------------------------------------------------------------------- +%% Function: stop(Pid) -> ok, {error, Reason} +%% Description: Stops the server. +%%-------------------------------------------------------------------- +stop(Pid) when is_pid(Pid) -> + gen_server:call(Pid, stop). + + +take_socket(Server, Socket) -> + gen_server:call(Server, {take_socket, Socket}). + + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init({Port, Service, Handler}) -> + {ok, Socket} = gen_tcp:listen(Port, + [binary, + {packet, 0}, + {active, false}, + {nodelay, true}, + {reuseaddr, true}]), + {ok, Ref} = prim_inet:async_accept(Socket, -1), + {ok, #state{listen_socket = Socket, + acceptor_ref = Ref, + service = Service, + handler = Handler}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call(stop, _From, State) -> + {stop, stopped, ok, State}; + +handle_call({take_socket, Socket}, {FromPid, _Tag}, State) -> + Result = gen_tcp:controlling_process(Socket, FromPid), + {reply, Result, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast(_Msg, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info({inet_async, ListenSocket, Ref, {ok, ClientSocket}}, + State = #state{listen_socket = ListenSocket, + acceptor_ref = Ref, + service = Service, + handler = Handler}) -> + case set_sockopt(ListenSocket, ClientSocket) of + ok -> + %% New client connected - start processor + start_processor(ClientSocket, Service, Handler), + {ok, NewRef} = prim_inet:async_accept(ListenSocket, -1), + {noreply, State#state{acceptor_ref = NewRef}}; + {error, Reason} -> + error_logger:error_msg("Couldn't set socket opts: ~p~n", + [Reason]), + {stop, Reason, State} + end; + +handle_info({inet_async, ListenSocket, Ref, Error}, State) -> + error_logger:error_msg("Error in acceptor: ~p~n", [Error]), + {stop, Error, State}; + +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(_Reason, _State) -> + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +set_sockopt(ListenSocket, ClientSocket) -> + true = inet_db:register_socket(ClientSocket, inet_tcp), + case prim_inet:getopts(ListenSocket, + [active, nodelay, keepalive, delay_send, priority, tos]) of + {ok, Opts} -> + case prim_inet:setopts(ClientSocket, Opts) of + ok -> ok; + Error -> gen_tcp:close(ClientSocket), + Error + end; + Error -> + gen_tcp:close(ClientSocket), + Error + end. + +start_processor(Socket, Service, Handler) -> + Server = self(), + + ProtoGen = fun() -> + % Become the controlling process + ok = take_socket(Server, Socket), + {ok, SocketTransport} = thrift_socket_transport:new(Socket), + {ok, BufferedTransport} = thrift_buffered_transport:new(SocketTransport), + {ok, Protocol} = thrift_binary_protocol:new(BufferedTransport), + {ok, Protocol, Protocol} + end, + + spawn(thrift_processor, init, [{Server, ProtoGen, Service, Handler}]). diff --git a/lib/erl/src/thrift_service.erl b/lib/erl/src/thrift_service.erl new file mode 100644 index 000000000..2ed7b57b0 --- /dev/null +++ b/lib/erl/src/thrift_service.erl @@ -0,0 +1,25 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_service). + +-export([behaviour_info/1]). + +behaviour_info(callbacks) -> + [{function_info, 2}]. diff --git a/lib/erl/src/thrift_socket_server.erl b/lib/erl/src/thrift_socket_server.erl new file mode 100644 index 000000000..62bdfdaf6 --- /dev/null +++ b/lib/erl/src/thrift_socket_server.erl @@ -0,0 +1,249 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_socket_server). + +-behaviour(gen_server). + +-export([start/1, stop/1]). + +-export([init/1, handle_call/3, handle_cast/2, terminate/2, code_change/3, + handle_info/2]). + +-export([acceptor_loop/1]). + +-record(thrift_socket_server, + {port, + service, + handler, + name, + max=2048, + ip=any, + listen=null, + acceptor=null, + socket_opts=[{recv_timeout, 500}] + }). + +start(State=#thrift_socket_server{}) -> + start_server(State); +start(Options) -> + start(parse_options(Options)). + +stop(Name) when is_atom(Name) -> + gen_server:cast(Name, stop); +stop(Pid) when is_pid(Pid) -> + gen_server:cast(Pid, stop); +stop({local, Name}) -> + stop(Name); +stop({global, Name}) -> + stop(Name); +stop(Options) -> + State = parse_options(Options), + stop(State#thrift_socket_server.name). + +%% Internal API + +parse_options(Options) -> + parse_options(Options, #thrift_socket_server{}). + +parse_options([], State) -> + State; +parse_options([{name, L} | Rest], State) when is_list(L) -> + Name = {local, list_to_atom(L)}, + parse_options(Rest, State#thrift_socket_server{name=Name}); +parse_options([{name, A} | Rest], State) when is_atom(A) -> + Name = {local, A}, + parse_options(Rest, State#thrift_socket_server{name=Name}); +parse_options([{name, Name} | Rest], State) -> + parse_options(Rest, State#thrift_socket_server{name=Name}); +parse_options([{port, L} | Rest], State) when is_list(L) -> + Port = list_to_integer(L), + parse_options(Rest, State#thrift_socket_server{port=Port}); +parse_options([{port, Port} | Rest], State) -> + parse_options(Rest, State#thrift_socket_server{port=Port}); +parse_options([{ip, Ip} | Rest], State) -> + ParsedIp = case Ip of + any -> + any; + Ip when is_tuple(Ip) -> + Ip; + Ip when is_list(Ip) -> + {ok, IpTuple} = inet_parse:address(Ip), + IpTuple + end, + parse_options(Rest, State#thrift_socket_server{ip=ParsedIp}); +parse_options([{socket_opts, L} | Rest], State) when is_list(L), length(L) > 0 -> + parse_options(Rest, State#thrift_socket_server{socket_opts=L}); +parse_options([{handler, Handler} | Rest], State) -> + parse_options(Rest, State#thrift_socket_server{handler=Handler}); +parse_options([{service, Service} | Rest], State) -> + parse_options(Rest, State#thrift_socket_server{service=Service}); +parse_options([{max, Max} | Rest], State) -> + MaxInt = case Max of + Max when is_list(Max) -> + list_to_integer(Max); + Max when is_integer(Max) -> + Max + end, + parse_options(Rest, State#thrift_socket_server{max=MaxInt}). + +start_server(State=#thrift_socket_server{name=Name}) -> + case Name of + undefined -> + gen_server:start_link(?MODULE, State, []); + _ -> + gen_server:start_link(Name, ?MODULE, State, []) + end. + +init(State=#thrift_socket_server{ip=Ip, port=Port}) -> + process_flag(trap_exit, true), + BaseOpts = [binary, + {reuseaddr, true}, + {packet, 0}, + {backlog, 4096}, + {recbuf, 8192}, + {active, false}], + Opts = case Ip of + any -> + BaseOpts; + Ip -> + [{ip, Ip} | BaseOpts] + end, + case gen_tcp_listen(Port, Opts, State) of + {stop, eacces} -> + %% fdsrv module allows another shot to bind + %% ports which require root access + case Port < 1024 of + true -> + case fdsrv:start() of + {ok, _} -> + case fdsrv:bind_socket(tcp, Port) of + {ok, Fd} -> + gen_tcp_listen(Port, [{fd, Fd} | Opts], State); + _ -> + {stop, fdsrv_bind_failed} + end; + _ -> + {stop, fdsrv_start_failed} + end; + false -> + {stop, eacces} + end; + Other -> + error_logger:info_msg("thrift service listening on port ~p", [Port]), + Other + end. + +gen_tcp_listen(Port, Opts, State) -> + case gen_tcp:listen(Port, Opts) of + {ok, Listen} -> + {ok, ListenPort} = inet:port(Listen), + {ok, new_acceptor(State#thrift_socket_server{listen=Listen, + port=ListenPort})}; + {error, Reason} -> + {stop, Reason} + end. + +new_acceptor(State=#thrift_socket_server{max=0}) -> + error_logger:error_msg("Not accepting new connections"), + State#thrift_socket_server{acceptor=null}; +new_acceptor(State=#thrift_socket_server{acceptor=OldPid, listen=Listen, + service=Service, handler=Handler, + socket_opts=Opts + }) -> + Pid = proc_lib:spawn_link(?MODULE, acceptor_loop, + [{self(), Listen, Service, Handler, Opts}]), +%% error_logger:info_msg("Spawning new acceptor: ~p => ~p", [OldPid, Pid]), + State#thrift_socket_server{acceptor=Pid}. + +acceptor_loop({Server, Listen, Service, Handler, SocketOpts}) + when is_pid(Server), is_list(SocketOpts) -> + case catch gen_tcp:accept(Listen) of % infinite timeout + {ok, Socket} -> + gen_server:cast(Server, {accepted, self()}), + ProtoGen = fun() -> + {ok, SocketTransport} = thrift_socket_transport:new(Socket, SocketOpts), + {ok, BufferedTransport} = thrift_buffered_transport:new(SocketTransport), + {ok, Protocol} = thrift_binary_protocol:new(BufferedTransport), + {ok, IProt=Protocol, OProt=Protocol} + end, + thrift_processor:init({Server, ProtoGen, Service, Handler}); + {error, closed} -> + exit({error, closed}); + Other -> + error_logger:error_report( + [{application, thrift}, + "Accept failed error", + lists:flatten(io_lib:format("~p", [Other]))]), + exit({error, accept_failed}) + end. + +handle_call({get, port}, _From, State=#thrift_socket_server{port=Port}) -> + {reply, Port, State}; +handle_call(_Message, _From, State) -> + Res = error, + {reply, Res, State}. + +handle_cast({accepted, Pid}, + State=#thrift_socket_server{acceptor=Pid, max=Max}) -> + % io:format("accepted ~p~n", [Pid]), + State1 = State#thrift_socket_server{max=Max - 1}, + {noreply, new_acceptor(State1)}; +handle_cast(stop, State) -> + {stop, normal, State}. + +terminate(_Reason, #thrift_socket_server{listen=Listen, port=Port}) -> + gen_tcp:close(Listen), + case Port < 1024 of + true -> + catch fdsrv:stop(), + ok; + false -> + ok + end. + +code_change(_OldVsn, State, _Extra) -> + State. + +handle_info({'EXIT', Pid, normal}, + State=#thrift_socket_server{acceptor=Pid}) -> + {noreply, new_acceptor(State)}; +handle_info({'EXIT', Pid, Reason}, + State=#thrift_socket_server{acceptor=Pid}) -> + error_logger:error_report({?MODULE, ?LINE, + {acceptor_error, Reason}}), + timer:sleep(100), + {noreply, new_acceptor(State)}; +handle_info({'EXIT', _LoopPid, Reason}, + State=#thrift_socket_server{acceptor=Pid, max=Max}) -> + case Reason of + normal -> ok; + shutdown -> ok; + _ -> error_logger:error_report({?MODULE, ?LINE, + {child_error, Reason, erlang:get_stacktrace()}}) + end, + State1 = State#thrift_socket_server{max=Max + 1}, + State2 = case Pid of + null -> new_acceptor(State1); + _ -> State1 + end, + {noreply, State2}; +handle_info(Info, State) -> + error_logger:info_report([{'INFO', Info}, {'State', State}]), + {noreply, State}. diff --git a/lib/erl/src/thrift_socket_transport.erl b/lib/erl/src/thrift_socket_transport.erl new file mode 100644 index 000000000..fcd69449e --- /dev/null +++ b/lib/erl/src/thrift_socket_transport.erl @@ -0,0 +1,119 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_socket_transport). + +-behaviour(thrift_transport). + +-export([new/1, + new/2, + write/2, read/2, flush/1, close/1, + + new_transport_factory/3]). + +-record(data, {socket, + recv_timeout=infinity}). + +new(Socket) -> + new(Socket, []). + +new(Socket, Opts) when is_list(Opts) -> + State = + case lists:keysearch(recv_timeout, 1, Opts) of + {value, {recv_timeout, Timeout}} + when is_integer(Timeout), Timeout > 0 -> + #data{socket=Socket, recv_timeout=Timeout}; + _ -> + #data{socket=Socket} + end, + thrift_transport:new(?MODULE, State). + +%% Data :: iolist() +write(#data{socket = Socket}, Data) -> + gen_tcp:send(Socket, Data). + +read(#data{socket=Socket, recv_timeout=Timeout}, Len) + when is_integer(Len), Len >= 0 -> + case gen_tcp:recv(Socket, Len, Timeout) of + Err = {error, timeout} -> + error_logger:info_msg("read timeout: peer conn ~p", [inet:peername(Socket)]), + gen_tcp:close(Socket), + Err; + Data -> Data + end. + +%% We can't really flush - everything is flushed when we write +flush(_) -> + ok. + +close(#data{socket = Socket}) -> + gen_tcp:close(Socket). + + +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +%% The following "local" record is filled in by parse_factory_options/2 +%% below. These options can be passed to new_protocol_factory/3 in a +%% proplists-style option list. They're parsed like this so it is an O(n) +%% operation instead of O(n^2) +-record(factory_opts, {connect_timeout = infinity, + sockopts = [], + framed = false}). + +parse_factory_options([], Opts) -> + Opts; +parse_factory_options([{framed, Bool} | Rest], Opts) when is_boolean(Bool) -> + parse_factory_options(Rest, Opts#factory_opts{framed=Bool}); +parse_factory_options([{sockopts, OptList} | Rest], Opts) when is_list(OptList) -> + parse_factory_options(Rest, Opts#factory_opts{sockopts=OptList}); +parse_factory_options([{connect_timeout, TO} | Rest], Opts) when TO =:= infinity; is_integer(TO) -> + parse_factory_options(Rest, Opts#factory_opts{connect_timeout=TO}). + + +%% +%% Generates a "transport factory" function - a fun which returns a thrift_transport() +%% instance. +%% This can be passed into a protocol factory to generate a connection to a +%% thrift server over a socket. +%% +new_transport_factory(Host, Port, Options) -> + ParsedOpts = parse_factory_options(Options, #factory_opts{}), + + F = fun() -> + SockOpts = [binary, + {packet, 0}, + {active, false}, + {nodelay, true} | + ParsedOpts#factory_opts.sockopts], + case catch gen_tcp:connect(Host, Port, SockOpts, + ParsedOpts#factory_opts.connect_timeout) of + {ok, Sock} -> + {ok, Transport} = thrift_socket_transport:new(Sock), + {ok, BufTransport} = + case ParsedOpts#factory_opts.framed of + true -> thrift_framed_transport:new(Transport); + false -> thrift_buffered_transport:new(Transport) + end, + {ok, BufTransport}; + Error -> + Error + end + end, + {ok, F}. diff --git a/lib/erl/src/thrift_transport.erl b/lib/erl/src/thrift_transport.erl new file mode 100644 index 000000000..20c4b5dc3 --- /dev/null +++ b/lib/erl/src/thrift_transport.erl @@ -0,0 +1,57 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_transport). + +-export([behaviour_info/1]). + +-export([new/2, + write/2, + read/2, + flush/1, + close/1 + ]). + +behaviour_info(callbacks) -> + [{read, 2}, + {write, 2}, + {flush, 1}, + {close, 1} + ]. + +-record(transport, {module, data}). + +new(Module, Data) when is_atom(Module) -> + {ok, #transport{module = Module, + data = Data}}. + +%% Data :: iolist() +write(Transport, Data) -> + Module = Transport#transport.module, + Module:write(Transport#transport.data, Data). + +read(Transport, Len) when is_integer(Len) -> + Module = Transport#transport.module, + Module:read(Transport#transport.data, Len). + +flush(#transport{module = Module, data = Data}) -> + Module:flush(Data). + +close(#transport{module = Module, data = Data}) -> + Module:close(Data). diff --git a/lib/erl/vsn.mk b/lib/erl/vsn.mk new file mode 100644 index 000000000..d9b40014b --- /dev/null +++ b/lib/erl/vsn.mk @@ -0,0 +1 @@ +THRIFT_VSN=0.1 diff --git a/lib/hs/.gitignore b/lib/hs/.gitignore new file mode 100644 index 000000000..849ddff3b --- /dev/null +++ b/lib/hs/.gitignore @@ -0,0 +1 @@ +dist/ diff --git a/lib/hs/README b/lib/hs/README new file mode 100644 index 000000000..e58c8c93f --- /dev/null +++ b/lib/hs/README @@ -0,0 +1,82 @@ +Haskell Thrift Bindings + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Running +======= + +You need -fglasgow-exts. Use Cabal to compile and install. If you're trying to +manually compile or load via ghci, and you're using ghc 6.10 (or really if your +default base package has major version number 4), you must specify a version of +the base package with major version number 3. Furthermore if you have the syb +package installed you need to hide that package to avoid import conflicts. +Here's an example of what I'm talking about: + + ghci -fglasgow-exts -package base-3.0.3.0 -hide-package syb -isrc Thrift.hs + +To determine which versions of the base package you have installed use the +following command: + + ghc-pkg list base + +All of this is taken care of for you if you use Cabal. + + +Enums +===== + +become haskell data types. Use fromEnum to get out the int value. + +Structs +======= + +become records. Field labels are ugly, of the form f_STRUCTNAME_FIELDNAME. All +fields are Maybe types. + +Exceptions +========== + +identical to structs. Throw them with throwDyn. Catch them with catchDyn. + +Client +====== + +just a bunch of functions. You may have to import a bunch of client files to +deal with inheritance. + +Interface +========= + +You should only have to import the last one in the chain of inheritors. To make +an interface, declare a label: + + data MyIface = MyIface + +and then declare it an instance of each iface class, starting with the superest +class and proceding down (all the while defining the methods). Then pass your +label to process as the handler. + +Processor +========= + +Just a function that takes a handler label, protocols. It calls the +superclasses process if there is a superclass. + diff --git a/lib/hs/Setup.lhs b/lib/hs/Setup.lhs new file mode 100644 index 000000000..c9e6d9705 --- /dev/null +++ b/lib/hs/Setup.lhs @@ -0,0 +1,23 @@ +#!/usr/bin/env runhaskell + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +> import Distribution.Simple +> main = defaultMain diff --git a/lib/hs/TODO b/lib/hs/TODO new file mode 100644 index 000000000..136817321 --- /dev/null +++ b/lib/hs/TODO @@ -0,0 +1,2 @@ +The library could stand to be built up more. +Many modules need export lists. diff --git a/lib/hs/Thrift.cabal b/lib/hs/Thrift.cabal new file mode 100644 index 000000000..4cef4de6d --- /dev/null +++ b/lib/hs/Thrift.cabal @@ -0,0 +1,20 @@ +Name: Thrift +Version: 0.1.0 +Cabal-Version: >= 1.2 +License: Apache2 +Category: Foreign +Build-Type: Simple +Synopsis: Thrift library package + +Library + Hs-Source-Dirs: + src + Build-Depends: + base >=4, network, ghc-prim + ghc-options: + -fglasgow-exts + Extensions: + DeriveDataTypeable + Exposed-Modules: + Thrift, Thrift.Protocol, Thrift.Transport, Thrift.Protocol.Binary + Thrift.Transport.Handle, Thrift.Server diff --git a/lib/hs/src/Thrift.hs b/lib/hs/src/Thrift.hs new file mode 100644 index 000000000..291bcae53 --- /dev/null +++ b/lib/hs/src/Thrift.hs @@ -0,0 +1,111 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift + ( module Thrift.Transport + , module Thrift.Protocol + , AppExnType(..) + , AppExn(..) + , readAppExn + , writeAppExn + , ThriftException(..) + ) where + +import Control.Monad ( when ) +import Control.Exception + +import Data.Typeable ( Typeable ) + +import Thrift.Transport +import Thrift.Protocol + + +data ThriftException = ThriftException + deriving ( Show, Typeable ) +instance Exception ThriftException + +data AppExnType + = AE_UNKNOWN + | AE_UNKNOWN_METHOD + | AE_INVALID_MESSAGE_TYPE + | AE_WRONG_METHOD_NAME + | AE_BAD_SEQUENCE_ID + | AE_MISSING_RESULT + deriving ( Eq, Show, Typeable ) + +instance Enum AppExnType where + toEnum 0 = AE_UNKNOWN + toEnum 1 = AE_UNKNOWN_METHOD + toEnum 2 = AE_INVALID_MESSAGE_TYPE + toEnum 3 = AE_WRONG_METHOD_NAME + toEnum 4 = AE_BAD_SEQUENCE_ID + toEnum 5 = AE_MISSING_RESULT + + fromEnum AE_UNKNOWN = 0 + fromEnum AE_UNKNOWN_METHOD = 1 + fromEnum AE_INVALID_MESSAGE_TYPE = 2 + fromEnum AE_WRONG_METHOD_NAME = 3 + fromEnum AE_BAD_SEQUENCE_ID = 4 + fromEnum AE_MISSING_RESULT = 5 + +data AppExn = AppExn { ae_type :: AppExnType, ae_message :: String } + deriving ( Show, Typeable ) +instance Exception AppExn + +writeAppExn :: (Protocol p, Transport t) => p t -> AppExn -> IO () +writeAppExn pt ae = do + writeStructBegin pt "TApplicationException" + + when (ae_message ae /= "") $ do + writeFieldBegin pt ("message", T_STRING , 1) + writeString pt (ae_message ae) + writeFieldEnd pt + + writeFieldBegin pt ("type", T_I32, 2); + writeI32 pt (fromEnum (ae_type ae)) + writeFieldEnd pt + writeFieldStop pt + writeStructEnd pt + +readAppExn :: (Protocol p, Transport t) => p t -> IO AppExn +readAppExn pt = do + readStructBegin pt + rec <- readAppExnFields pt (AppExn {ae_type = undefined, ae_message = undefined}) + readStructEnd pt + return rec + +readAppExnFields pt rec = do + (n, ft, id) <- readFieldBegin pt + if ft == T_STOP + then return rec + else case id of + 1 -> if ft == T_STRING then + do s <- readString pt + readAppExnFields pt rec{ae_message = s} + else do skip pt ft + readAppExnFields pt rec + 2 -> if ft == T_I32 then + do i <- readI32 pt + readAppExnFields pt rec{ae_type = (toEnum i)} + else do skip pt ft + readAppExnFields pt rec + _ -> do skip pt ft + readFieldEnd pt + readAppExnFields pt rec + diff --git a/lib/hs/src/Thrift/Protocol.hs b/lib/hs/src/Thrift/Protocol.hs new file mode 100644 index 000000000..8fa060ea5 --- /dev/null +++ b/lib/hs/src/Thrift/Protocol.hs @@ -0,0 +1,191 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Protocol + ( Protocol(..) + , skip + , MessageType(..) + , ThriftType(..) + , ProtocolExn(..) + , ProtocolExnType(..) + ) where + +import Control.Monad ( replicateM_, unless ) +import Control.Exception + +import Data.Typeable ( Typeable ) +import Data.Int + +import Thrift.Transport + + +data ThriftType + = T_STOP + | T_VOID + | T_BOOL + | T_BYTE + | T_DOUBLE + | T_I16 + | T_I32 + | T_I64 + | T_STRING + | T_STRUCT + | T_MAP + | T_SET + | T_LIST + deriving ( Eq ) + +instance Enum ThriftType where + fromEnum T_STOP = 0 + fromEnum T_VOID = 1 + fromEnum T_BOOL = 2 + fromEnum T_BYTE = 3 + fromEnum T_DOUBLE = 4 + fromEnum T_I16 = 6 + fromEnum T_I32 = 8 + fromEnum T_I64 = 10 + fromEnum T_STRING = 11 + fromEnum T_STRUCT = 12 + fromEnum T_MAP = 13 + fromEnum T_SET = 14 + fromEnum T_LIST = 15 + + toEnum 0 = T_STOP + toEnum 1 = T_VOID + toEnum 2 = T_BOOL + toEnum 3 = T_BYTE + toEnum 4 = T_DOUBLE + toEnum 6 = T_I16 + toEnum 8 = T_I32 + toEnum 10 = T_I64 + toEnum 11 = T_STRING + toEnum 12 = T_STRUCT + toEnum 13 = T_MAP + toEnum 14 = T_SET + toEnum 15 = T_LIST + +data MessageType + = M_CALL + | M_REPLY + | M_EXCEPTION + deriving ( Eq ) + +instance Enum MessageType where + fromEnum M_CALL = 1 + fromEnum M_REPLY = 2 + fromEnum M_EXCEPTION = 3 + + toEnum 1 = M_CALL + toEnum 2 = M_REPLY + toEnum 3 = M_EXCEPTION + + +class Protocol a where + getTransport :: Transport t => a t -> t + + writeMessageBegin :: Transport t => a t -> (String, MessageType, Int) -> IO () + writeMessageEnd :: Transport t => a t -> IO () + + writeStructBegin :: Transport t => a t -> String -> IO () + writeStructEnd :: Transport t => a t -> IO () + writeFieldBegin :: Transport t => a t -> (String, ThriftType, Int) -> IO () + writeFieldEnd :: Transport t => a t -> IO () + writeFieldStop :: Transport t => a t -> IO () + writeMapBegin :: Transport t => a t -> (ThriftType, ThriftType, Int) -> IO () + writeMapEnd :: Transport t => a t -> IO () + writeListBegin :: Transport t => a t -> (ThriftType, Int) -> IO () + writeListEnd :: Transport t => a t -> IO () + writeSetBegin :: Transport t => a t -> (ThriftType, Int) -> IO () + writeSetEnd :: Transport t => a t -> IO () + + writeBool :: Transport t => a t -> Bool -> IO () + writeByte :: Transport t => a t -> Int -> IO () + writeI16 :: Transport t => a t -> Int -> IO () + writeI32 :: Transport t => a t -> Int -> IO () + writeI64 :: Transport t => a t -> Int64 -> IO () + writeDouble :: Transport t => a t -> Double -> IO () + writeString :: Transport t => a t -> String -> IO () + writeBinary :: Transport t => a t -> String -> IO () + + + readMessageBegin :: Transport t => a t -> IO (String, MessageType, Int) + readMessageEnd :: Transport t => a t -> IO () + + readStructBegin :: Transport t => a t -> IO String + readStructEnd :: Transport t => a t -> IO () + readFieldBegin :: Transport t => a t -> IO (String, ThriftType, Int) + readFieldEnd :: Transport t => a t -> IO () + readMapBegin :: Transport t => a t -> IO (ThriftType, ThriftType, Int) + readMapEnd :: Transport t => a t -> IO () + readListBegin :: Transport t => a t -> IO (ThriftType, Int) + readListEnd :: Transport t => a t -> IO () + readSetBegin :: Transport t => a t -> IO (ThriftType, Int) + readSetEnd :: Transport t => a t -> IO () + + readBool :: Transport t => a t -> IO Bool + readByte :: Transport t => a t -> IO Int + readI16 :: Transport t => a t -> IO Int + readI32 :: Transport t => a t -> IO Int + readI64 :: Transport t => a t -> IO Int64 + readDouble :: Transport t => a t -> IO Double + readString :: Transport t => a t -> IO String + readBinary :: Transport t => a t -> IO String + + +skip :: (Protocol p, Transport t) => p t -> ThriftType -> IO () +skip p T_STOP = return () +skip p T_VOID = return () +skip p T_BOOL = readBool p >> return () +skip p T_BYTE = readByte p >> return () +skip p T_I16 = readI16 p >> return () +skip p T_I32 = readI32 p >> return () +skip p T_I64 = readI64 p >> return () +skip p T_DOUBLE = readDouble p >> return () +skip p T_STRING = readString p >> return () +skip p T_STRUCT = do readStructBegin p + skipFields p + readStructEnd p +skip p T_MAP = do (k, v, s) <- readMapBegin p + replicateM_ s (skip p k >> skip p v) + readMapEnd p +skip p T_SET = do (t, n) <- readSetBegin p + replicateM_ n (skip p t) + readSetEnd p +skip p T_LIST = do (t, n) <- readListBegin p + replicateM_ n (skip p t) + readListEnd p + + +skipFields :: (Protocol p, Transport t) => p t -> IO () +skipFields p = do + (_, t, _) <- readFieldBegin p + unless (t == T_STOP) (skip p t >> readFieldEnd p >> skipFields p) + + +data ProtocolExnType + = PE_UNKNOWN + | PE_INVALID_DATA + | PE_NEGATIVE_SIZE + | PE_SIZE_LIMIT + | PE_BAD_VERSION + deriving ( Eq, Show, Typeable ) + +data ProtocolExn = ProtocolExn ProtocolExnType String + deriving ( Show, Typeable ) +instance Exception ProtocolExn diff --git a/lib/hs/src/Thrift/Protocol/Binary.hs b/lib/hs/src/Thrift/Protocol/Binary.hs new file mode 100644 index 000000000..3f798ceea --- /dev/null +++ b/lib/hs/src/Thrift/Protocol/Binary.hs @@ -0,0 +1,147 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Protocol.Binary + ( module Thrift.Protocol + , BinaryProtocol(..) + ) where + +import Control.Exception ( throw ) + +import Data.Bits +import Data.Int +import Data.List ( foldl' ) + +import GHC.Exts +import GHC.Word + +import Thrift.Protocol +import Thrift.Transport + + +version_mask = 0xffff0000 +version_1 = 0x80010000 + +data BinaryProtocol a = Transport a => BinaryProtocol a + + +instance Protocol BinaryProtocol where + getTransport (BinaryProtocol t) = t + + writeMessageBegin p (n, t, s) = do + writeI32 p (version_1 .|. (fromEnum t)) + writeString p n + writeI32 p s + writeMessageEnd _ = return () + + writeStructBegin _ _ = return () + writeStructEnd _ = return () + writeFieldBegin p (_, t, i) = writeType p t >> writeI16 p i + writeFieldEnd _ = return () + writeFieldStop p = writeType p T_STOP + writeMapBegin p (k, v, n) = writeType p k >> writeType p v >> writeI32 p n + writeMapEnd p = return () + writeListBegin p (t, n) = writeType p t >> writeI32 p n + writeListEnd _ = return () + writeSetBegin p (t, n) = writeType p t >> writeI32 p n + writeSetEnd _ = return () + + writeBool p b = tWrite (getTransport p) [toEnum $ if b then 1 else 0] + writeByte p b = tWrite (getTransport p) (getBytes b 1) + writeI16 p b = tWrite (getTransport p) (getBytes b 2) + writeI32 p b = tWrite (getTransport p) (getBytes b 4) + writeI64 p b = tWrite (getTransport p) (getBytes b 8) + writeDouble p d = writeI64 p (fromIntegral $ floatBits d) + writeString p s = writeI32 p (length s) >> tWrite (getTransport p) s + writeBinary = writeString + + readMessageBegin p = do + ver <- readI32 p + if (ver .&. version_mask /= version_1) + then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier" + else do + s <- readString p + sz <- readI32 p + return (s, toEnum $ ver .&. 0xFF, sz) + readMessageEnd _ = return () + readStructBegin _ = return "" + readStructEnd _ = return () + readFieldBegin p = do + t <- readType p + n <- if t /= T_STOP then readI16 p else return 0 + return ("", t, n) + readFieldEnd _ = return () + readMapBegin p = do + kt <- readType p + vt <- readType p + n <- readI32 p + return (kt, vt, n) + readMapEnd _ = return () + readListBegin p = do + t <- readType p + n <- readI32 p + return (t, n) + readListEnd _ = return () + readSetBegin p = do + t <- readType p + n <- readI32 p + return (t, n) + readSetEnd _ = return () + + readBool p = (== 1) `fmap` readByte p + readByte p = do + bs <- tReadAll (getTransport p) 1 + return $ fromIntegral (composeBytes bs :: Int8) + readI16 p = do + bs <- tReadAll (getTransport p) 2 + return $ fromIntegral (composeBytes bs :: Int16) + readI32 p = composeBytes `fmap` tReadAll (getTransport p) 4 + readI64 p = composeBytes `fmap` tReadAll (getTransport p) 8 + readDouble p = do + bs <- readI64 p + return $ floatOfBits $ fromIntegral bs + readString p = readI32 p >>= tReadAll (getTransport p) + readBinary = readString + + +-- | Write a type as a byte +writeType :: (Protocol p, Transport t) => p t -> ThriftType -> IO () +writeType p t = writeByte p (fromEnum t) + +-- | Read a byte as though it were a ThriftType +readType :: (Protocol p, Transport t) => p t -> IO ThriftType +readType p = toEnum `fmap` readByte p + +composeBytes :: (Bits b, Enum t) => [t] -> b +composeBytes = (foldl' fn 0) . (map $ fromIntegral . fromEnum) + where fn acc b = (acc `shiftL` 8) .|. b + +getByte :: Bits a => a -> Int -> a +getByte i n = 255 .&. (i `shiftR` (8 * n)) + +getBytes :: (Bits a, Integral a) => a -> Int -> String +getBytes i 0 = [] +getBytes i n = (toEnum $ fromIntegral $ getByte i (n-1)):(getBytes i (n-1)) + +floatBits :: Double -> Word64 +floatBits (D# d#) = W64# (unsafeCoerce# d#) + +floatOfBits :: Word64 -> Double +floatOfBits (W64# b#) = D# (unsafeCoerce# b#) + diff --git a/lib/hs/src/Thrift/Server.hs b/lib/hs/src/Thrift/Server.hs new file mode 100644 index 000000000..770965f1e --- /dev/null +++ b/lib/hs/src/Thrift/Server.hs @@ -0,0 +1,65 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Server + ( runBasicServer + , runThreadedServer + ) where + +import Control.Concurrent ( forkIO ) +import Control.Exception +import Control.Monad ( forever, when ) + +import Network + +import System.IO + +import Thrift +import Thrift.Transport.Handle +import Thrift.Protocol.Binary + + +-- | A threaded sever that is capable of using any Transport or Protocol +-- instances. +runThreadedServer :: (Transport t, Protocol i, Protocol o) + => (Socket -> IO (i t, o t)) + -> h + -> (h -> (i t, o t) -> IO Bool) + -> PortID + -> IO a +runThreadedServer accepter hand proc port = do + socket <- listenOn port + acceptLoop (accepter socket) (proc hand) + +-- | A basic threaded binary protocol socket server. +runBasicServer :: h + -> (h -> (BinaryProtocol Handle, BinaryProtocol Handle) -> IO Bool) + -> PortNumber + -> IO a +runBasicServer hand proc port = runThreadedServer binaryAccept hand proc (PortNumber port) + where binaryAccept s = do + (h, _, _) <- accept s + return (BinaryProtocol h, BinaryProtocol h) + +acceptLoop :: IO t -> (t -> IO Bool) -> IO a +acceptLoop accepter proc = forever $ + do ps <- accepter + forkIO $ handle (\(e :: SomeException) -> return ()) + (loop $ proc ps) + where loop m = do { continue <- m; when continue (loop m) } diff --git a/lib/hs/src/Thrift/Transport.hs b/lib/hs/src/Thrift/Transport.hs new file mode 100644 index 000000000..29f50d07a --- /dev/null +++ b/lib/hs/src/Thrift/Transport.hs @@ -0,0 +1,60 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Transport + ( Transport(..) + , TransportExn(..) + , TransportExnType(..) + ) where + +import Control.Monad ( when ) +import Control.Exception ( Exception, throw ) + +import Data.Typeable ( Typeable ) + + +class Transport a where + tIsOpen :: a -> IO Bool + tClose :: a -> IO () + tRead :: a -> Int -> IO String + tWrite :: a -> String ->IO () + tFlush :: a -> IO () + tReadAll :: a -> Int -> IO String + + tReadAll a 0 = return [] + tReadAll a len = do + result <- tRead a len + let rlen = length result + when (rlen == 0) (throw $ TransportExn "Cannot read. Remote side has closed." TE_UNKNOWN) + if len <= rlen + then return result + else (result ++) `fmap` (tReadAll a (len - rlen)) + +data TransportExn = TransportExn String TransportExnType + deriving ( Show, Typeable ) +instance Exception TransportExn + +data TransportExnType + = TE_UNKNOWN + | TE_NOT_OPEN + | TE_ALREADY_OPEN + | TE_TIMED_OUT + | TE_END_OF_FILE + deriving ( Eq, Show, Typeable ) + diff --git a/lib/hs/src/Thrift/Transport/Handle.hs b/lib/hs/src/Thrift/Transport/Handle.hs new file mode 100644 index 000000000..e49456b5b --- /dev/null +++ b/lib/hs/src/Thrift/Transport/Handle.hs @@ -0,0 +1,58 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Transport.Handle + ( module Thrift.Transport + , HandleSource(..) + ) where + +import Control.Exception ( throw ) +import Control.Monad ( replicateM ) + +import Network + +import System.IO +import System.IO.Error ( isEOFError ) + +import Thrift.Transport + + +instance Transport Handle where + tIsOpen = hIsOpen + tClose h = hClose h + tRead h n = replicateM n (hGetChar h) `catch` handleEOF + tWrite h s = mapM_ (hPutChar h) s + tFlush = hFlush + + +-- | Type class for all types that can open a Handle. This class is used to +-- replace tOpen in the Transport type class. +class HandleSource s where + hOpen :: s -> IO Handle + +instance HandleSource FilePath where + hOpen s = openFile s ReadWriteMode + +instance HandleSource (HostName, PortID) where + hOpen = uncurry connectTo + + +handleEOF e = if isEOFError e + then return [] + else throw $ TransportExn "TChannelTransport: Could not read" TE_UNKNOWN diff --git a/lib/java/Makefile.am b/lib/java/Makefile.am new file mode 100644 index 000000000..0a40496d4 --- /dev/null +++ b/lib/java/Makefile.am @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +EXTRA_DIST = build.xml ivy.xml src test + +all-local: + $(ANT) + +install-exec-hook: + $(ANT) install -Dinstall.path=$(DESTDIR)$(JAVA_PREFIX) \ + -Dinstall.javadoc.path=$(DESTDIR)$(docdir)/java + +# Make sure this doesn't fail if ant is not configured. +clean-local: + ANT=$(ANT) ; if test -z "$$ANT" ; then ANT=: ; fi ; \ + $$ANT clean + +check-local: all + $(ANT) test + diff --git a/lib/java/README b/lib/java/README new file mode 100644 index 000000000..6b8d351b6 --- /dev/null +++ b/lib/java/README @@ -0,0 +1,43 @@ +Thrift Java Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with Java +====================== + +The Thrift Java source is not build using the GNU tools, but rather uses +the Apache Ant build system, which tends to be predominant amongst Java +developers. + +To compile the Java Thrift libraries, simply do the following: + +ant + +Yep, that's easy. Look for libthrift.jar in the base directory. + +To include Thrift in your applications simply add libthrift.jar to your +classpath, or install if in your default system classpath of choice. + +Dependencies +============ + +Apache Ant +http://ant.apache.org/ diff --git a/lib/java/build.xml b/lib/java/build.xml new file mode 100644 index 000000000..0a7c8944d --- /dev/null +++ b/lib/java/build.xml @@ -0,0 +1,192 @@ +<?xml version="1.0"?> +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> +<project name="libthrift" default="dist" basedir="." + xmlns:ivy="antlib:org.apache.ivy.ant"> + + <description>Thrift Build File</description> + + <property name="gen" location="gen-java" /> + <property name="genbean" location="gen-javabean" /> + + <property name="src" location="src" /> + <property name="build" location="build" /> + <property name="javadoc" location="${build}/javadoc" /> + <property name="install.path" value="/usr/local/lib" /> + <property name="src.test" location="test" /> + <property name="build.test" location="${build}/test" /> + <property name="test.thrift.home" location="../../test"/> + + <property file="${user.home}/.thrift-build.properties" /> + + <!-- ivy properties --> + <property name="ivy.version" value="2.0.0-rc2" /> + <property name="ivy.dir" location="${build}/ivy" /> + <property name="ivy.jar" location="${ivy.dir}/ivy-${ivy.version}.jar"/> + <property name="ivy.lib.dir" location="${ivy.dir}/lib" /> + <property name="ivy_repo_url" value="http://repo2.maven.org/maven2/org/apache/ivy/ivy/${ivy.version}/ivy-${ivy.version}.jar"/> + <property name="ivysettings.xml" location="${ivy.dir}/ivysettings.xml" /> + + <path id="compile.classpath"> + <fileset dir="${ivy.lib.dir}"> + <include name="**/*.jar" /> + </fileset> + </path> + + <path id="test.classpath"> + <path refid="compile.classpath" /> + <pathelement location="build/test" /> + <pathelement location="libthrift.jar" /> + </path> + + <target name="init"> + <tstamp /> + <mkdir dir="${build}"/> + <mkdir dir="${build.test}" /> + <!-- + Allow Ivy to be disabled with "-Dnoivy=". + It is kind of a hack to pretend that we already found it, + but Ant doesn't provide an easy way of blocking dependencies + from executing or checking multiple conditions. + --> + <condition property="ivy.found"><isset property="noivy" /></condition> + <condition property="offline"><isset property="noivy" /></condition> + </target> + + <target name="ivy-init-dirs"> + <mkdir dir="${ivy.dir}" /> + <mkdir dir="${ivy.lib.dir}" /> + </target> + + <target name="ivy-download" depends="ivy-init-dirs" description="To download ivy" unless="offline"> + <get src="${ivy_repo_url}" dest="${ivy.jar}" usetimestamp="true"/> + </target> + + <target name="ivy-probe-antlib"> + <condition property="ivy.found"> + <typefound uri="antlib:org.apache.ivy.ant" name="cleancache"/> + </condition> + </target> + + <target name="ivy-init-antlib" depends="ivy-download,ivy-probe-antlib" unless="ivy.found"> + <typedef uri="antlib:org.apache.ivy.ant" onerror="fail" + loaderRef="ivyLoader"> + <classpath> + <pathelement location="${ivy.jar}"/> + </classpath> + </typedef> + <fail> + <condition > + <not> + <typefound uri="antlib:org.apache.ivy.ant" name="cleancache"/> + </not> + </condition> + You need Apache Ivy 2.0 or later from http://ant.apache.org/ + It could not be loaded from ${ivy_repo_url} + </fail> + </target> + + <target name="resolve" depends="ivy-init-antlib" description="retrieve dependencies with ivy" unless="noivy"> + <ivy:retrieve /> + </target> + + <target name="compile" depends="init,resolve"> + <javac srcdir="${src}" destdir="${build}" source="1.5" debug="true" classpathref="compile.classpath" /> + </target> + + <target name="javadoc" depends="init"> + <javadoc sourcepath="${src}" + destdir="${javadoc}" + version="true" + windowtitle="Thrift Java API" + doctitle="Thrift Java API"> + </javadoc> + </target> + + <target name="dist" depends="compile"> + <jar jarfile="libthrift.jar"> + <fileset dir="${build}"> + <include name="**/*.class" /> + </fileset> + <fileset dir="src"> + <include name="**/*.java" /> + </fileset> + </jar> + </target> + + <target name="install" depends="dist,javadoc"> + <exec executable="install"> + <arg line="libthrift.jar ${install.path}" /> + </exec> + <copy todir="${install.javadoc.path}"> + <fileset dir="${javadoc}"> + <include name="**/*" /> + </fileset> + </copy> + </target> + + <target name="clean"> + <delete dir="${build}" /> + <delete dir="${gen}"/> + <delete dir="${genbean}"/> + <delete dir="${javadoc}"/> + <delete file="libthrift.jar" /> + </target> + + <target name="compile-test" description="Build the test suite classes" depends="generate,dist"> + <javac debug="true" srcdir="${gen}" destdir="${build.test}" classpathref="test.classpath" /> + <javac debug="true" srcdir="${genbean}" destdir="${build.test}" classpathref="test.classpath" /> + <javac debug="true" srcdir="${src.test}" destdir="${build.test}" classpathref="test.classpath" /> + </target> + + <target name="test" description="Run the full test suite" depends="compile-test"> + <java classname="org.apache.thrift.test.JSONProtoTest" + classpathref="test.classpath" failonerror="true" /> + <java classname="org.apache.thrift.test.TCompactProtocolTest" + classpathref="test.classpath" failonerror="true" /> + <java classname="org.apache.thrift.test.IdentityTest" + classpathref="test.classpath" failonerror="true" /> + <java classname="org.apache.thrift.test.EqualityTest" + classpathref="test.classpath" failonerror="true" /> + <java classname="org.apache.thrift.test.ToStringTest" + classpathref="test.classpath" failonerror="true" /> + <java classname="org.apache.thrift.test.DeepCopyTest" + classpathref="test.classpath" failonerror="true" /> + <java classname="org.apache.thrift.test.MetaDataTest" + classpathref="test.classpath" failonerror="true" /> + <java classname="org.apache.thrift.test.JavaBeansTest" + classpathref="test.classpath" failonerror="true" /> + </target> + + <target name="generate"> + <exec executable="../../compiler/cpp/thrift"> + <arg line="--gen java:hashcode ${test.thrift.home}/ThriftTest.thrift" /> + </exec> + <exec executable="../../compiler/cpp/thrift"> + <arg line="--gen java:hashcode ${test.thrift.home}/DebugProtoTest.thrift" /> + </exec> + <exec executable="../../compiler/cpp/thrift"> + <arg line="--gen java:hashcode ${test.thrift.home}/OptionalRequiredTest.thrift" /> + </exec> + <exec executable="../../compiler/cpp/thrift"> + <arg line="--gen java:beans,nocamel ${test.thrift.home}/JavaBeansTest.thrift" /> + </exec> + </target> + +</project> diff --git a/lib/java/ivy.xml b/lib/java/ivy.xml new file mode 100644 index 000000000..0b1be5d81 --- /dev/null +++ b/lib/java/ivy.xml @@ -0,0 +1,7 @@ +<ivy-module version="1.0"> + <info organisation="jayasoft" module="hello-ivy" /> + <dependencies> + <dependency org="log4j" name="log4j" rev="1.2.15" conf="default->master"/> + <dependency org="commons-lang" name="commons-lang" rev="2.4" conf="* -> *,!sources,!javadoc"/> + </dependencies> +</ivy-module> diff --git a/lib/java/src/org/apache/thrift/IntRangeSet.java b/lib/java/src/org/apache/thrift/IntRangeSet.java new file mode 100644 index 000000000..5430134de --- /dev/null +++ b/lib/java/src/org/apache/thrift/IntRangeSet.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +/** + * IntRangeSet is a specialized Set<Integer> implementation designed + * specifically to make the generated validate() method calls faster. It groups + * the set values into ranges, and in the contains() call, it does + * num ranges * 2 comparisons max. For the common case, which is a single, + * contiguous range, this approach is about 60% faster than using a HashSet. If + * you had a very ragged value set, like all the odd numbers, for instance, + * then you would end up with pretty poor running time. + */ +public class IntRangeSet implements Set<Integer> { + /** + * This array keeps the bounds of each extent in alternating cells, always + * increasing. Example: [0,5,10,15], which corresponds to 0-5, 10-15. + */ + private int[] extents; + + /** + * We'll keep a duplicate, real HashSet around internally to satisfy some of + * the other set operations. + */ + private Set<Integer> realSet = new HashSet<Integer>(); + + public IntRangeSet(int... values) { + Arrays.sort(values); + + List<Integer> extent_list = new ArrayList<Integer>(); + + int ext_start = values[0]; + int ext_end_so_far = values[0]; + for (int i = 1; i < values.length; i++) { + realSet.add(values[i]); + + if (values[i] == ext_end_so_far + 1) { + // advance the end so far + ext_end_so_far = values[i]; + } else { + // create an extent for everything we saw so far, move on to the next one + extent_list.add(ext_start); + extent_list.add(ext_end_so_far); + ext_start = values[i]; + ext_end_so_far = values[i]; + } + } + extent_list.add(ext_start); + extent_list.add(ext_end_so_far); + + extents = new int[extent_list.size()]; + for (int i = 0; i < extent_list.size(); i++) { + extents[i] = extent_list.get(i); + } + } + + public boolean add(Integer i) { + throw new UnsupportedOperationException(); + } + + public void clear() { + throw new UnsupportedOperationException(); + } + + public boolean addAll(Collection<? extends Integer> arg0) { + throw new UnsupportedOperationException(); + } + + /** + * While this method is here for Set interface compatibility, you should avoid + * using it. It incurs boxing overhead! Use the int method directly, instead. + */ + public boolean contains(Object arg0) { + return contains(((Integer)arg0).intValue()); + } + + /** + * This is much faster, since it doesn't stop at Integer on the way through. + * @param val the value you want to check set membership for + * @return true if val was found, false otherwise + */ + public boolean contains(int val) { + for (int i = 0; i < extents.length / 2; i++) { + if (val < extents[i*2]) { + return false; + } else if (val <= extents[i*2+1]) { + return true; + } + } + + return false; + } + + public boolean containsAll(Collection<?> arg0) { + for (Object o : arg0) { + if (!contains(o)) { + return false; + } + } + return true; + } + + public boolean isEmpty() { + return realSet.isEmpty(); + } + + public Iterator<Integer> iterator() { + return realSet.iterator(); + } + + public boolean remove(Object arg0) { + throw new UnsupportedOperationException(); + } + + public boolean removeAll(Collection<?> arg0) { + throw new UnsupportedOperationException(); + } + + public boolean retainAll(Collection<?> arg0) { + throw new UnsupportedOperationException(); + } + + public int size() { + return realSet.size(); + } + + public Object[] toArray() { + return realSet.toArray(); + } + + public <T> T[] toArray(T[] arg0) { + return realSet.toArray(arg0); + } + + @Override + public String toString() { + String buf = ""; + for (int i = 0; i < extents.length / 2; i++) { + if (i != 0) { + buf += ", "; + } + buf += "[" + extents[i*2] + "," + extents[i*2+1] + "]"; + } + return buf; + } +} diff --git a/lib/java/src/org/apache/thrift/TApplicationException.java b/lib/java/src/org/apache/thrift/TApplicationException.java new file mode 100644 index 000000000..a85e3705e --- /dev/null +++ b/lib/java/src/org/apache/thrift/TApplicationException.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import org.apache.thrift.protocol.TField; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolUtil; +import org.apache.thrift.protocol.TStruct; +import org.apache.thrift.protocol.TType; + +/** + * Application level exception + * + */ +public class TApplicationException extends TException { + + private static final TStruct TAPPLICATION_EXCEPTION_STRUCT = new TStruct("TApplicationException"); + private static final TField MESSAGE_FIELD = new TField("message", TType.STRING, (short)1); + private static final TField TYPE_FIELD = new TField("type", TType.I32, (short)2); + + private static final long serialVersionUID = 1L; + + public static final int UNKNOWN = 0; + public static final int UNKNOWN_METHOD = 1; + public static final int INVALID_MESSAGE_TYPE = 2; + public static final int WRONG_METHOD_NAME = 3; + public static final int BAD_SEQUENCE_ID = 4; + public static final int MISSING_RESULT = 5; + + protected int type_ = UNKNOWN; + + public TApplicationException() { + super(); + } + + public TApplicationException(int type) { + super(); + type_ = type; + } + + public TApplicationException(int type, String message) { + super(message); + type_ = type; + } + + public TApplicationException(String message) { + super(message); + } + + public int getType() { + return type_; + } + + public static TApplicationException read(TProtocol iprot) throws TException { + TField field; + iprot.readStructBegin(); + + String message = null; + int type = UNKNOWN; + + while (true) { + field = iprot.readFieldBegin(); + if (field.type == TType.STOP) { + break; + } + switch (field.id) { + case 1: + if (field.type == TType.STRING) { + message = iprot.readString(); + } else { + TProtocolUtil.skip(iprot, field.type); + } + break; + case 2: + if (field.type == TType.I32) { + type = iprot.readI32(); + } else { + TProtocolUtil.skip(iprot, field.type); + } + break; + default: + TProtocolUtil.skip(iprot, field.type); + break; + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + + return new TApplicationException(type, message); + } + + public void write(TProtocol oprot) throws TException { + oprot.writeStructBegin(TAPPLICATION_EXCEPTION_STRUCT); + if (getMessage() != null) { + oprot.writeFieldBegin(MESSAGE_FIELD); + oprot.writeString(getMessage()); + oprot.writeFieldEnd(); + } + oprot.writeFieldBegin(TYPE_FIELD); + oprot.writeI32(type_); + oprot.writeFieldEnd(); + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } +} diff --git a/lib/java/src/org/apache/thrift/TBase.java b/lib/java/src/org/apache/thrift/TBase.java new file mode 100644 index 000000000..7c8978a2b --- /dev/null +++ b/lib/java/src/org/apache/thrift/TBase.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import org.apache.thrift.protocol.TProtocol; + +/** + * Generic base interface for generated Thrift objects. + * + */ +public interface TBase extends Cloneable { + + /** + * Reads the TObject from the given input protocol. + * + * @param iprot Input protocol + */ + public void read(TProtocol iprot) throws TException; + + /** + * Writes the objects out to the protocol + * + * @param oprot Output protocol + */ + public void write(TProtocol oprot) throws TException; + + /** + * Check if a field is currently set or unset. + * + * @param fieldId The field's id tag as found in the IDL. + */ + public boolean isSet(int fieldId); + + /** + * Get a field's value by id. Primitive types will be wrapped in the + * appropriate "boxed" types. + * + * @param fieldId The field's id tag as found in the IDL. + */ + public Object getFieldValue(int fieldId); + + /** + * Set a field's value by id. Primitive types must be "boxed" in the + * appropriate object wrapper type. + * + * @param fieldId The field's id tag as found in the IDL. + */ + public void setFieldValue(int fieldId, Object value); +} diff --git a/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java b/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java new file mode 100644 index 000000000..e35fbcb73 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import java.io.ByteArrayOutputStream; + +/** + * Class that allows access to the underlying buf without doing deep + * copies on it. + * + */ +public class TByteArrayOutputStream extends ByteArrayOutputStream { + public TByteArrayOutputStream(int size) { + super(size); + } + + public TByteArrayOutputStream() { + super(); + } + + + public byte[] get() { + return buf; + } + + public int len() { + return count; + } +} diff --git a/lib/java/src/org/apache/thrift/TDeserializer.java b/lib/java/src/org/apache/thrift/TDeserializer.java new file mode 100644 index 000000000..d6dd5d4b1 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TDeserializer.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import java.io.ByteArrayInputStream; +import java.io.UnsupportedEncodingException; + +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TIOStreamTransport; + +/** + * Generic utility for easily deserializing objects from a byte array or Java + * String. + * + */ +public class TDeserializer { + private final TProtocolFactory protocolFactory_; + + /** + * Create a new TDeserializer that uses the TBinaryProtocol by default. + */ + public TDeserializer() { + this(new TBinaryProtocol.Factory()); + } + + /** + * Create a new TDeserializer. It will use the TProtocol specified by the + * factory that is passed in. + * + * @param protocolFactory Factory to create a protocol + */ + public TDeserializer(TProtocolFactory protocolFactory) { + protocolFactory_ = protocolFactory; + } + + /** + * Deserialize the Thrift object from a byte array. + * + * @param base The object to read into + * @param bytes The array to read from + */ + public void deserialize(TBase base, byte[] bytes) throws TException { + base.read( + protocolFactory_.getProtocol( + new TIOStreamTransport( + new ByteArrayInputStream(bytes)))); + } + + /** + * Deserialize the Thrift object from a Java string, using a specified + * character set for decoding. + * + * @param base The object to read into + * @param data The string to read from + * @param charset Valid JVM charset + */ + public void deserialize(TBase base, String data, String charset) throws TException { + try { + deserialize(base, data.getBytes(charset)); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT ENCODING: " + charset); + } + } + + /** + * Deserialize the Thrift object from a Java string, using the default JVM + * charset encoding. + * + * @param base The object to read into + * @param data The string to read from + */ + public void toString(TBase base, String data) throws TException { + deserialize(base, data.getBytes()); + } +} + diff --git a/lib/java/src/org/apache/thrift/TException.java b/lib/java/src/org/apache/thrift/TException.java new file mode 100644 index 000000000..f84f4812e --- /dev/null +++ b/lib/java/src/org/apache/thrift/TException.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +/** + * Generic exception class for Thrift. + * + */ +public class TException extends Exception { + + private static final long serialVersionUID = 1L; + + public TException() { + super(); + } + + public TException(String message) { + super(message); + } + + public TException(Throwable cause) { + super(cause); + } + + public TException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/lib/java/src/org/apache/thrift/TFieldRequirementType.java b/lib/java/src/org/apache/thrift/TFieldRequirementType.java new file mode 100644 index 000000000..74bac4eff --- /dev/null +++ b/lib/java/src/org/apache/thrift/TFieldRequirementType.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +/** + * Requirement type constants. + * + */ +public final class TFieldRequirementType { + public static final byte REQUIRED = 1; + public static final byte OPTIONAL = 2; + public static final byte DEFAULT = 3; +} diff --git a/lib/java/src/org/apache/thrift/TProcessor.java b/lib/java/src/org/apache/thrift/TProcessor.java new file mode 100644 index 000000000..d79522c3e --- /dev/null +++ b/lib/java/src/org/apache/thrift/TProcessor.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import org.apache.thrift.protocol.TProtocol; + +/** + * A processor is a generic object which operates upon an input stream and + * writes to some output stream. + * + */ +public interface TProcessor { + public boolean process(TProtocol in, TProtocol out) + throws TException; +} diff --git a/lib/java/src/org/apache/thrift/TProcessorFactory.java b/lib/java/src/org/apache/thrift/TProcessorFactory.java new file mode 100644 index 000000000..bcd8a38fd --- /dev/null +++ b/lib/java/src/org/apache/thrift/TProcessorFactory.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import org.apache.thrift.transport.TTransport; + +/** + * The default processor factory just returns a singleton + * instance. + */ +public class TProcessorFactory { + + private final TProcessor processor_; + + public TProcessorFactory(TProcessor processor) { + processor_ = processor; + } + + public TProcessor getProcessor(TTransport trans) { + return processor_; + } +} diff --git a/lib/java/src/org/apache/thrift/TSerializer.java b/lib/java/src/org/apache/thrift/TSerializer.java new file mode 100644 index 000000000..4e1ce6129 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TSerializer.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import java.io.ByteArrayOutputStream; +import java.io.UnsupportedEncodingException; + +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TIOStreamTransport; + +/** + * Generic utility for easily serializing objects into a byte array or Java + * String. + * + */ +public class TSerializer { + + /** + * This is the byte array that data is actually serialized into + */ + private final ByteArrayOutputStream baos_ = new ByteArrayOutputStream(); + + /** + * This transport wraps that byte array + */ + private final TIOStreamTransport transport_ = new TIOStreamTransport(baos_); + + /** + * Internal protocol used for serializing objects. + */ + private TProtocol protocol_; + + /** + * Create a new TSerializer that uses the TBinaryProtocol by default. + */ + public TSerializer() { + this(new TBinaryProtocol.Factory()); + } + + /** + * Create a new TSerializer. It will use the TProtocol specified by the + * factory that is passed in. + * + * @param protocolFactory Factory to create a protocol + */ + public TSerializer(TProtocolFactory protocolFactory) { + protocol_ = protocolFactory.getProtocol(transport_); + } + + /** + * Serialize the Thrift object into a byte array. The process is simple, + * just clear the byte array output, write the object into it, and grab the + * raw bytes. + * + * @param base The object to serialize + * @return Serialized object in byte[] format + */ + public byte[] serialize(TBase base) throws TException { + baos_.reset(); + base.write(protocol_); + return baos_.toByteArray(); + } + + /** + * Serialize the Thrift object into a Java string, using a specified + * character set for encoding. + * + * @param base The object to serialize + * @param charset Valid JVM charset + * @return Serialized object as a String + */ + public String toString(TBase base, String charset) throws TException { + try { + return new String(serialize(base), charset); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT ENCODING: " + charset); + } + } + + /** + * Serialize the Thrift object into a Java string, using the default JVM + * charset encoding. + * + * @param base The object to serialize + * @return Serialized object as a String + */ + public String toString(TBase base) throws TException { + return new String(serialize(base)); + } +} + diff --git a/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java b/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java new file mode 100644 index 000000000..3e90a8b9e --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +import java.util.HashMap; +import java.util.Map; +import org.apache.thrift.TBase; + +/** + * This class is used to store meta data about thrift fields. Every field in a + * a struct should have a corresponding instance of this class describing it. + * + */ +public class FieldMetaData implements java.io.Serializable { + public final String fieldName; + public final byte requirementType; + public final FieldValueMetaData valueMetaData; + private static Map<Class<? extends TBase>, Map<Integer, FieldMetaData>> structMap; + + static { + structMap = new HashMap<Class<? extends TBase>, Map<Integer, FieldMetaData>>(); + } + + public FieldMetaData(String name, byte req, FieldValueMetaData vMetaData){ + this.fieldName = name; + this.requirementType = req; + this.valueMetaData = vMetaData; + } + + public static void addStructMetaDataMap(Class<? extends TBase> sClass, Map<Integer, FieldMetaData> map){ + structMap.put(sClass, map); + } + + /** + * Returns a map with metadata (i.e. instances of FieldMetaData) that + * describe the fields of the given class. + * + * @param sClass The TBase class for which the metadata map is requested + */ + public static Map<Integer, FieldMetaData> getStructMetaDataMap(Class<? extends TBase> sClass){ + if (!structMap.containsKey(sClass)){ // Load class if it hasn't been loaded + try{ + sClass.newInstance(); + } catch (InstantiationException e){ + throw new RuntimeException("InstantiationException for TBase class: " + sClass.getName() + ", message: " + e.getMessage()); + } catch (IllegalAccessException e){ + throw new RuntimeException("IllegalAccessException for TBase class: " + sClass.getName() + ", message: " + e.getMessage()); + } + } + return structMap.get(sClass); + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/FieldValueMetaData.java b/lib/java/src/org/apache/thrift/meta_data/FieldValueMetaData.java new file mode 100644 index 000000000..f72da0cd2 --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/FieldValueMetaData.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +import org.apache.thrift.protocol.TType; + +/** + * FieldValueMetaData and collection of subclasses to store metadata about + * the value(s) of a field + */ +public class FieldValueMetaData implements java.io.Serializable { + public final byte type; + + public FieldValueMetaData(byte type){ + this.type = type; + } + + public boolean isStruct() { + return type == TType.STRUCT; + } + + public boolean isContainer() { + return type == TType.LIST || type == TType.MAP || type == TType.SET; + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/ListMetaData.java b/lib/java/src/org/apache/thrift/meta_data/ListMetaData.java new file mode 100644 index 000000000..8e7073bf5 --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/ListMetaData.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +public class ListMetaData extends FieldValueMetaData { + public final FieldValueMetaData elemMetaData; + + public ListMetaData(byte type, FieldValueMetaData eMetaData){ + super(type); + this.elemMetaData = eMetaData; + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/MapMetaData.java b/lib/java/src/org/apache/thrift/meta_data/MapMetaData.java new file mode 100644 index 000000000..e7c408c78 --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/MapMetaData.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +public class MapMetaData extends FieldValueMetaData { + public final FieldValueMetaData keyMetaData; + public final FieldValueMetaData valueMetaData; + + public MapMetaData(byte type, FieldValueMetaData kMetaData, FieldValueMetaData vMetaData){ + super(type); + this.keyMetaData = kMetaData; + this.valueMetaData = vMetaData; + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/SetMetaData.java b/lib/java/src/org/apache/thrift/meta_data/SetMetaData.java new file mode 100644 index 000000000..cf4b96aab --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/SetMetaData.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +public class SetMetaData extends FieldValueMetaData { + public final FieldValueMetaData elemMetaData; + + public SetMetaData(byte type, FieldValueMetaData eMetaData){ + super(type); + this.elemMetaData = eMetaData; + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/StructMetaData.java b/lib/java/src/org/apache/thrift/meta_data/StructMetaData.java new file mode 100644 index 000000000..b37d21dab --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/StructMetaData.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +import org.apache.thrift.TBase; + +public class StructMetaData extends FieldValueMetaData { + public final Class<? extends TBase> structClass; + + public StructMetaData(byte type, Class<? extends TBase> sClass){ + super(type); + this.structClass = sClass; + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TBase64Utils.java b/lib/java/src/org/apache/thrift/protocol/TBase64Utils.java new file mode 100644 index 000000000..37a9fd9f9 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TBase64Utils.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Class for encoding and decoding Base64 data. + * + * This class is kept at package level because the interface does no input + * validation and is therefore too low-level for generalized reuse. + * + * Note also that the encoding does not pad with equal signs , as discussed in + * section 2.2 of the RFC (http://www.faqs.org/rfcs/rfc3548.html). Furthermore, + * bad data encountered when decoding is neither rejected or ignored but simply + * results in bad decoded data -- this is not in compliance with the RFC but is + * done in the interest of performance. + * + */ +class TBase64Utils { + + private static final String ENCODE_TABLE = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + /** + * Encode len bytes of data in src at offset srcOff, storing the result into + * dst at offset dstOff. len must be 1, 2, or 3. dst must have at least len+1 + * bytes of space at dstOff. src and dst should not be the same object. This + * method does no validation of the input values in the interest of + * performance. + * + * @param src the source of bytes to encode + * @param srcOff the offset into the source to read the unencoded bytes + * @param len the number of bytes to encode (must be 1, 2, or 3). + * @param dst the destination for the encoding + * @param dstOff the offset into the destination to place the encoded bytes + */ + static final void encode(byte[] src, int srcOff, int len, byte[] dst, + int dstOff) { + dst[dstOff] = (byte)ENCODE_TABLE.charAt((src[srcOff] >> 2) & 0x3F); + if (len == 3) { + dst[dstOff + 1] = + (byte)ENCODE_TABLE.charAt( + ((src[srcOff] << 4) + (src[srcOff+1] >> 4)) & 0x3F); + dst[dstOff + 2] = + (byte)ENCODE_TABLE.charAt( + ((src[srcOff+1] << 2) + (src[srcOff+2] >> 6)) & 0x3F); + dst[dstOff + 3] = + (byte)ENCODE_TABLE.charAt(src[srcOff+2] & 0x3F); + } + else if (len == 2) { + dst[dstOff+1] = + (byte)ENCODE_TABLE.charAt( + ((src[srcOff] << 4) + (src[srcOff+1] >> 4)) & 0x3F); + dst[dstOff + 2] = + (byte)ENCODE_TABLE.charAt((src[srcOff+1] << 2) & 0x3F); + + } + else { // len == 1) { + dst[dstOff + 1] = + (byte)ENCODE_TABLE.charAt((src[srcOff] << 4) & 0x3F); + } + } + + private static final byte[] DECODE_TABLE = { + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,62,-1,-1,-1,63, + 52,53,54,55,56,57,58,59,60,61,-1,-1,-1,-1,-1,-1, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14, + 15,16,17,18,19,20,21,22,23,24,25,-1,-1,-1,-1,-1, + -1,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40, + 41,42,43,44,45,46,47,48,49,50,51,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + }; + + /** + * Decode len bytes of data in src at offset srcOff, storing the result into + * dst at offset dstOff. len must be 2, 3, or 4. dst must have at least len-1 + * bytes of space at dstOff. src and dst may be the same object as long as + * dstoff <= srcOff. This method does no validation of the input values in + * the interest of performance. + * + * @param src the source of bytes to decode + * @param srcOff the offset into the source to read the encoded bytes + * @param len the number of bytes to decode (must be 2, 3, or 4) + * @param dst the destination for the decoding + * @param dstOff the offset into the destination to place the decoded bytes + */ + static final void decode(byte[] src, int srcOff, int len, byte[] dst, + int dstOff) { + dst[dstOff] = (byte) + ((DECODE_TABLE[src[srcOff] & 0x0FF] << 2) | + (DECODE_TABLE[src[srcOff+1] & 0x0FF] >> 4)); + if (len > 2) { + dst[dstOff+1] = (byte) + (((DECODE_TABLE[src[srcOff+1] & 0x0FF] << 4) & 0xF0) | + (DECODE_TABLE[src[srcOff+2] & 0x0FF] >> 2)); + if (len > 3) { + dst[dstOff+2] = (byte) + (((DECODE_TABLE[src[srcOff+2] & 0x0FF] << 6) & 0xC0) | + DECODE_TABLE[src[srcOff+3] & 0x0FF]); + } + } + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java new file mode 100644 index 000000000..e9bd8b796 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import java.io.UnsupportedEncodingException; + +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; + +/** + * Binary protocol implementation for thrift. + * + */ +public class TBinaryProtocol extends TProtocol { + private static final TStruct ANONYMOUS_STRUCT = new TStruct(); + + protected static final int VERSION_MASK = 0xffff0000; + protected static final int VERSION_1 = 0x80010000; + + protected boolean strictRead_ = false; + protected boolean strictWrite_ = true; + + protected int readLength_; + protected boolean checkReadLength_ = false; + + /** + * Factory + */ + public static class Factory implements TProtocolFactory { + protected boolean strictRead_ = false; + protected boolean strictWrite_ = true; + + public Factory() { + this(false, true); + } + + public Factory(boolean strictRead, boolean strictWrite) { + strictRead_ = strictRead; + strictWrite_ = strictWrite; + } + + public TProtocol getProtocol(TTransport trans) { + return new TBinaryProtocol(trans, strictRead_, strictWrite_); + } + } + + /** + * Constructor + */ + public TBinaryProtocol(TTransport trans) { + this(trans, false, true); + } + + public TBinaryProtocol(TTransport trans, boolean strictRead, boolean strictWrite) { + super(trans); + strictRead_ = strictRead; + strictWrite_ = strictWrite; + } + + public void writeMessageBegin(TMessage message) throws TException { + if (strictWrite_) { + int version = VERSION_1 | message.type; + writeI32(version); + writeString(message.name); + writeI32(message.seqid); + } else { + writeString(message.name); + writeByte(message.type); + writeI32(message.seqid); + } + } + + public void writeMessageEnd() {} + + public void writeStructBegin(TStruct struct) {} + + public void writeStructEnd() {} + + public void writeFieldBegin(TField field) throws TException { + writeByte(field.type); + writeI16(field.id); + } + + public void writeFieldEnd() {} + + public void writeFieldStop() throws TException { + writeByte(TType.STOP); + } + + public void writeMapBegin(TMap map) throws TException { + writeByte(map.keyType); + writeByte(map.valueType); + writeI32(map.size); + } + + public void writeMapEnd() {} + + public void writeListBegin(TList list) throws TException { + writeByte(list.elemType); + writeI32(list.size); + } + + public void writeListEnd() {} + + public void writeSetBegin(TSet set) throws TException { + writeByte(set.elemType); + writeI32(set.size); + } + + public void writeSetEnd() {} + + public void writeBool(boolean b) throws TException { + writeByte(b ? (byte)1 : (byte)0); + } + + private byte [] bout = new byte[1]; + public void writeByte(byte b) throws TException { + bout[0] = b; + trans_.write(bout, 0, 1); + } + + private byte[] i16out = new byte[2]; + public void writeI16(short i16) throws TException { + i16out[0] = (byte)(0xff & (i16 >> 8)); + i16out[1] = (byte)(0xff & (i16)); + trans_.write(i16out, 0, 2); + } + + private byte[] i32out = new byte[4]; + public void writeI32(int i32) throws TException { + i32out[0] = (byte)(0xff & (i32 >> 24)); + i32out[1] = (byte)(0xff & (i32 >> 16)); + i32out[2] = (byte)(0xff & (i32 >> 8)); + i32out[3] = (byte)(0xff & (i32)); + trans_.write(i32out, 0, 4); + } + + private byte[] i64out = new byte[8]; + public void writeI64(long i64) throws TException { + i64out[0] = (byte)(0xff & (i64 >> 56)); + i64out[1] = (byte)(0xff & (i64 >> 48)); + i64out[2] = (byte)(0xff & (i64 >> 40)); + i64out[3] = (byte)(0xff & (i64 >> 32)); + i64out[4] = (byte)(0xff & (i64 >> 24)); + i64out[5] = (byte)(0xff & (i64 >> 16)); + i64out[6] = (byte)(0xff & (i64 >> 8)); + i64out[7] = (byte)(0xff & (i64)); + trans_.write(i64out, 0, 8); + } + + public void writeDouble(double dub) throws TException { + writeI64(Double.doubleToLongBits(dub)); + } + + public void writeString(String str) throws TException { + try { + byte[] dat = str.getBytes("UTF-8"); + writeI32(dat.length); + trans_.write(dat, 0, dat.length); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + public void writeBinary(byte[] bin) throws TException { + writeI32(bin.length); + trans_.write(bin, 0, bin.length); + } + + /** + * Reading methods. + */ + + public TMessage readMessageBegin() throws TException { + int size = readI32(); + if (size < 0) { + int version = size & VERSION_MASK; + if (version != VERSION_1) { + throw new TProtocolException(TProtocolException.BAD_VERSION, "Bad version in readMessageBegin"); + } + return new TMessage(readString(), (byte)(size & 0x000000ff), readI32()); + } else { + if (strictRead_) { + throw new TProtocolException(TProtocolException.BAD_VERSION, "Missing version in readMessageBegin, old client?"); + } + return new TMessage(readStringBody(size), readByte(), readI32()); + } + } + + public void readMessageEnd() {} + + public TStruct readStructBegin() { + return ANONYMOUS_STRUCT; + } + + public void readStructEnd() {} + + public TField readFieldBegin() throws TException { + byte type = readByte(); + short id = type == TType.STOP ? 0 : readI16(); + return new TField("", type, id); + } + + public void readFieldEnd() {} + + public TMap readMapBegin() throws TException { + return new TMap(readByte(), readByte(), readI32()); + } + + public void readMapEnd() {} + + public TList readListBegin() throws TException { + return new TList(readByte(), readI32()); + } + + public void readListEnd() {} + + public TSet readSetBegin() throws TException { + return new TSet(readByte(), readI32()); + } + + public void readSetEnd() {} + + public boolean readBool() throws TException { + return (readByte() == 1); + } + + private byte[] bin = new byte[1]; + public byte readByte() throws TException { + readAll(bin, 0, 1); + return bin[0]; + } + + private byte[] i16rd = new byte[2]; + public short readI16() throws TException { + readAll(i16rd, 0, 2); + return + (short) + (((i16rd[0] & 0xff) << 8) | + ((i16rd[1] & 0xff))); + } + + private byte[] i32rd = new byte[4]; + public int readI32() throws TException { + readAll(i32rd, 0, 4); + return + ((i32rd[0] & 0xff) << 24) | + ((i32rd[1] & 0xff) << 16) | + ((i32rd[2] & 0xff) << 8) | + ((i32rd[3] & 0xff)); + } + + private byte[] i64rd = new byte[8]; + public long readI64() throws TException { + readAll(i64rd, 0, 8); + return + ((long)(i64rd[0] & 0xff) << 56) | + ((long)(i64rd[1] & 0xff) << 48) | + ((long)(i64rd[2] & 0xff) << 40) | + ((long)(i64rd[3] & 0xff) << 32) | + ((long)(i64rd[4] & 0xff) << 24) | + ((long)(i64rd[5] & 0xff) << 16) | + ((long)(i64rd[6] & 0xff) << 8) | + ((long)(i64rd[7] & 0xff)); + } + + public double readDouble() throws TException { + return Double.longBitsToDouble(readI64()); + } + + public String readString() throws TException { + int size = readI32(); + return readStringBody(size); + } + + public String readStringBody(int size) throws TException { + try { + checkReadLength(size); + byte[] buf = new byte[size]; + trans_.readAll(buf, 0, size); + return new String(buf, "UTF-8"); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + public byte[] readBinary() throws TException { + int size = readI32(); + checkReadLength(size); + byte[] buf = new byte[size]; + trans_.readAll(buf, 0, size); + return buf; + } + + private int readAll(byte[] buf, int off, int len) throws TException { + checkReadLength(len); + return trans_.readAll(buf, off, len); + } + + public void setReadLength(int readLength) { + readLength_ = readLength; + checkReadLength_ = true; + } + + protected void checkReadLength(int length) throws TException { + if (checkReadLength_) { + readLength_ -= length; + if (readLength_ < 0) { + throw new TException("Message length exceeded: " + length); + } + } + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java new file mode 100755 index 000000000..e2d0bfdc1 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java @@ -0,0 +1,741 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.protocol; + +import java.util.Stack; +import java.io.UnsupportedEncodingException; + +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.TException; + +/** + * TCompactProtocol2 is the Java implementation of the compact protocol specified + * in THRIFT-110. The fundamental approach to reducing the overhead of + * structures is a) use variable-length integers all over the place and b) make + * use of unused bits wherever possible. Your savings will obviously vary + * based on the specific makeup of your structs, but in general, the more + * fields, nested structures, short strings and collections, and low-value i32 + * and i64 fields you have, the more benefit you'll see. + */ +public final class TCompactProtocol extends TProtocol { + + private final static TStruct ANONYMOUS_STRUCT = new TStruct(""); + private final static TField TSTOP = new TField("", TType.STOP, (short)0); + + private final static byte[] ttypeToCompactType = new byte[16]; + + static { + ttypeToCompactType[TType.STOP] = TType.STOP; + ttypeToCompactType[TType.BOOL] = Types.BOOLEAN_TRUE; + ttypeToCompactType[TType.BYTE] = Types.BYTE; + ttypeToCompactType[TType.I16] = Types.I16; + ttypeToCompactType[TType.I32] = Types.I32; + ttypeToCompactType[TType.I64] = Types.I64; + ttypeToCompactType[TType.DOUBLE] = Types.DOUBLE; + ttypeToCompactType[TType.STRING] = Types.BINARY; + ttypeToCompactType[TType.LIST] = Types.LIST; + ttypeToCompactType[TType.SET] = Types.SET; + ttypeToCompactType[TType.MAP] = Types.MAP; + ttypeToCompactType[TType.STRUCT] = Types.STRUCT; + } + + /** + * TProtocolFactory that produces TCompactProtocols. + */ + public static class Factory implements TProtocolFactory { + public Factory() {} + + public TProtocol getProtocol(TTransport trans) { + return new TCompactProtocol(trans); + } + } + + private static final byte PROTOCOL_ID = (byte)0x82; + private static final byte VERSION = 1; + private static final byte VERSION_MASK = 0x1f; // 0001 1111 + private static final byte TYPE_MASK = (byte)0xE0; // 1110 0000 + private static final int TYPE_SHIFT_AMOUNT = 5; + + /** + * All of the on-wire type codes. + */ + private static class Types { + public static final byte BOOLEAN_TRUE = 0x01; + public static final byte BOOLEAN_FALSE = 0x02; + public static final byte BYTE = 0x03; + public static final byte I16 = 0x04; + public static final byte I32 = 0x05; + public static final byte I64 = 0x06; + public static final byte DOUBLE = 0x07; + public static final byte BINARY = 0x08; + public static final byte LIST = 0x09; + public static final byte SET = 0x0A; + public static final byte MAP = 0x0B; + public static final byte STRUCT = 0x0C; + } + + /** + * Used to keep track of the last field for the current and previous structs, + * so we can do the delta stuff. + */ + private Stack<Short> lastField_ = new Stack<Short>(); + + private short lastFieldId_ = 0; + + /** + * If we encounter a boolean field begin, save the TField here so it can + * have the value incorporated. + */ + private TField booleanField_ = null; + + /** + * If we read a field header, and it's a boolean field, save the boolean + * value here so that readBool can use it. + */ + private Boolean boolValue_ = null; + + /** + * Create a TCompactProtocol. + * + * @param transport the TTransport object to read from or write to. + */ + public TCompactProtocol(TTransport transport) { + super(transport); + } + + + // + // Public Writing methods. + // + + /** + * Write a message header to the wire. Compact Protocol messages contain the + * protocol version so we can migrate forwards in the future if need be. + */ + public void writeMessageBegin(TMessage message) throws TException { + writeByteDirect(PROTOCOL_ID); + writeByteDirect((VERSION & VERSION_MASK) | ((message.type << TYPE_SHIFT_AMOUNT) & TYPE_MASK)); + writeVarint32(message.seqid); + writeString(message.name); + } + + /** + * Write a struct begin. This doesn't actually put anything on the wire. We + * use it as an opportunity to put special placeholder markers on the field + * stack so we can get the field id deltas correct. + */ + public void writeStructBegin(TStruct struct) throws TException { + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + } + + /** + * Write a struct end. This doesn't actually put anything on the wire. We use + * this as an opportunity to pop the last field from the current struct off + * of the field stack. + */ + public void writeStructEnd() throws TException { + lastFieldId_ = lastField_.pop(); + } + + /** + * Write a field header containing the field id and field type. If the + * difference between the current field id and the last one is small (< 15), + * then the field id will be encoded in the 4 MSB as a delta. Otherwise, the + * field id will follow the type header as a zigzag varint. + */ + public void writeFieldBegin(TField field) throws TException { + if (field.type == TType.BOOL) { + // we want to possibly include the value, so we'll wait. + booleanField_ = field; + } else { + writeFieldBeginInternal(field, (byte)-1); + } + } + + /** + * The workhorse of writeFieldBegin. It has the option of doing a + * 'type override' of the type header. This is used specifically in the + * boolean field case. + */ + private void writeFieldBeginInternal(TField field, byte typeOverride) throws TException { + // short lastField = lastField_.pop(); + + // if there's a type override, use that. + byte typeToWrite = typeOverride == -1 ? getCompactType(field.type) : typeOverride; + + // check if we can use delta encoding for the field id + if (field.id > lastFieldId_ && field.id - lastFieldId_ <= 15) { + // write them together + writeByteDirect((field.id - lastFieldId_) << 4 | typeToWrite); + } else { + // write them separate + writeByteDirect(typeToWrite); + writeI16(field.id); + } + + lastFieldId_ = field.id; + // lastField_.push(field.id); + } + + /** + * Write the STOP symbol so we know there are no more fields in this struct. + */ + public void writeFieldStop() throws TException { + writeByteDirect(TType.STOP); + } + + /** + * Write a map header. If the map is empty, omit the key and value type + * headers, as we don't need any additional information to skip it. + */ + public void writeMapBegin(TMap map) throws TException { + if (map.size == 0) { + writeByteDirect(0); + } else { + writeVarint32(map.size); + writeByteDirect(getCompactType(map.keyType) << 4 | getCompactType(map.valueType)); + } + } + + /** + * Write a list header. + */ + public void writeListBegin(TList list) throws TException { + writeCollectionBegin(list.elemType, list.size); + } + + /** + * Write a set header. + */ + public void writeSetBegin(TSet set) throws TException { + writeCollectionBegin(set.elemType, set.size); + } + + /** + * Write a boolean value. Potentially, this could be a boolean field, in + * which case the field header info isn't written yet. If so, decide what the + * right type header is for the value and then write the field header. + * Otherwise, write a single byte. + */ + public void writeBool(boolean b) throws TException { + if (booleanField_ != null) { + // we haven't written the field header yet + writeFieldBeginInternal(booleanField_, b ? Types.BOOLEAN_TRUE : Types.BOOLEAN_FALSE); + booleanField_ = null; + } else { + // we're not part of a field, so just write the value. + writeByteDirect(b ? Types.BOOLEAN_TRUE : Types.BOOLEAN_FALSE); + } + } + + /** + * Write a byte. Nothing to see here! + */ + public void writeByte(byte b) throws TException { + writeByteDirect(b); + } + + /** + * Write an I16 as a zigzag varint. + */ + public void writeI16(short i16) throws TException { + writeVarint32(intToZigZag(i16)); + } + + /** + * Write an i32 as a zigzag varint. + */ + public void writeI32(int i32) throws TException { + writeVarint32(intToZigZag(i32)); + } + + /** + * Write an i64 as a zigzag varint. + */ + public void writeI64(long i64) throws TException { + writeVarint64(longToZigzag(i64)); + } + + /** + * Write a double to the wire as 8 bytes. + */ + public void writeDouble(double dub) throws TException { + byte[] data = new byte[]{0, 0, 0, 0, 0, 0, 0, 0}; + fixedLongToBytes(Double.doubleToLongBits(dub), data, 0); + trans_.write(data); + } + + /** + * Write a string to the wire with a varint size preceeding. + */ + public void writeString(String str) throws TException { + try { + writeBinary(str.getBytes("UTF-8")); + } catch (UnsupportedEncodingException e) { + throw new TException("UTF-8 not supported!"); + } + } + + /** + * Write a byte array, using a varint for the size. + */ + public void writeBinary(byte[] bin) throws TException { + writeVarint32(bin.length); + trans_.write(bin); + } + + // + // These methods are called by structs, but don't actually have any wire + // output or purpose. + // + + public void writeMessageEnd() throws TException {} + public void writeMapEnd() throws TException {} + public void writeListEnd() throws TException {} + public void writeSetEnd() throws TException {} + public void writeFieldEnd() throws TException {} + + // + // Internal writing methods + // + + /** + * Abstract method for writing the start of lists and sets. List and sets on + * the wire differ only by the type indicator. + */ + protected void writeCollectionBegin(byte elemType, int size) throws TException { + if (size <= 14) { + writeByteDirect(size << 4 | getCompactType(elemType)); + } else { + writeByteDirect(0xf0 | getCompactType(elemType)); + writeVarint32(size); + } + } + + /** + * Write an i32 as a varint. Results in 1-5 bytes on the wire. + * TODO: make a permanent buffer like writeVarint64? + */ + byte[] i32buf = new byte[5]; + private void writeVarint32(int n) throws TException { + int idx = 0; + while (true) { + if ((n & ~0x7F) == 0) { + i32buf[idx++] = (byte)n; + // writeByteDirect((byte)n); + break; + // return; + } else { + i32buf[idx++] = (byte)((n & 0x7F) | 0x80); + // writeByteDirect((byte)((n & 0x7F) | 0x80)); + n >>>= 7; + } + } + trans_.write(i32buf, 0, idx); + } + + /** + * Write an i64 as a varint. Results in 1-10 bytes on the wire. + */ + byte[] varint64out = new byte[10]; + private void writeVarint64(long n) throws TException { + int idx = 0; + while (true) { + if ((n & ~0x7FL) == 0) { + varint64out[idx++] = (byte)n; + break; + } else { + varint64out[idx++] = ((byte)((n & 0x7F) | 0x80)); + n >>>= 7; + } + } + trans_.write(varint64out, 0, idx); + } + + /** + * Convert l into a zigzag long. This allows negative numbers to be + * represented compactly as a varint. + */ + private long longToZigzag(long l) { + return (l << 1) ^ (l >> 63); + } + + /** + * Convert n into a zigzag int. This allows negative numbers to be + * represented compactly as a varint. + */ + private int intToZigZag(int n) { + return (n << 1) ^ (n >> 31); + } + + /** + * Convert a long into little-endian bytes in buf starting at off and going + * until off+7. + */ + private void fixedLongToBytes(long n, byte[] buf, int off) { + buf[off+0] = (byte)( n & 0xff); + buf[off+1] = (byte)((n >> 8 ) & 0xff); + buf[off+2] = (byte)((n >> 16) & 0xff); + buf[off+3] = (byte)((n >> 24) & 0xff); + buf[off+4] = (byte)((n >> 32) & 0xff); + buf[off+5] = (byte)((n >> 40) & 0xff); + buf[off+6] = (byte)((n >> 48) & 0xff); + buf[off+7] = (byte)((n >> 56) & 0xff); + } + + /** + * Writes a byte without any possiblity of all that field header nonsense. + * Used internally by other writing methods that know they need to write a byte. + */ + private byte[] byteDirectBuffer = new byte[1]; + private void writeByteDirect(byte b) throws TException { + byteDirectBuffer[0] = b; + trans_.write(byteDirectBuffer); + } + + /** + * Writes a byte without any possiblity of all that field header nonsense. + */ + private void writeByteDirect(int n) throws TException { + writeByteDirect((byte)n); + } + + + // + // Reading methods. + // + + /** + * Read a message header. + */ + public TMessage readMessageBegin() throws TException { + byte protocolId = readByte(); + if (protocolId != PROTOCOL_ID) { + throw new TProtocolException("Expected protocol id " + Integer.toHexString(PROTOCOL_ID) + " but got " + Integer.toHexString(protocolId)); + } + byte versionAndType = readByte(); + byte version = (byte)(versionAndType & VERSION_MASK); + if (version != VERSION) { + throw new TProtocolException("Expected version " + VERSION + " but got " + version); + } + byte type = (byte)((versionAndType >> TYPE_SHIFT_AMOUNT) & 0x03); + int seqid = readVarint32(); + String messageName = readString(); + return new TMessage(messageName, type, seqid); + } + + /** + * Read a struct begin. There's nothing on the wire for this, but it is our + * opportunity to push a new struct begin marker onto the field stack. + */ + public TStruct readStructBegin() throws TException { + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + return ANONYMOUS_STRUCT; + } + + /** + * Doesn't actually consume any wire data, just removes the last field for + * this struct from the field stack. + */ + public void readStructEnd() throws TException { + // consume the last field we read off the wire. + lastFieldId_ = lastField_.pop(); + } + + /** + * Read a field header off the wire. + */ + public TField readFieldBegin() throws TException { + byte type = readByte(); + + // if it's a stop, then we can return immediately, as the struct is over. + if ((type & 0x0f) == TType.STOP) { + return TSTOP; + } + + short fieldId; + + // mask off the 4 MSB of the type header. it could contain a field id delta. + short modifier = (short)((type & 0xf0) >> 4); + if (modifier == 0) { + // not a delta. look ahead for the zigzag varint field id. + fieldId = readI16(); + } else { + // has a delta. add the delta to the last read field id. + fieldId = (short)(lastFieldId_ + modifier); + } + + TField field = new TField("", getTType((byte)(type & 0x0f)), fieldId); + + // if this happens to be a boolean field, the value is encoded in the type + if (isBoolType(type)) { + // save the boolean value in a special instance variable. + boolValue_ = (byte)(type & 0x0f) == Types.BOOLEAN_TRUE ? Boolean.TRUE : Boolean.FALSE; + } + + // push the new field onto the field stack so we can keep the deltas going. + lastFieldId_ = field.id; + return field; + } + + /** + * Read a map header off the wire. If the size is zero, skip reading the key + * and value type. This means that 0-length maps will yield TMaps without the + * "correct" types. + */ + public TMap readMapBegin() throws TException { + int size = readVarint32(); + byte keyAndValueType = size == 0 ? 0 : readByte(); + return new TMap(getTType((byte)(keyAndValueType >> 4)), getTType((byte)(keyAndValueType & 0xf)), size); + } + + /** + * Read a list header off the wire. If the list size is 0-14, the size will + * be packed into the element type header. If it's a longer list, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ + public TList readListBegin() throws TException { + byte size_and_type = readByte(); + int size = (size_and_type >> 4) & 0x0f; + if (size == 15) { + size = readVarint32(); + } + byte type = getTType(size_and_type); + return new TList(type, size); + } + + /** + * Read a set header off the wire. If the set size is 0-14, the size will + * be packed into the element type header. If it's a longer set, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ + public TSet readSetBegin() throws TException { + return new TSet(readListBegin()); + } + + /** + * Read a boolean off the wire. If this is a boolean field, the value should + * already have been read during readFieldBegin, so we'll just consume the + * pre-stored value. Otherwise, read a byte. + */ + public boolean readBool() throws TException { + if (boolValue_ != null) { + boolean result = boolValue_.booleanValue(); + boolValue_ = null; + return result; + } + return readByte() == Types.BOOLEAN_TRUE; + } + + byte[] byteRawBuf = new byte[1]; + /** + * Read a single byte off the wire. Nothing interesting here. + */ + public byte readByte() throws TException { + trans_.read(byteRawBuf, 0, 1); + return byteRawBuf[0]; + } + + /** + * Read an i16 from the wire as a zigzag varint. + */ + public short readI16() throws TException { + return (short)zigzagToInt(readVarint32()); + } + + /** + * Read an i32 from the wire as a zigzag varint. + */ + public int readI32() throws TException { + return zigzagToInt(readVarint32()); + } + + /** + * Read an i64 from the wire as a zigzag varint. + */ + public long readI64() throws TException { + return zigzagToLong(readVarint64()); + } + + /** + * No magic here - just read a double off the wire. + */ + public double readDouble() throws TException { + byte[] longBits = new byte[8]; + trans_.read(longBits, 0, 8); + return Double.longBitsToDouble(bytesToLong(longBits)); + } + + /** + * Reads a byte[] (via readBinary), and then UTF-8 decodes it. + */ + public String readString() throws TException { + try { + return new String(readBinary(), "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new TException("UTF-8 not supported!"); + } + } + + /** + * Read a byte[] from the wire. + */ + public byte[] readBinary() throws TException { + int length = readVarint32(); + if (length == 0) return new byte[0]; + + byte[] buf = new byte[length]; + trans_.read(buf, 0, length); + return buf; + } + + + // + // These methods are here for the struct to call, but don't have any wire + // encoding. + // + public void readMessageEnd() throws TException {} + public void readFieldEnd() throws TException {} + public void readMapEnd() throws TException {} + public void readListEnd() throws TException {} + public void readSetEnd() throws TException {} + + // + // Internal reading methods + // + + /** + * Read an i32 from the wire as a varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 5 bytes. + */ + private int readVarint32() throws TException { + // if the wire contains the right stuff, this will just truncate the i64 we + // read and get us the right sign. + return (int)readVarint64(); + } + + /** + * Read an i64 from the wire as a proper varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 10 bytes. + */ + private long readVarint64() throws TException { + int shift = 0; + long result = 0; + while (true) { + byte b = readByte(); + result |= (long) (b & 0x7f) << shift; + if ((b & 0x80) != 0x80) break; + shift +=7; + } + return result; + } + + // + // encoding helpers + // + + /** + * Convert from zigzag int to int. + */ + private int zigzagToInt(int n) { + return (n >>> 1) ^ -(n & 1); + } + + /** + * Convert from zigzag long to long. + */ + private long zigzagToLong(long n) { + return (n >>> 1) ^ -(n & 1); + } + + /** + * Note that it's important that the mask bytes are long literals, + * otherwise they'll default to ints, and when you shift an int left 56 bits, + * you just get a messed up int. + */ + private long bytesToLong(byte[] bytes) { + return + ((bytes[7] & 0xffL) << 56) | + ((bytes[6] & 0xffL) << 48) | + ((bytes[5] & 0xffL) << 40) | + ((bytes[4] & 0xffL) << 32) | + ((bytes[3] & 0xffL) << 24) | + ((bytes[2] & 0xffL) << 16) | + ((bytes[1] & 0xffL) << 8) | + ((bytes[0] & 0xffL)); + } + + // + // type testing and converting + // + + private boolean isBoolType(byte b) { + return (b & 0x0f) == Types.BOOLEAN_TRUE || (b & 0x0f) == Types.BOOLEAN_FALSE; + } + + /** + * Given a TCompactProtocol.Types constant, convert it to its corresponding + * TType value. + */ + private byte getTType(byte type) { + switch ((byte)(type & 0x0f)) { + case TType.STOP: + return TType.STOP; + case Types.BOOLEAN_FALSE: + case Types.BOOLEAN_TRUE: + return TType.BOOL; + case Types.BYTE: + return TType.BYTE; + case Types.I16: + return TType.I16; + case Types.I32: + return TType.I32; + case Types.I64: + return TType.I64; + case Types.DOUBLE: + return TType.DOUBLE; + case Types.BINARY: + return TType.STRING; + case Types.LIST: + return TType.LIST; + case Types.SET: + return TType.SET; + case Types.MAP: + return TType.MAP; + case Types.STRUCT: + return TType.STRUCT; + default: + throw new RuntimeException("don't know what type: " + (byte)(type & 0x0f)); + } + } + + /** + * Given a TType value, find the appropriate TCompactProtocol.Types constant. + */ + private byte getCompactType(byte ttype) { + return ttypeToCompactType[ttype]; + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TField.java b/lib/java/src/org/apache/thrift/protocol/TField.java new file mode 100644 index 000000000..03affdaa1 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TField.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates field metadata. + * + */ +public class TField { + public TField() { + this("", TType.STOP, (short)0); + } + + public TField(String n, byte t, short i) { + name = n; + type = t; + id = i; + } + + public final String name; + public final byte type; + public final short id; + + public String toString() { + return "<TField name:'" + name + "' type:" + type + " field-id:" + id + ">"; + } + + public boolean equals(TField otherField) { + return type == otherField.type && id == otherField.id; + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java b/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java new file mode 100644 index 000000000..631c6a5b5 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java @@ -0,0 +1,927 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.TException; +import org.apache.thrift.TByteArrayOutputStream; +import org.apache.thrift.transport.TTransport; +import java.io.UnsupportedEncodingException; +import java.util.Stack; + +/** + * JSON protocol implementation for thrift. + * + * This is a full-featured protocol supporting write and read. + * + * Please see the C++ class header for a detailed description of the + * protocol's wire format. + * + */ +public class TJSONProtocol extends TProtocol { + + /** + * Factory for JSON protocol objects + */ + public static class Factory implements TProtocolFactory { + + public TProtocol getProtocol(TTransport trans) { + return new TJSONProtocol(trans); + } + + } + + private static final byte[] COMMA = new byte[] {','}; + private static final byte[] COLON = new byte[] {':'}; + private static final byte[] LBRACE = new byte[] {'{'}; + private static final byte[] RBRACE = new byte[] {'}'}; + private static final byte[] LBRACKET = new byte[] {'['}; + private static final byte[] RBRACKET = new byte[] {']'}; + private static final byte[] QUOTE = new byte[] {'"'}; + private static final byte[] BACKSLASH = new byte[] {'\\'}; + private static final byte[] ZERO = new byte[] {'0'}; + + private static final byte[] ESCSEQ = new byte[] {'\\','u','0','0'}; + + private static final long VERSION = 1; + + private static final byte[] JSON_CHAR_TABLE = { + /* 0 1 2 3 4 5 6 7 8 9 A B C D E F */ + 0, 0, 0, 0, 0, 0, 0, 0,'b','t','n', 0,'f','r', 0, 0, // 0 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 1 + 1, 1,'"', 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2 + }; + + private static final String ESCAPE_CHARS = "\"\\bfnrt"; + + private static final byte[] ESCAPE_CHAR_VALS = { + '"', '\\', '\b', '\f', '\n', '\r', '\t', + }; + + private static final int DEF_STRING_SIZE = 16; + + private static final byte[] NAME_BOOL = new byte[] {'t', 'f'}; + private static final byte[] NAME_BYTE = new byte[] {'i','8'}; + private static final byte[] NAME_I16 = new byte[] {'i','1','6'}; + private static final byte[] NAME_I32 = new byte[] {'i','3','2'}; + private static final byte[] NAME_I64 = new byte[] {'i','6','4'}; + private static final byte[] NAME_DOUBLE = new byte[] {'d','b','l'}; + private static final byte[] NAME_STRUCT = new byte[] {'r','e','c'}; + private static final byte[] NAME_STRING = new byte[] {'s','t','r'}; + private static final byte[] NAME_MAP = new byte[] {'m','a','p'}; + private static final byte[] NAME_LIST = new byte[] {'l','s','t'}; + private static final byte[] NAME_SET = new byte[] {'s','e','t'}; + + private static final TStruct ANONYMOUS_STRUCT = new TStruct(); + + private static final byte[] getTypeNameForTypeID(byte typeID) + throws TException { + switch (typeID) { + case TType.BOOL: + return NAME_BOOL; + case TType.BYTE: + return NAME_BYTE; + case TType.I16: + return NAME_I16; + case TType.I32: + return NAME_I32; + case TType.I64: + return NAME_I64; + case TType.DOUBLE: + return NAME_DOUBLE; + case TType.STRING: + return NAME_STRING; + case TType.STRUCT: + return NAME_STRUCT; + case TType.MAP: + return NAME_MAP; + case TType.SET: + return NAME_SET; + case TType.LIST: + return NAME_LIST; + default: + throw new TProtocolException(TProtocolException.NOT_IMPLEMENTED, + "Unrecognized type"); + } + } + + private static final byte getTypeIDForTypeName(byte[] name) + throws TException { + byte result = TType.STOP; + if (name.length > 1) { + switch (name[0]) { + case 'd': + result = TType.DOUBLE; + break; + case 'i': + switch (name[1]) { + case '8': + result = TType.BYTE; + break; + case '1': + result = TType.I16; + break; + case '3': + result = TType.I32; + break; + case '6': + result = TType.I64; + break; + } + break; + case 'l': + result = TType.LIST; + break; + case 'm': + result = TType.MAP; + break; + case 'r': + result = TType.STRUCT; + break; + case 's': + if (name[1] == 't') { + result = TType.STRING; + } + else if (name[1] == 'e') { + result = TType.SET; + } + break; + case 't': + result = TType.BOOL; + break; + } + } + if (result == TType.STOP) { + throw new TProtocolException(TProtocolException.NOT_IMPLEMENTED, + "Unrecognized type"); + } + return result; + } + + // Base class for tracking JSON contexts that may require inserting/reading + // additional JSON syntax characters + // This base context does nothing. + protected class JSONBaseContext { + protected void write() throws TException {} + + protected void read() throws TException {} + + protected boolean escapeNum() { return false; } + } + + // Context for JSON lists. Will insert/read commas before each item except + // for the first one + protected class JSONListContext extends JSONBaseContext { + private boolean first_ = true; + + @Override + protected void write() throws TException { + if (first_) { + first_ = false; + } else { + trans_.write(COMMA); + } + } + + @Override + protected void read() throws TException { + if (first_) { + first_ = false; + } else { + readJSONSyntaxChar(COMMA); + } + } + } + + // Context for JSON records. Will insert/read colons before the value portion + // of each record pair, and commas before each key except the first. In + // addition, will indicate that numbers in the key position need to be + // escaped in quotes (since JSON keys must be strings). + protected class JSONPairContext extends JSONBaseContext { + private boolean first_ = true; + private boolean colon_ = true; + + @Override + protected void write() throws TException { + if (first_) { + first_ = false; + colon_ = true; + } else { + trans_.write(colon_ ? COLON : COMMA); + colon_ = !colon_; + } + } + + @Override + protected void read() throws TException { + if (first_) { + first_ = false; + colon_ = true; + } else { + readJSONSyntaxChar(colon_ ? COLON : COMMA); + colon_ = !colon_; + } + } + + @Override + protected boolean escapeNum() { + return colon_; + } + } + + // Holds up to one byte from the transport + protected class LookaheadReader { + + private boolean hasData_; + private byte[] data_ = new byte[1]; + + // Return and consume the next byte to be read, either taking it from the + // data buffer if present or getting it from the transport otherwise. + protected byte read() throws TException { + if (hasData_) { + hasData_ = false; + } + else { + trans_.readAll(data_, 0, 1); + } + return data_[0]; + } + + // Return the next byte to be read without consuming, filling the data + // buffer if it has not been filled already. + protected byte peek() throws TException { + if (!hasData_) { + trans_.readAll(data_, 0, 1); + } + hasData_ = true; + return data_[0]; + } + } + + // Stack of nested contexts that we may be in + private Stack<JSONBaseContext> contextStack_ = new Stack<JSONBaseContext>(); + + // Current context that we are in + private JSONBaseContext context_ = new JSONBaseContext(); + + // Reader that manages a 1-byte buffer + private LookaheadReader reader_ = new LookaheadReader(); + + // Push a new JSON context onto the stack. + private void pushContext(JSONBaseContext c) { + contextStack_.push(context_); + context_ = c; + } + + // Pop the last JSON context off the stack + private void popContext() { + context_ = contextStack_.pop(); + } + + /** + * Constructor + */ + public TJSONProtocol(TTransport trans) { + super(trans); + } + + // Temporary buffer used by several methods + private byte[] tmpbuf_ = new byte[4]; + + // Read a byte that must match b[0]; otherwise an excpetion is thrown. + // Marked protected to avoid synthetic accessor in JSONListContext.read + // and JSONPairContext.read + protected void readJSONSyntaxChar(byte[] b) throws TException { + byte ch = reader_.read(); + if (ch != b[0]) { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Unexpected character:" + (char)ch); + } + } + + // Convert a byte containing a hex char ('0'-'9' or 'a'-'f') into its + // corresponding hex value + private static final byte hexVal(byte ch) throws TException { + if ((ch >= '0') && (ch <= '9')) { + return (byte)((char)ch - '0'); + } + else if ((ch >= 'a') && (ch <= 'f')) { + return (byte)((char)ch - 'a'); + } + else { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Expected hex character"); + } + } + + // Convert a byte containing a hex value to its corresponding hex character + private static final byte hexChar(byte val) { + val &= 0x0F; + if (val < 10) { + return (byte)((char)val + '0'); + } + else { + return (byte)((char)val + 'a'); + } + } + + // Write the bytes in array buf as a JSON characters, escaping as needed + private void writeJSONString(byte[] b) throws TException { + context_.write(); + trans_.write(QUOTE); + int len = b.length; + for (int i = 0; i < len; i++) { + if ((b[i] & 0x00FF) >= 0x30) { + if (b[i] == BACKSLASH[0]) { + trans_.write(BACKSLASH); + trans_.write(BACKSLASH); + } + else { + trans_.write(b, i, 1); + } + } + else { + tmpbuf_[0] = JSON_CHAR_TABLE[b[i]]; + if (tmpbuf_[0] == 1) { + trans_.write(b, i, 1); + } + else if (tmpbuf_[0] > 1) { + trans_.write(BACKSLASH); + trans_.write(tmpbuf_, 0, 1); + } + else { + trans_.write(ESCSEQ); + tmpbuf_[0] = hexChar((byte)(b[i] >> 4)); + tmpbuf_[1] = hexChar(b[i]); + trans_.write(tmpbuf_, 0, 2); + } + } + } + trans_.write(QUOTE); + } + + // Write out number as a JSON value. If the context dictates so, it will be + // wrapped in quotes to output as a JSON string. + private void writeJSONInteger(long num) throws TException { + context_.write(); + String str = Long.toString(num); + boolean escapeNum = context_.escapeNum(); + if (escapeNum) { + trans_.write(QUOTE); + } + try { + byte[] buf = str.getBytes("UTF-8"); + trans_.write(buf); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + if (escapeNum) { + trans_.write(QUOTE); + } + } + + // Write out a double as a JSON value. If it is NaN or infinity or if the + // context dictates escaping, write out as JSON string. + private void writeJSONDouble(double num) throws TException { + context_.write(); + String str = Double.toString(num); + boolean special = false; + switch (str.charAt(0)) { + case 'N': // NaN + case 'I': // Infinity + special = true; + break; + case '-': + if (str.charAt(1) == 'I') { // -Infinity + special = true; + } + break; + } + + boolean escapeNum = special || context_.escapeNum(); + if (escapeNum) { + trans_.write(QUOTE); + } + try { + byte[] b = str.getBytes("UTF-8"); + trans_.write(b, 0, b.length); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + if (escapeNum) { + trans_.write(QUOTE); + } + } + + // Write out contents of byte array b as a JSON string with base-64 encoded + // data + private void writeJSONBase64(byte[] b) throws TException { + context_.write(); + trans_.write(QUOTE); + int len = b.length; + int off = 0; + while (len >= 3) { + // Encode 3 bytes at a time + TBase64Utils.encode(b, off, 3, tmpbuf_, 0); + trans_.write(tmpbuf_, 0, 4); + off += 3; + len -= 3; + } + if (len > 0) { + // Encode remainder + TBase64Utils.encode(b, off, len, tmpbuf_, 0); + trans_.write(tmpbuf_, 0, len + 1); + } + trans_.write(QUOTE); + } + + private void writeJSONObjectStart() throws TException { + context_.write(); + trans_.write(LBRACE); + pushContext(new JSONPairContext()); + } + + private void writeJSONObjectEnd() throws TException { + popContext(); + trans_.write(RBRACE); + } + + private void writeJSONArrayStart() throws TException { + context_.write(); + trans_.write(LBRACKET); + pushContext(new JSONListContext()); + } + + private void writeJSONArrayEnd() throws TException { + popContext(); + trans_.write(RBRACKET); + } + + @Override + public void writeMessageBegin(TMessage message) throws TException { + writeJSONArrayStart(); + writeJSONInteger(VERSION); + try { + byte[] b = message.name.getBytes("UTF-8"); + writeJSONString(b); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + writeJSONInteger(message.type); + writeJSONInteger(message.seqid); + } + + @Override + public void writeMessageEnd() throws TException { + writeJSONArrayEnd(); + } + + @Override + public void writeStructBegin(TStruct struct) throws TException { + writeJSONObjectStart(); + } + + @Override + public void writeStructEnd() throws TException { + writeJSONObjectEnd(); + } + + @Override + public void writeFieldBegin(TField field) throws TException { + writeJSONInteger(field.id); + writeJSONObjectStart(); + writeJSONString(getTypeNameForTypeID(field.type)); + } + + @Override + public void writeFieldEnd() throws TException { + writeJSONObjectEnd(); + } + + @Override + public void writeFieldStop() {} + + @Override + public void writeMapBegin(TMap map) throws TException { + writeJSONArrayStart(); + writeJSONString(getTypeNameForTypeID(map.keyType)); + writeJSONString(getTypeNameForTypeID(map.valueType)); + writeJSONInteger(map.size); + writeJSONObjectStart(); + } + + @Override + public void writeMapEnd() throws TException { + writeJSONObjectEnd(); + writeJSONArrayEnd(); + } + + @Override + public void writeListBegin(TList list) throws TException { + writeJSONArrayStart(); + writeJSONString(getTypeNameForTypeID(list.elemType)); + writeJSONInteger(list.size); + } + + @Override + public void writeListEnd() throws TException { + writeJSONArrayEnd(); + } + + @Override + public void writeSetBegin(TSet set) throws TException { + writeJSONArrayStart(); + writeJSONString(getTypeNameForTypeID(set.elemType)); + writeJSONInteger(set.size); + } + + @Override + public void writeSetEnd() throws TException { + writeJSONArrayEnd(); + } + + @Override + public void writeBool(boolean b) throws TException { + writeJSONInteger(b ? (long)1 : (long)0); + } + + @Override + public void writeByte(byte b) throws TException { + writeJSONInteger((long)b); + } + + @Override + public void writeI16(short i16) throws TException { + writeJSONInteger((long)i16); + } + + @Override + public void writeI32(int i32) throws TException { + writeJSONInteger((long)i32); + } + + @Override + public void writeI64(long i64) throws TException { + writeJSONInteger(i64); + } + + @Override + public void writeDouble(double dub) throws TException { + writeJSONDouble(dub); + } + + @Override + public void writeString(String str) throws TException { + try { + byte[] b = str.getBytes("UTF-8"); + writeJSONString(b); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + @Override + public void writeBinary(byte[] bin) throws TException { + writeJSONBase64(bin); + } + + /** + * Reading methods. + */ + + // Read in a JSON string, unescaping as appropriate.. Skip reading from the + // context if skipContext is true. + private TByteArrayOutputStream readJSONString(boolean skipContext) + throws TException { + TByteArrayOutputStream arr = new TByteArrayOutputStream(DEF_STRING_SIZE); + if (!skipContext) { + context_.read(); + } + readJSONSyntaxChar(QUOTE); + while (true) { + byte ch = reader_.read(); + if (ch == QUOTE[0]) { + break; + } + if (ch == ESCSEQ[0]) { + ch = reader_.read(); + if (ch == ESCSEQ[1]) { + readJSONSyntaxChar(ZERO); + readJSONSyntaxChar(ZERO); + trans_.readAll(tmpbuf_, 0, 2); + ch = (byte)((hexVal((byte)tmpbuf_[0]) << 4) + hexVal(tmpbuf_[1])); + } + else { + int off = ESCAPE_CHARS.indexOf(ch); + if (off == -1) { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Expected control char"); + } + ch = ESCAPE_CHAR_VALS[off]; + } + } + arr.write(ch); + } + return arr; + } + + // Return true if the given byte could be a valid part of a JSON number. + private boolean isJSONNumeric(byte b) { + switch (b) { + case '+': + case '-': + case '.': + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case 'E': + case 'e': + return true; + } + return false; + } + + // Read in a sequence of characters that are all valid in JSON numbers. Does + // not do a complete regex check to validate that this is actually a number. + private String readJSONNumericChars() throws TException { + StringBuilder strbld = new StringBuilder(); + while (true) { + byte ch = reader_.peek(); + if (!isJSONNumeric(ch)) { + break; + } + strbld.append((char)reader_.read()); + } + return strbld.toString(); + } + + // Read in a JSON number. If the context dictates, read in enclosing quotes. + private long readJSONInteger() throws TException { + context_.read(); + if (context_.escapeNum()) { + readJSONSyntaxChar(QUOTE); + } + String str = readJSONNumericChars(); + if (context_.escapeNum()) { + readJSONSyntaxChar(QUOTE); + } + try { + return Long.valueOf(str); + } + catch (NumberFormatException ex) { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data"); + } + } + + // Read in a JSON double value. Throw if the value is not wrapped in quotes + // when expected or if wrapped in quotes when not expected. + private double readJSONDouble() throws TException { + context_.read(); + if (reader_.peek() == QUOTE[0]) { + TByteArrayOutputStream arr = readJSONString(true); + try { + double dub = Double.valueOf(arr.toString("UTF-8")); + if (!context_.escapeNum() && !Double.isNaN(dub) && + !Double.isInfinite(dub)) { + // Throw exception -- we should not be in a string in this case + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Numeric data unexpectedly quoted"); + } + return dub; + } + catch (UnsupportedEncodingException ex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + else { + if (context_.escapeNum()) { + // This will throw - we should have had a quote if escapeNum == true + readJSONSyntaxChar(QUOTE); + } + try { + return Double.valueOf(readJSONNumericChars()); + } + catch (NumberFormatException ex) { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data"); + } + } + } + + // Read in a JSON string containing base-64 encoded data and decode it. + private byte[] readJSONBase64() throws TException { + TByteArrayOutputStream arr = readJSONString(false); + byte[] b = arr.get(); + int len = arr.len(); + int off = 0; + int size = 0; + while (len >= 4) { + // Decode 4 bytes at a time + TBase64Utils.decode(b, off, 4, b, size); // NB: decoded in place + off += 4; + len -= 4; + size += 3; + } + // Don't decode if we hit the end or got a single leftover byte (invalid + // base64 but legal for skip of regular string type) + if (len > 1) { + // Decode remainder + TBase64Utils.decode(b, off, len, b, size); // NB: decoded in place + size += len - 1; + } + // Sadly we must copy the byte[] (any way around this?) + byte [] result = new byte[size]; + System.arraycopy(b, 0, result, 0, size); + return result; + } + + private void readJSONObjectStart() throws TException { + context_.read(); + readJSONSyntaxChar(LBRACE); + pushContext(new JSONPairContext()); + } + + private void readJSONObjectEnd() throws TException { + readJSONSyntaxChar(RBRACE); + popContext(); + } + + private void readJSONArrayStart() throws TException { + context_.read(); + readJSONSyntaxChar(LBRACKET); + pushContext(new JSONListContext()); + } + + private void readJSONArrayEnd() throws TException { + readJSONSyntaxChar(RBRACKET); + popContext(); + } + + @Override + public TMessage readMessageBegin() throws TException { + readJSONArrayStart(); + if (readJSONInteger() != VERSION) { + throw new TProtocolException(TProtocolException.BAD_VERSION, + "Message contained bad version."); + } + String name; + try { + name = readJSONString(false).toString("UTF-8"); + } + catch (UnsupportedEncodingException ex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + byte type = (byte) readJSONInteger(); + int seqid = (int) readJSONInteger(); + return new TMessage(name, type, seqid); + } + + @Override + public void readMessageEnd() throws TException { + readJSONArrayEnd(); + } + + @Override + public TStruct readStructBegin() throws TException { + readJSONObjectStart(); + return ANONYMOUS_STRUCT; + } + + @Override + public void readStructEnd() throws TException { + readJSONObjectEnd(); + } + + @Override + public TField readFieldBegin() throws TException { + byte ch = reader_.peek(); + byte type; + short id = 0; + if (ch == RBRACE[0]) { + type = TType.STOP; + } + else { + id = (short) readJSONInteger(); + readJSONObjectStart(); + type = getTypeIDForTypeName(readJSONString(false).get()); + } + return new TField("", type, id); + } + + @Override + public void readFieldEnd() throws TException { + readJSONObjectEnd(); + } + + @Override + public TMap readMapBegin() throws TException { + readJSONArrayStart(); + byte keyType = getTypeIDForTypeName(readJSONString(false).get()); + byte valueType = getTypeIDForTypeName(readJSONString(false).get()); + int size = (int)readJSONInteger(); + readJSONObjectStart(); + return new TMap(keyType, valueType, size); + } + + @Override + public void readMapEnd() throws TException { + readJSONObjectEnd(); + readJSONArrayEnd(); + } + + @Override + public TList readListBegin() throws TException { + readJSONArrayStart(); + byte elemType = getTypeIDForTypeName(readJSONString(false).get()); + int size = (int)readJSONInteger(); + return new TList(elemType, size); + } + + @Override + public void readListEnd() throws TException { + readJSONArrayEnd(); + } + + @Override + public TSet readSetBegin() throws TException { + readJSONArrayStart(); + byte elemType = getTypeIDForTypeName(readJSONString(false).get()); + int size = (int)readJSONInteger(); + return new TSet(elemType, size); + } + + @Override + public void readSetEnd() throws TException { + readJSONArrayEnd(); + } + + @Override + public boolean readBool() throws TException { + return (readJSONInteger() == 0 ? false : true); + } + + @Override + public byte readByte() throws TException { + return (byte) readJSONInteger(); + } + + @Override + public short readI16() throws TException { + return (short) readJSONInteger(); + } + + @Override + public int readI32() throws TException { + return (int) readJSONInteger(); + } + + @Override + public long readI64() throws TException { + return (long) readJSONInteger(); + } + + @Override + public double readDouble() throws TException { + return readJSONDouble(); + } + + @Override + public String readString() throws TException { + try { + return readJSONString(false).toString("UTF-8"); + } + catch (UnsupportedEncodingException ex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + @Override + public byte[] readBinary() throws TException { + return readJSONBase64(); + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TList.java b/lib/java/src/org/apache/thrift/protocol/TList.java new file mode 100644 index 000000000..0d36e83d9 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TList.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates list metadata. + * + */ +public final class TList { + public TList() { + this(TType.STOP, 0); + } + + public TList(byte t, int s) { + elemType = t; + size = s; + } + + public final byte elemType; + public final int size; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TMap.java b/lib/java/src/org/apache/thrift/protocol/TMap.java new file mode 100644 index 000000000..20881f7ac --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TMap.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates map metadata. + * + */ +public final class TMap { + public TMap() { + this(TType.STOP, TType.STOP, 0); + } + + public TMap(byte k, byte v, int s) { + keyType = k; + valueType = v; + size = s; + } + + public final byte keyType; + public final byte valueType; + public final int size; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TMessage.java b/lib/java/src/org/apache/thrift/protocol/TMessage.java new file mode 100644 index 000000000..cd56964da --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TMessage.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates struct metadata. + * + */ +public final class TMessage { + public TMessage() { + this("", TType.STOP, 0); + } + + public TMessage(String n, byte t, int s) { + name = n; + type = t; + seqid = s; + } + + public final String name; + public final byte type; + public final int seqid; + + public String toString() { + return "<TMessage name:'" + name + "' type: " + type + " seqid:" + seqid + ">"; + } + + public boolean equals(TMessage other) { + return name.equals(other.name) && type == other.type && seqid == other.seqid; + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TMessageType.java b/lib/java/src/org/apache/thrift/protocol/TMessageType.java new file mode 100644 index 000000000..aa3f93177 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TMessageType.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Message type constants in the Thrift protocol. + * + */ +public final class TMessageType { + public static final byte CALL = 1; + public static final byte REPLY = 2; + public static final byte EXCEPTION = 3; + public static final byte ONEWAY = 4; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocol.java b/lib/java/src/org/apache/thrift/protocol/TProtocol.java new file mode 100644 index 000000000..50d6683df --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TProtocol.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; + +/** + * Protocol interface definition. + * + */ +public abstract class TProtocol { + + /** + * Prevent direct instantiation + */ + @SuppressWarnings("unused") + private TProtocol() {} + + /** + * Transport + */ + protected TTransport trans_; + + /** + * Constructor + */ + protected TProtocol(TTransport trans) { + trans_ = trans; + } + + /** + * Transport accessor + */ + public TTransport getTransport() { + return trans_; + } + + /** + * Writing methods. + */ + + public abstract void writeMessageBegin(TMessage message) throws TException; + + public abstract void writeMessageEnd() throws TException; + + public abstract void writeStructBegin(TStruct struct) throws TException; + + public abstract void writeStructEnd() throws TException; + + public abstract void writeFieldBegin(TField field) throws TException; + + public abstract void writeFieldEnd() throws TException; + + public abstract void writeFieldStop() throws TException; + + public abstract void writeMapBegin(TMap map) throws TException; + + public abstract void writeMapEnd() throws TException; + + public abstract void writeListBegin(TList list) throws TException; + + public abstract void writeListEnd() throws TException; + + public abstract void writeSetBegin(TSet set) throws TException; + + public abstract void writeSetEnd() throws TException; + + public abstract void writeBool(boolean b) throws TException; + + public abstract void writeByte(byte b) throws TException; + + public abstract void writeI16(short i16) throws TException; + + public abstract void writeI32(int i32) throws TException; + + public abstract void writeI64(long i64) throws TException; + + public abstract void writeDouble(double dub) throws TException; + + public abstract void writeString(String str) throws TException; + + public abstract void writeBinary(byte[] bin) throws TException; + + /** + * Reading methods. + */ + + public abstract TMessage readMessageBegin() throws TException; + + public abstract void readMessageEnd() throws TException; + + public abstract TStruct readStructBegin() throws TException; + + public abstract void readStructEnd() throws TException; + + public abstract TField readFieldBegin() throws TException; + + public abstract void readFieldEnd() throws TException; + + public abstract TMap readMapBegin() throws TException; + + public abstract void readMapEnd() throws TException; + + public abstract TList readListBegin() throws TException; + + public abstract void readListEnd() throws TException; + + public abstract TSet readSetBegin() throws TException; + + public abstract void readSetEnd() throws TException; + + public abstract boolean readBool() throws TException; + + public abstract byte readByte() throws TException; + + public abstract short readI16() throws TException; + + public abstract int readI32() throws TException; + + public abstract long readI64() throws TException; + + public abstract double readDouble() throws TException; + + public abstract String readString() throws TException; + + public abstract byte[] readBinary() throws TException; + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocolException.java b/lib/java/src/org/apache/thrift/protocol/TProtocolException.java new file mode 100644 index 000000000..248815bec --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TProtocolException.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.TException; + +/** + * Protocol exceptions. + * + */ +public class TProtocolException extends TException { + + + private static final long serialVersionUID = 1L; + public static final int UNKNOWN = 0; + public static final int INVALID_DATA = 1; + public static final int NEGATIVE_SIZE = 2; + public static final int SIZE_LIMIT = 3; + public static final int BAD_VERSION = 4; + public static final int NOT_IMPLEMENTED = 5; + + protected int type_ = UNKNOWN; + + public TProtocolException() { + super(); + } + + public TProtocolException(int type) { + super(); + type_ = type; + } + + public TProtocolException(int type, String message) { + super(message); + type_ = type; + } + + public TProtocolException(String message) { + super(message); + } + + public TProtocolException(int type, Throwable cause) { + super(cause); + type_ = type; + } + + public TProtocolException(Throwable cause) { + super(cause); + } + + public TProtocolException(String message, Throwable cause) { + super(message, cause); + } + + public TProtocolException(int type, String message, Throwable cause) { + super(message, cause); + type_ = type; + } + + public int getType() { + return type_; + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocolFactory.java b/lib/java/src/org/apache/thrift/protocol/TProtocolFactory.java new file mode 100644 index 000000000..afa502b70 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TProtocolFactory.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.transport.TTransport; + +/** + * Factory interface for constructing protocol instances. + * + */ +public interface TProtocolFactory { + public TProtocol getProtocol(TTransport trans); +} diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java b/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java new file mode 100644 index 000000000..9bf10f67e --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.TException; + +/** + * Utility class with static methods for interacting with protocol data + * streams. + * + */ +public class TProtocolUtil { + + /** + * The maximum recursive depth the skip() function will traverse before + * throwing a TException. + */ + private static int maxSkipDepth = Integer.MAX_VALUE; + + /** + * Specifies the maximum recursive depth that the skip function will + * traverse before throwing a TException. This is a global setting, so + * any call to skip in this JVM will enforce this value. + * + * @param depth the maximum recursive depth. A value of 2 would allow + * the skip function to skip a structure or collection with basic children, + * but it would not permit skipping a struct that had a field containing + * a child struct. A value of 1 would only allow skipping of simple + * types and empty structs/collections. + */ + public static void setMaxSkipDepth(int depth) { + maxSkipDepth = depth; + } + + /** + * Skips over the next data element from the provided input TProtocol object. + * + * @param prot the protocol object to read from + * @param type the next value will be intepreted as this TType value. + */ + public static void skip(TProtocol prot, byte type) + throws TException { + skip(prot, type, maxSkipDepth); + } + + /** + * Skips over the next data element from the provided input TProtocol object. + * + * @param prot the protocol object to read from + * @param type the next value will be intepreted as this TType value. + * @param maxDepth this function will only skip complex objects to this + * recursive depth, to prevent Java stack overflow. + */ + public static void skip(TProtocol prot, byte type, int maxDepth) + throws TException { + if (maxDepth <= 0) { + throw new TException("Maximum skip depth exceeded"); + } + switch (type) { + case TType.BOOL: + { + prot.readBool(); + break; + } + case TType.BYTE: + { + prot.readByte(); + break; + } + case TType.I16: + { + prot.readI16(); + break; + } + case TType.I32: + { + prot.readI32(); + break; + } + case TType.I64: + { + prot.readI64(); + break; + } + case TType.DOUBLE: + { + prot.readDouble(); + break; + } + case TType.STRING: + { + prot.readBinary(); + break; + } + case TType.STRUCT: + { + prot.readStructBegin(); + while (true) { + TField field = prot.readFieldBegin(); + if (field.type == TType.STOP) { + break; + } + skip(prot, field.type, maxDepth - 1); + prot.readFieldEnd(); + } + prot.readStructEnd(); + break; + } + case TType.MAP: + { + TMap map = prot.readMapBegin(); + for (int i = 0; i < map.size; i++) { + skip(prot, map.keyType, maxDepth - 1); + skip(prot, map.valueType, maxDepth - 1); + } + prot.readMapEnd(); + break; + } + case TType.SET: + { + TSet set = prot.readSetBegin(); + for (int i = 0; i < set.size; i++) { + skip(prot, set.elemType, maxDepth - 1); + } + prot.readSetEnd(); + break; + } + case TType.LIST: + { + TList list = prot.readListBegin(); + for (int i = 0; i < list.size; i++) { + skip(prot, list.elemType, maxDepth - 1); + } + prot.readListEnd(); + break; + } + default: + break; + } + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TSet.java b/lib/java/src/org/apache/thrift/protocol/TSet.java new file mode 100644 index 000000000..38be9a991 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TSet.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates set metadata. + * + */ +public final class TSet { + public TSet() { + this(TType.STOP, 0); + } + + public TSet(byte t, int s) { + elemType = t; + size = s; + } + + public TSet(TList list) { + this(list.elemType, list.size); + } + + public final byte elemType; + public final int size; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TSimpleJSONProtocol.java b/lib/java/src/org/apache/thrift/protocol/TSimpleJSONProtocol.java new file mode 100644 index 000000000..a60bdf407 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TSimpleJSONProtocol.java @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import java.io.UnsupportedEncodingException; +import java.util.Stack; + +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; + +/** + * JSON protocol implementation for thrift. + * + * This protocol is write-only and produces a simple output format + * suitable for parsing by scripting languages. It should not be + * confused with the full-featured TJSONProtocol. + * + */ +public class TSimpleJSONProtocol extends TProtocol { + + /** + * Factory + */ + public static class Factory implements TProtocolFactory { + public TProtocol getProtocol(TTransport trans) { + return new TSimpleJSONProtocol(trans); + } + } + + public static final byte[] COMMA = new byte[] {','}; + public static final byte[] COLON = new byte[] {':'}; + public static final byte[] LBRACE = new byte[] {'{'}; + public static final byte[] RBRACE = new byte[] {'}'}; + public static final byte[] LBRACKET = new byte[] {'['}; + public static final byte[] RBRACKET = new byte[] {']'}; + public static final char QUOTE = '"'; + + private static final TStruct ANONYMOUS_STRUCT = new TStruct(); + private static final TField ANONYMOUS_FIELD = new TField(); + private static final TMessage EMPTY_MESSAGE = new TMessage(); + private static final TSet EMPTY_SET = new TSet(); + private static final TList EMPTY_LIST = new TList(); + private static final TMap EMPTY_MAP = new TMap(); + + protected class Context { + protected void write() throws TException {} + } + + protected class ListContext extends Context { + protected boolean first_ = true; + + protected void write() throws TException { + if (first_) { + first_ = false; + } else { + trans_.write(COMMA); + } + } + } + + protected class StructContext extends Context { + protected boolean first_ = true; + protected boolean colon_ = true; + + protected void write() throws TException { + if (first_) { + first_ = false; + colon_ = true; + } else { + trans_.write(colon_ ? COLON : COMMA); + colon_ = !colon_; + } + } + } + + protected final Context BASE_CONTEXT = new Context(); + + /** + * Stack of nested contexts that we may be in. + */ + protected Stack<Context> writeContextStack_ = new Stack<Context>(); + + /** + * Current context that we are in + */ + protected Context writeContext_ = BASE_CONTEXT; + + /** + * Push a new write context onto the stack. + */ + protected void pushWriteContext(Context c) { + writeContextStack_.push(writeContext_); + writeContext_ = c; + } + + /** + * Pop the last write context off the stack + */ + protected void popWriteContext() { + writeContext_ = writeContextStack_.pop(); + } + + /** + * Constructor + */ + public TSimpleJSONProtocol(TTransport trans) { + super(trans); + } + + public void writeMessageBegin(TMessage message) throws TException { + trans_.write(LBRACKET); + pushWriteContext(new ListContext()); + writeString(message.name); + writeByte(message.type); + writeI32(message.seqid); + } + + public void writeMessageEnd() throws TException { + popWriteContext(); + trans_.write(RBRACKET); + } + + public void writeStructBegin(TStruct struct) throws TException { + writeContext_.write(); + trans_.write(LBRACE); + pushWriteContext(new StructContext()); + } + + public void writeStructEnd() throws TException { + popWriteContext(); + trans_.write(RBRACE); + } + + public void writeFieldBegin(TField field) throws TException { + // Note that extra type information is omitted in JSON! + writeString(field.name); + } + + public void writeFieldEnd() {} + + public void writeFieldStop() {} + + public void writeMapBegin(TMap map) throws TException { + writeContext_.write(); + trans_.write(LBRACE); + pushWriteContext(new StructContext()); + // No metadata! + } + + public void writeMapEnd() throws TException { + popWriteContext(); + trans_.write(RBRACE); + } + + public void writeListBegin(TList list) throws TException { + writeContext_.write(); + trans_.write(LBRACKET); + pushWriteContext(new ListContext()); + // No metadata! + } + + public void writeListEnd() throws TException { + popWriteContext(); + trans_.write(RBRACKET); + } + + public void writeSetBegin(TSet set) throws TException { + writeContext_.write(); + trans_.write(LBRACKET); + pushWriteContext(new ListContext()); + // No metadata! + } + + public void writeSetEnd() throws TException { + popWriteContext(); + trans_.write(RBRACKET); + } + + public void writeBool(boolean b) throws TException { + writeByte(b ? (byte)1 : (byte)0); + } + + public void writeByte(byte b) throws TException { + writeI32(b); + } + + public void writeI16(short i16) throws TException { + writeI32(i16); + } + + public void writeI32(int i32) throws TException { + writeContext_.write(); + _writeStringData(Integer.toString(i32)); + } + + public void _writeStringData(String s) throws TException { + try { + byte[] b = s.getBytes("UTF-8"); + trans_.write(b); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + public void writeI64(long i64) throws TException { + writeContext_.write(); + _writeStringData(Long.toString(i64)); + } + + public void writeDouble(double dub) throws TException { + writeContext_.write(); + _writeStringData(Double.toString(dub)); + } + + public void writeString(String str) throws TException { + writeContext_.write(); + int length = str.length(); + StringBuffer escape = new StringBuffer(length + 16); + escape.append(QUOTE); + for (int i = 0; i < length; ++i) { + char c = str.charAt(i); + switch (c) { + case '"': + case '\\': + escape.append('\\'); + escape.append(c); + break; + case '\b': + escape.append('\\'); + escape.append('b'); + break; + case '\f': + escape.append('\\'); + escape.append('f'); + break; + case '\n': + escape.append('\\'); + escape.append('n'); + break; + case '\r': + escape.append('\\'); + escape.append('r'); + break; + case '\t': + escape.append('\\'); + escape.append('t'); + break; + default: + // Control characeters! According to JSON RFC u0020 (space) + if (c < ' ') { + String hex = Integer.toHexString(c); + escape.append('\\'); + escape.append('u'); + for (int j = 4; j > hex.length(); --j) { + escape.append('0'); + } + escape.append(hex); + } else { + escape.append(c); + } + break; + } + } + escape.append(QUOTE); + _writeStringData(escape.toString()); + } + + public void writeBinary(byte[] bin) throws TException { + try { + // TODO(mcslee): Fix this + writeString(new String(bin, "UTF-8")); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + /** + * Reading methods. + */ + + public TMessage readMessageBegin() throws TException { + // TODO(mcslee): implement + return EMPTY_MESSAGE; + } + + public void readMessageEnd() {} + + public TStruct readStructBegin() { + // TODO(mcslee): implement + return ANONYMOUS_STRUCT; + } + + public void readStructEnd() {} + + public TField readFieldBegin() throws TException { + // TODO(mcslee): implement + return ANONYMOUS_FIELD; + } + + public void readFieldEnd() {} + + public TMap readMapBegin() throws TException { + // TODO(mcslee): implement + return EMPTY_MAP; + } + + public void readMapEnd() {} + + public TList readListBegin() throws TException { + // TODO(mcslee): implement + return EMPTY_LIST; + } + + public void readListEnd() {} + + public TSet readSetBegin() throws TException { + // TODO(mcslee): implement + return EMPTY_SET; + } + + public void readSetEnd() {} + + public boolean readBool() throws TException { + return (readByte() == 1); + } + + public byte readByte() throws TException { + // TODO(mcslee): implement + return 0; + } + + public short readI16() throws TException { + // TODO(mcslee): implement + return 0; + } + + public int readI32() throws TException { + // TODO(mcslee): implement + return 0; + } + + public long readI64() throws TException { + // TODO(mcslee): implement + return 0; + } + + public double readDouble() throws TException { + // TODO(mcslee): implement + return 0; + } + + public String readString() throws TException { + // TODO(mcslee): implement + return ""; + } + + public String readStringBody(int size) throws TException { + // TODO(mcslee): implement + return ""; + } + + public byte[] readBinary() throws TException { + // TODO(mcslee): implement + return new byte[0]; + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TStruct.java b/lib/java/src/org/apache/thrift/protocol/TStruct.java new file mode 100644 index 000000000..a0f79012a --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TStruct.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates struct metadata. + * + */ +public final class TStruct { + public TStruct() { + this(""); + } + + public TStruct(String n) { + name = n; + } + + public final String name; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TType.java b/lib/java/src/org/apache/thrift/protocol/TType.java new file mode 100644 index 000000000..dbdc3caa8 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TType.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Type constants in the Thrift protocol. + * + */ +public final class TType { + public static final byte STOP = 0; + public static final byte VOID = 1; + public static final byte BOOL = 2; + public static final byte BYTE = 3; + public static final byte DOUBLE = 4; + public static final byte I16 = 6; + public static final byte I32 = 8; + public static final byte I64 = 10; + public static final byte STRING = 11; + public static final byte STRUCT = 12; + public static final byte MAP = 13; + public static final byte SET = 14; + public static final byte LIST = 15; +} diff --git a/lib/java/src/org/apache/thrift/server/THsHaServer.java b/lib/java/src/org/apache/thrift/server/THsHaServer.java new file mode 100644 index 000000000..8bf096ed6 --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/THsHaServer.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.server; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import org.apache.thrift.TProcessor; +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TFramedTransport; +import org.apache.thrift.transport.TNonblockingServerTransport; + +/** + * An extension of the TNonblockingServer to a Half-Sync/Half-Async server. + * Like TNonblockingServer, it relies on the use of TFramedTransport. + */ +public class THsHaServer extends TNonblockingServer { + + // This wraps all the functionality of queueing and thread pool management + // for the passing of Invocations from the Selector to workers. + private ExecutorService invoker; + + protected final int MIN_WORKER_THREADS; + protected final int MAX_WORKER_THREADS; + protected final int STOP_TIMEOUT_VAL; + protected final TimeUnit STOP_TIMEOUT_UNIT; + + /** + * Create server with given processor, and server transport. Default server + * options, TBinaryProtocol for the protocol, and TFramedTransport.Factory on + * both input and output transports. A TProcessorFactory will be created that + * always returns the specified processor. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport) { + this(processor, serverTransport, new Options()); + } + + /** + * Create server with given processor, server transport, and server options + * using TBinaryProtocol for the protocol, and TFramedTransport.Factory on + * both input and output transports. A TProcessorFactory will be created that + * always returns the specified processor. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + Options options) { + this(new TProcessorFactory(processor), serverTransport, options); + } + + /** + * Create server with specified processor factory and server transport. Uses + * default options. TBinaryProtocol is assumed. TFramedTransport.Factory is + * used on both input and output transports. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport) { + this(processorFactory, serverTransport, new Options()); + } + + /** + * Create server with specified processor factory, server transport, and server + * options. TBinaryProtocol is assumed. TFramedTransport.Factory is used on + * both input and output transports. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + Options options) { + this(processorFactory, serverTransport, new TFramedTransport.Factory(), + new TBinaryProtocol.Factory(), options); + } + + /** + * Server with specified processor, server transport, and in/out protocol + * factory. Defaults will be used for in/out transport factory and server + * options. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, protocolFactory, new Options()); + } + + /** + * Server with specified processor, server transport, and in/out protocol + * factory. Defaults will be used for in/out transport factory and server + * options. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + TProtocolFactory protocolFactory, + Options options) { + this(processor, serverTransport, new TFramedTransport.Factory(), + protocolFactory); + } + + /** + * Create server with specified processor, server transport, in/out + * transport factory, in/out protocol factory, and default server options. A + * processor factory will be created that always returns the specified + * processor. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory) { + this(new TProcessorFactory(processor), serverTransport, + transportFactory, protocolFactory); + } + + /** + * Create server with specified processor factory, server transport, in/out + * transport factory, in/out protocol factory, and default server options. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory) { + this(processorFactory, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory, new Options()); + } + + /** + * Create server with specified processor factory, server transport, in/out + * transport factory, in/out protocol factory, and server options. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory, + Options options) { + this(processorFactory, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory, + options); + } + + /** + * Create server with everything specified, except use default server options. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + this(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + /** + * Create server with everything specified, except use default server options. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) + { + this(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, new Options()); + } + + /** + * Create server with every option fully specified. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + Options options) + { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, + options); + + MIN_WORKER_THREADS = options.minWorkerThreads; + MAX_WORKER_THREADS = options.maxWorkerThreads; + STOP_TIMEOUT_VAL = options.stopTimeoutVal; + STOP_TIMEOUT_UNIT = options.stopTimeoutUnit; + } + + /** @inheritDoc */ + @Override + public void serve() { + if (!startInvokerPool()) { + return; + } + + // start listening, or exit + if (!startListening()) { + return; + } + + // start the selector, or exit + if (!startSelectorThread()) { + return; + } + + // this will block while we serve + joinSelector(); + + gracefullyShutdownInvokerPool(); + + // do a little cleanup + stopListening(); + + // ungracefully shut down the invoker pool? + } + + protected boolean startInvokerPool() { + // start the invoker pool + LinkedBlockingQueue<Runnable> queue = new LinkedBlockingQueue<Runnable>(); + invoker = new ThreadPoolExecutor(MIN_WORKER_THREADS, MAX_WORKER_THREADS, + STOP_TIMEOUT_VAL, STOP_TIMEOUT_UNIT, queue); + + return true; + } + + protected void gracefullyShutdownInvokerPool() { + // try to gracefully shut down the executor service + invoker.shutdown(); + + // Loop until awaitTermination finally does return without a interrupted + // exception. If we don't do this, then we'll shut down prematurely. We want + // to let the executorService clear it's task queue, closing client sockets + // appropriately. + long timeoutMS = 10000; + long now = System.currentTimeMillis(); + while (timeoutMS >= 0) { + try { + invoker.awaitTermination(timeoutMS, TimeUnit.MILLISECONDS); + break; + } catch (InterruptedException ix) { + long newnow = System.currentTimeMillis(); + timeoutMS -= (newnow - now); + now = newnow; + } + } + } + + /** + * We override the standard invoke method here to queue the invocation for + * invoker service instead of immediately invoking. The thread pool takes care of the rest. + */ + @Override + protected void requestInvoke(FrameBuffer frameBuffer) { + invoker.execute(new Invocation(frameBuffer)); + } + + /** + * An Invocation represents a method call that is prepared to execute, given + * an idle worker thread. It contains the input and output protocols the + * thread's processor should use to perform the usual Thrift invocation. + */ + private class Invocation implements Runnable { + + private final FrameBuffer frameBuffer; + + public Invocation(final FrameBuffer frameBuffer) { + this.frameBuffer = frameBuffer; + } + + public void run() { + frameBuffer.invoke(); + } + } + + public static class Options extends TNonblockingServer.Options { + public int minWorkerThreads = 5; + public int maxWorkerThreads = Integer.MAX_VALUE; + public int stopTimeoutVal = 60; + public TimeUnit stopTimeoutUnit = TimeUnit.SECONDS; + } +} diff --git a/lib/java/src/org/apache/thrift/server/TNonblockingServer.java b/lib/java/src/org/apache/thrift/server/TNonblockingServer.java new file mode 100644 index 000000000..95d81e223 --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/TNonblockingServer.java @@ -0,0 +1,769 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.server; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.spi.SelectorProvider; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; + +import org.apache.log4j.Logger; + +import org.apache.thrift.TByteArrayOutputStream; +import org.apache.thrift.TException; +import org.apache.thrift.TProcessor; +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TFramedTransport; +import org.apache.thrift.transport.TIOStreamTransport; +import org.apache.thrift.transport.TNonblockingServerTransport; +import org.apache.thrift.transport.TNonblockingTransport; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; + +/** + * A nonblocking TServer implementation. This allows for fairness amongst all + * connected clients in terms of invocations. + * + * This server is inherently single-threaded. If you want a limited thread pool + * coupled with invocation-fairness, see THsHaServer. + * + * To use this server, you MUST use a TFramedTransport at the outermost + * transport, otherwise this server will be unable to determine when a whole + * method call has been read off the wire. Clients must also use TFramedTransport. + */ +public class TNonblockingServer extends TServer { + private static final Logger LOGGER = + Logger.getLogger(TNonblockingServer.class.getName()); + + // Flag for stopping the server + private volatile boolean stopped_; + + private SelectThread selectThread_; + + /** + * The maximum amount of memory we will allocate to client IO buffers at a + * time. Without this limit, the server will gladly allocate client buffers + * right into an out of memory exception, rather than waiting. + */ + private final long MAX_READ_BUFFER_BYTES; + + protected final Options options_; + + /** + * How many bytes are currently allocated to read buffers. + */ + private long readBufferBytesAllocated = 0; + + /** + * Create server with given processor and server transport, using + * TBinaryProtocol for the protocol, TFramedTransport.Factory on both input + * and output transports. A TProcessorFactory will be created that always + * returns the specified processor. + */ + public TNonblockingServer(TProcessor processor, + TNonblockingServerTransport serverTransport) { + this(new TProcessorFactory(processor), serverTransport); + } + + /** + * Create server with specified processor factory and server transport. + * TBinaryProtocol is assumed. TFramedTransport.Factory is used on both input + * and output transports. + */ + public TNonblockingServer(TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport) { + this(processorFactory, serverTransport, + new TFramedTransport.Factory(), new TFramedTransport.Factory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory()); + } + + public TNonblockingServer(TProcessor processor, + TNonblockingServerTransport serverTransport, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, + new TFramedTransport.Factory(), new TFramedTransport.Factory(), + protocolFactory, protocolFactory); + } + + public TNonblockingServer(TProcessor processor, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory); + } + + public TNonblockingServer(TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory) { + this(processorFactory, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory); + } + + public TNonblockingServer(TProcessor processor, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + this(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + public TNonblockingServer(TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + this(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, + new Options()); + } + + public TNonblockingServer(TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + Options options) { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + options_ = options; + options_.validate(); + MAX_READ_BUFFER_BYTES = options.maxReadBufferBytes; + } + + /** + * Begin accepting connections and processing invocations. + */ + public void serve() { + // start listening, or exit + if (!startListening()) { + return; + } + + // start the selector, or exit + if (!startSelectorThread()) { + return; + } + + // this will block while we serve + joinSelector(); + + // do a little cleanup + stopListening(); + } + + /** + * Have the server transport start accepting connections. + * + * @return true if we started listening successfully, false if something went + * wrong. + */ + protected boolean startListening() { + try { + serverTransport_.listen(); + return true; + } catch (TTransportException ttx) { + LOGGER.error("Failed to start listening on server socket!", ttx); + return false; + } + } + + /** + * Stop listening for conections. + */ + protected void stopListening() { + serverTransport_.close(); + } + + /** + * Start the selector thread running to deal with clients. + * + * @return true if everything went ok, false if we couldn't start for some + * reason. + */ + protected boolean startSelectorThread() { + // start the selector + try { + selectThread_ = new SelectThread((TNonblockingServerTransport)serverTransport_); + selectThread_.start(); + return true; + } catch (IOException e) { + LOGGER.error("Failed to start selector thread!", e); + return false; + } + } + + /** + * Block until the selector exits. + */ + protected void joinSelector() { + // wait until the selector thread exits + try { + selectThread_.join(); + } catch (InterruptedException e) { + // for now, just silently ignore. technically this means we'll have less of + // a graceful shutdown as a result. + } + } + + /** + * Stop serving and shut everything down. + */ + public void stop() { + stopped_ = true; + selectThread_.wakeupSelector(); + } + + /** + * Perform an invocation. This method could behave several different ways + * - invoke immediately inline, queue for separate execution, etc. + */ + protected void requestInvoke(FrameBuffer frameBuffer) { + frameBuffer.invoke(); + } + + /** + * A FrameBuffer wants to change its selection preferences, but might not be + * in the select thread. + */ + protected void requestSelectInterestChange(FrameBuffer frameBuffer) { + selectThread_.requestSelectInterestChange(frameBuffer); + } + + /** + * The thread that will be doing all the selecting, managing new connections + * and those that still need to be read. + */ + protected class SelectThread extends Thread { + + private final TNonblockingServerTransport serverTransport; + private final Selector selector; + + // List of FrameBuffers that want to change their selection interests. + private final Set<FrameBuffer> selectInterestChanges = + new HashSet<FrameBuffer>(); + + /** + * Set up the SelectorThread. + */ + public SelectThread(final TNonblockingServerTransport serverTransport) + throws IOException { + this.serverTransport = serverTransport; + this.selector = SelectorProvider.provider().openSelector(); + serverTransport.registerSelector(selector); + } + + /** + * The work loop. Handles both selecting (all IO operations) and managing + * the selection preferences of all existing connections. + */ + public void run() { + while (!stopped_) { + select(); + processInterestChanges(); + } + } + + /** + * If the selector is blocked, wake it up. + */ + public void wakeupSelector() { + selector.wakeup(); + } + + /** + * Add FrameBuffer to the list of select interest changes and wake up the + * selector if it's blocked. When the select() call exits, it'll give the + * FrameBuffer a chance to change its interests. + */ + public void requestSelectInterestChange(FrameBuffer frameBuffer) { + synchronized (selectInterestChanges) { + selectInterestChanges.add(frameBuffer); + } + // wakeup the selector, if it's currently blocked. + selector.wakeup(); + } + + /** + * Select and process IO events appropriately: + * If there are connections to be accepted, accept them. + * If there are existing connections with data waiting to be read, read it, + * bufferring until a whole frame has been read. + * If there are any pending responses, buffer them until their target client + * is available, and then send the data. + */ + private void select() { + try { + // wait for io events. + selector.select(); + + // process the io events we received + Iterator<SelectionKey> selectedKeys = selector.selectedKeys().iterator(); + while (!stopped_ && selectedKeys.hasNext()) { + SelectionKey key = selectedKeys.next(); + selectedKeys.remove(); + + // skip if not valid + if (!key.isValid()) { + cleanupSelectionkey(key); + continue; + } + + // if the key is marked Accept, then it has to be the server + // transport. + if (key.isAcceptable()) { + handleAccept(); + } else if (key.isReadable()) { + // deal with reads + handleRead(key); + } else if (key.isWritable()) { + // deal with writes + handleWrite(key); + } else { + LOGGER.warn("Unexpected state in select! " + key.interestOps()); + } + } + } catch (IOException e) { + LOGGER.warn("Got an IOException while selecting!", e); + } + } + + /** + * Check to see if there are any FrameBuffers that have switched their + * interest type from read to write or vice versa. + */ + private void processInterestChanges() { + synchronized (selectInterestChanges) { + for (FrameBuffer fb : selectInterestChanges) { + fb.changeSelectInterests(); + } + selectInterestChanges.clear(); + } + } + + /** + * Accept a new connection. + */ + private void handleAccept() throws IOException { + SelectionKey clientKey = null; + TNonblockingTransport client = null; + try { + // accept the connection + client = (TNonblockingTransport)serverTransport.accept(); + clientKey = client.registerSelector(selector, SelectionKey.OP_READ); + + // add this key to the map + FrameBuffer frameBuffer = new FrameBuffer(client, clientKey); + clientKey.attach(frameBuffer); + } catch (TTransportException tte) { + // something went wrong accepting. + LOGGER.warn("Exception trying to accept!", tte); + tte.printStackTrace(); + if (clientKey != null) cleanupSelectionkey(clientKey); + if (client != null) client.close(); + } + } + + /** + * Do the work required to read from a readable client. If the frame is + * fully read, then invoke the method call. + */ + private void handleRead(SelectionKey key) { + FrameBuffer buffer = (FrameBuffer)key.attachment(); + if (buffer.read()) { + // if the buffer's frame read is complete, invoke the method. + if (buffer.isFrameFullyRead()) { + requestInvoke(buffer); + } + } else { + cleanupSelectionkey(key); + } + } + + /** + * Let a writable client get written, if there's data to be written. + */ + private void handleWrite(SelectionKey key) { + FrameBuffer buffer = (FrameBuffer)key.attachment(); + if (!buffer.write()) { + cleanupSelectionkey(key); + } + } + + /** + * Do connection-close cleanup on a given SelectionKey. + */ + private void cleanupSelectionkey(SelectionKey key) { + // remove the records from the two maps + FrameBuffer buffer = (FrameBuffer)key.attachment(); + if (buffer != null) { + // close the buffer + buffer.close(); + } + // cancel the selection key + key.cancel(); + } + } // SelectorThread + + /** + * Class that implements a sort of state machine around the interaction with + * a client and an invoker. It manages reading the frame size and frame data, + * getting it handed off as wrapped transports, and then the writing of + * reponse data back to the client. In the process it manages flipping the + * read and write bits on the selection key for its client. + */ + protected class FrameBuffer { + // + // Possible states for the FrameBuffer state machine. + // + // in the midst of reading the frame size off the wire + private static final int READING_FRAME_SIZE = 1; + // reading the actual frame data now, but not all the way done yet + private static final int READING_FRAME = 2; + // completely read the frame, so an invocation can now happen + private static final int READ_FRAME_COMPLETE = 3; + // waiting to get switched to listening for write events + private static final int AWAITING_REGISTER_WRITE = 4; + // started writing response data, not fully complete yet + private static final int WRITING = 6; + // another thread wants this framebuffer to go back to reading + private static final int AWAITING_REGISTER_READ = 7; + // we want our transport and selection key invalidated in the selector thread + private static final int AWAITING_CLOSE = 8; + + // + // Instance variables + // + + // the actual transport hooked up to the client. + private final TNonblockingTransport trans_; + + // the SelectionKey that corresponds to our transport + private final SelectionKey selectionKey_; + + // where in the process of reading/writing are we? + private int state_ = READING_FRAME_SIZE; + + // the ByteBuffer we'll be using to write and read, depending on the state + private ByteBuffer buffer_; + + private TByteArrayOutputStream response_; + + public FrameBuffer( final TNonblockingTransport trans, + final SelectionKey selectionKey) { + trans_ = trans; + selectionKey_ = selectionKey; + buffer_ = ByteBuffer.allocate(4); + } + + /** + * Give this FrameBuffer a chance to read. The selector loop should have + * received a read event for this FrameBuffer. + * + * @return true if the connection should live on, false if it should be + * closed + */ + public boolean read() { + if (state_ == READING_FRAME_SIZE) { + // try to read the frame size completely + if (!internalRead()) { + return false; + } + + // if the frame size has been read completely, then prepare to read the + // actual frame. + if (buffer_.remaining() == 0) { + // pull out the frame size as an integer. + int frameSize = buffer_.getInt(0); + if (frameSize <= 0) { + LOGGER.error("Read an invalid frame size of " + frameSize + + ". Are you using TFramedTransport on the client side?"); + return false; + } + + // if this frame will always be too large for this server, log the + // error and close the connection. + if (frameSize + 4 > MAX_READ_BUFFER_BYTES) { + LOGGER.error("Read a frame size of " + frameSize + + ", which is bigger than the maximum allowable buffer size for ALL connections."); + return false; + } + + // if this frame will push us over the memory limit, then return. + // with luck, more memory will free up the next time around. + if (readBufferBytesAllocated + frameSize + 4 > MAX_READ_BUFFER_BYTES) { + return true; + } + + // incremement the amount of memory allocated to read buffers + readBufferBytesAllocated += frameSize + 4; + + // reallocate the readbuffer as a frame-sized buffer + buffer_ = ByteBuffer.allocate(frameSize + 4); + // put the frame size at the head of the buffer + buffer_.putInt(frameSize); + + state_ = READING_FRAME; + } else { + // this skips the check of READING_FRAME state below, since we can't + // possibly go on to that state if there's data left to be read at + // this one. + return true; + } + } + + // it is possible to fall through from the READING_FRAME_SIZE section + // to READING_FRAME if there's already some frame data available once + // READING_FRAME_SIZE is complete. + + if (state_ == READING_FRAME) { + if (!internalRead()) { + return false; + } + + // since we're already in the select loop here for sure, we can just + // modify our selection key directly. + if (buffer_.remaining() == 0) { + // get rid of the read select interests + selectionKey_.interestOps(0); + state_ = READ_FRAME_COMPLETE; + } + + return true; + } + + // if we fall through to this point, then the state must be invalid. + LOGGER.error("Read was called but state is invalid (" + state_ + ")"); + return false; + } + + /** + * Give this FrameBuffer a chance to write its output to the final client. + */ + public boolean write() { + if (state_ == WRITING) { + try { + if (trans_.write(buffer_) < 0) { + return false; + } + } catch (IOException e) { + LOGGER.warn("Got an IOException during write!", e); + return false; + } + + // we're done writing. now we need to switch back to reading. + if (buffer_.remaining() == 0) { + prepareRead(); + } + return true; + } + + LOGGER.error("Write was called, but state is invalid (" + state_ + ")"); + return false; + } + + /** + * Give this FrameBuffer a chance to set its interest to write, once data + * has come in. + */ + public void changeSelectInterests() { + if (state_ == AWAITING_REGISTER_WRITE) { + // set the OP_WRITE interest + selectionKey_.interestOps(SelectionKey.OP_WRITE); + state_ = WRITING; + } else if (state_ == AWAITING_REGISTER_READ) { + prepareRead(); + } else if (state_ == AWAITING_CLOSE){ + close(); + selectionKey_.cancel(); + } else { + LOGGER.error( + "changeSelectInterest was called, but state is invalid (" + + state_ + ")"); + } + } + + /** + * Shut the connection down. + */ + public void close() { + // if we're being closed due to an error, we might have allocated a + // buffer that we need to subtract for our memory accounting. + if (state_ == READING_FRAME || state_ == READ_FRAME_COMPLETE) { + readBufferBytesAllocated -= buffer_.array().length; + } + trans_.close(); + } + + /** + * Check if this FrameBuffer has a full frame read. + */ + public boolean isFrameFullyRead() { + return state_ == READ_FRAME_COMPLETE; + } + + /** + * After the processor has processed the invocation, whatever thread is + * managing invocations should call this method on this FrameBuffer so we + * know it's time to start trying to write again. Also, if it turns out + * that there actually isn't any data in the response buffer, we'll skip + * trying to write and instead go back to reading. + */ + public void responseReady() { + // the read buffer is definitely no longer in use, so we will decrement + // our read buffer count. we do this here as well as in close because + // we'd like to free this read memory up as quickly as possible for other + // clients. + readBufferBytesAllocated -= buffer_.array().length; + + if (response_.len() == 0) { + // go straight to reading again. this was probably an oneway method + state_ = AWAITING_REGISTER_READ; + buffer_ = null; + } else { + buffer_ = ByteBuffer.wrap(response_.get(), 0, response_.len()); + + // set state that we're waiting to be switched to write. we do this + // asynchronously through requestSelectInterestChange() because there is a + // possibility that we're not in the main thread, and thus currently + // blocked in select(). (this functionality is in place for the sake of + // the HsHa server.) + state_ = AWAITING_REGISTER_WRITE; + } + requestSelectInterestChange(); + } + + /** + * Actually invoke the method signified by this FrameBuffer. + */ + public void invoke() { + TTransport inTrans = getInputTransport(); + TProtocol inProt = inputProtocolFactory_.getProtocol(inTrans); + TProtocol outProt = outputProtocolFactory_.getProtocol(getOutputTransport()); + + try { + processorFactory_.getProcessor(inTrans).process(inProt, outProt); + responseReady(); + return; + } catch (TException te) { + LOGGER.warn("Exception while invoking!", te); + } catch (Exception e) { + LOGGER.error("Unexpected exception while invoking!", e); + } + // This will only be reached when there is an exception. + state_ = AWAITING_CLOSE; + requestSelectInterestChange(); + } + + /** + * Wrap the read buffer in a memory-based transport so a processor can read + * the data it needs to handle an invocation. + */ + private TTransport getInputTransport() { + return inputTransportFactory_.getTransport(new TIOStreamTransport( + new ByteArrayInputStream(buffer_.array()))); + } + + /** + * Get the transport that should be used by the invoker for responding. + */ + private TTransport getOutputTransport() { + response_ = new TByteArrayOutputStream(); + return outputTransportFactory_.getTransport(new TIOStreamTransport(response_)); + } + + /** + * Perform a read into buffer. + * + * @return true if the read succeeded, false if there was an error or the + * connection closed. + */ + private boolean internalRead() { + try { + if (trans_.read(buffer_) < 0) { + return false; + } + return true; + } catch (IOException e) { + LOGGER.warn("Got an IOException in internalRead!", e); + return false; + } + } + + /** + * We're done writing, so reset our interest ops and change state accordingly. + */ + private void prepareRead() { + // we can set our interest directly without using the queue because + // we're in the select thread. + selectionKey_.interestOps(SelectionKey.OP_READ); + // get ready for another go-around + buffer_ = ByteBuffer.allocate(4); + state_ = READING_FRAME_SIZE; + } + + /** + * When this FrameBuffer needs to change it's select interests and execution + * might not be in the select thread, then this method will make sure the + * interest change gets done when the select thread wakes back up. When the + * current thread is the select thread, then it just does the interest change + * immediately. + */ + private void requestSelectInterestChange() { + if (Thread.currentThread() == selectThread_) { + changeSelectInterests(); + } else { + TNonblockingServer.this.requestSelectInterestChange(this); + } + } + } // FrameBuffer + + + public static class Options { + public long maxReadBufferBytes = Long.MAX_VALUE; + + public Options() {} + + public void validate() { + if (maxReadBufferBytes <= 1024) { + throw new IllegalArgumentException("You must allocate at least 1KB to the read buffer."); + } + } + } +} diff --git a/lib/java/src/org/apache/thrift/server/TServer.java b/lib/java/src/org/apache/thrift/server/TServer.java new file mode 100644 index 000000000..eafe0c173 --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/TServer.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.server; + +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TServerTransport; +import org.apache.thrift.transport.TTransportFactory; + +/** + * Generic interface for a Thrift server. + * + */ +public abstract class TServer { + + /** + * Core processor + */ + protected TProcessorFactory processorFactory_; + + /** + * Server transport + */ + protected TServerTransport serverTransport_; + + /** + * Input Transport Factory + */ + protected TTransportFactory inputTransportFactory_; + + /** + * Output Transport Factory + */ + protected TTransportFactory outputTransportFactory_; + + /** + * Input Protocol Factory + */ + protected TProtocolFactory inputProtocolFactory_; + + /** + * Output Protocol Factory + */ + protected TProtocolFactory outputProtocolFactory_; + + /** + * Default constructors. + */ + + protected TServer(TProcessorFactory processorFactory, + TServerTransport serverTransport) { + this(processorFactory, + serverTransport, + new TTransportFactory(), + new TTransportFactory(), + new TBinaryProtocol.Factory(), + new TBinaryProtocol.Factory()); + } + + protected TServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory transportFactory) { + this(processorFactory, + serverTransport, + transportFactory, + transportFactory, + new TBinaryProtocol.Factory(), + new TBinaryProtocol.Factory()); + } + + protected TServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + this(processorFactory, + serverTransport, + transportFactory, + transportFactory, + protocolFactory, + protocolFactory); + } + + protected TServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + processorFactory_ = processorFactory; + serverTransport_ = serverTransport; + inputTransportFactory_ = inputTransportFactory; + outputTransportFactory_ = outputTransportFactory; + inputProtocolFactory_ = inputProtocolFactory; + outputProtocolFactory_ = outputProtocolFactory; + } + + /** + * The run method fires up the server and gets things going. + */ + public abstract void serve(); + + /** + * Stop the server. This is optional on a per-implementation basis. Not + * all servers are required to be cleanly stoppable. + */ + public void stop() {} + +} diff --git a/lib/java/src/org/apache/thrift/server/TSimpleServer.java b/lib/java/src/org/apache/thrift/server/TSimpleServer.java new file mode 100644 index 000000000..b3ee5ad6e --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/TSimpleServer.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.server; + +import org.apache.thrift.TException; +import org.apache.thrift.TProcessor; +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TServerTransport; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportFactory; +import org.apache.thrift.transport.TTransportException; +import org.apache.log4j.Logger; + +/** + * Simple singlethreaded server for testing. + * + */ +public class TSimpleServer extends TServer { + + private static final Logger LOGGER = Logger.getLogger(TSimpleServer.class.getName()); + + private boolean stopped_ = false; + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport) { + super(new TProcessorFactory(processor), serverTransport); + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + super(new TProcessorFactory(processor), serverTransport, transportFactory, protocolFactory); + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + super(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + public TSimpleServer(TProcessorFactory processorFactory, + TServerTransport serverTransport) { + super(processorFactory, serverTransport); + } + + public TSimpleServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + super(processorFactory, serverTransport, transportFactory, protocolFactory); + } + + public TSimpleServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + + public void serve() { + stopped_ = false; + try { + serverTransport_.listen(); + } catch (TTransportException ttx) { + LOGGER.error("Error occurred during listening.", ttx); + return; + } + + while (!stopped_) { + TTransport client = null; + TProcessor processor = null; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try { + client = serverTransport_.accept(); + if (client != null) { + processor = processorFactory_.getProcessor(client); + inputTransport = inputTransportFactory_.getTransport(client); + outputTransport = outputTransportFactory_.getTransport(client); + inputProtocol = inputProtocolFactory_.getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_.getProtocol(outputTransport); + while (processor.process(inputProtocol, outputProtocol)) {} + } + } catch (TTransportException ttx) { + // Client died, just move on + } catch (TException tx) { + if (!stopped_) { + LOGGER.error("Thrift error occurred during processing of message.", tx); + } + } catch (Exception x) { + if (!stopped_) { + LOGGER.error("Error occurred during processing of message.", x); + } + } + + if (inputTransport != null) { + inputTransport.close(); + } + + if (outputTransport != null) { + outputTransport.close(); + } + + } + } + + public void stop() { + stopped_ = true; + serverTransport_.interrupt(); + } +} diff --git a/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java b/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java new file mode 100644 index 000000000..ebc5a9be6 --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.server; + +import org.apache.thrift.TException; +import org.apache.thrift.TProcessor; +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.transport.TServerTransport; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.transport.TTransportFactory; +import org.apache.log4j.Logger; +import org.apache.log4j.Level; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + + +/** + * Server which uses Java's built in ThreadPool management to spawn off + * a worker pool that + * + */ +public class TThreadPoolServer extends TServer { + + private static final Logger LOGGER = Logger.getLogger(TThreadPoolServer.class.getName()); + + // Executor service for handling client connections + private ExecutorService executorService_; + + // Flag for stopping the server + private volatile boolean stopped_; + + // Server options + private Options options_; + + // Customizable server options + public static class Options { + public int minWorkerThreads = 5; + public int maxWorkerThreads = Integer.MAX_VALUE; + public int stopTimeoutVal = 60; + public TimeUnit stopTimeoutUnit = TimeUnit.SECONDS; + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport) { + this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory()); + } + + public TThreadPoolServer(TProcessorFactory processorFactory, + TServerTransport serverTransport) { + this(processorFactory, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory()); + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + protocolFactory, protocolFactory); + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory); + } + + public TThreadPoolServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + this(processorFactory, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory); + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + this(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + public TThreadPoolServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + options_ = new Options(); + executorService_ = Executors.newCachedThreadPool(); + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + Options options) { + this(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, + options); + } + + public TThreadPoolServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + Options options) { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + + executorService_ = null; + + SynchronousQueue<Runnable> executorQueue = + new SynchronousQueue<Runnable>(); + + executorService_ = new ThreadPoolExecutor(options.minWorkerThreads, + options.maxWorkerThreads, + 60, + TimeUnit.SECONDS, + executorQueue); + + options_ = options; + } + + + public void serve() { + try { + serverTransport_.listen(); + } catch (TTransportException ttx) { + LOGGER.error("Error occurred during listening.", ttx); + return; + } + + stopped_ = false; + while (!stopped_) { + int failureCount = 0; + try { + TTransport client = serverTransport_.accept(); + WorkerProcess wp = new WorkerProcess(client); + executorService_.execute(wp); + } catch (TTransportException ttx) { + if (!stopped_) { + ++failureCount; + LOGGER.warn("Transport error occurred during acceptance of message.", ttx); + } + } + } + + executorService_.shutdown(); + + // Loop until awaitTermination finally does return without a interrupted + // exception. If we don't do this, then we'll shut down prematurely. We want + // to let the executorService clear it's task queue, closing client sockets + // appropriately. + long timeoutMS = options_.stopTimeoutUnit.toMillis(options_.stopTimeoutVal); + long now = System.currentTimeMillis(); + while (timeoutMS >= 0) { + try { + executorService_.awaitTermination(timeoutMS, TimeUnit.MILLISECONDS); + break; + } catch (InterruptedException ix) { + long newnow = System.currentTimeMillis(); + timeoutMS -= (newnow - now); + now = newnow; + } + } + } + + public void stop() { + stopped_ = true; + serverTransport_.interrupt(); + } + + private class WorkerProcess implements Runnable { + + /** + * Client that this services. + */ + private TTransport client_; + + /** + * Default constructor. + * + * @param client Transport to process + */ + private WorkerProcess(TTransport client) { + client_ = client; + } + + /** + * Loops on processing a client forever + */ + public void run() { + TProcessor processor = null; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try { + processor = processorFactory_.getProcessor(client_); + inputTransport = inputTransportFactory_.getTransport(client_); + outputTransport = outputTransportFactory_.getTransport(client_); + inputProtocol = inputProtocolFactory_.getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_.getProtocol(outputTransport); + // we check stopped_ first to make sure we're not supposed to be shutting + // down. this is necessary for graceful shutdown. + while (!stopped_ && processor.process(inputProtocol, outputProtocol)) {} + } catch (TTransportException ttx) { + // Assume the client died and continue silently + } catch (TException tx) { + LOGGER.error("Thrift error occurred during processing of message.", tx); + } catch (Exception x) { + LOGGER.error("Error occurred during processing of message.", x); + } + + if (inputTransport != null) { + inputTransport.close(); + } + + if (outputTransport != null) { + outputTransport.close(); + } + } + } +} diff --git a/lib/java/src/org/apache/thrift/transport/TFramedTransport.java b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java new file mode 100644 index 000000000..c83748ad2 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.ByteArrayInputStream; + +import org.apache.thrift.TByteArrayOutputStream; + +/** + * Socket implementation of the TTransport interface. To be commented soon! + * + */ +public class TFramedTransport extends TTransport { + + /** + * Underlying transport + */ + private TTransport transport_ = null; + + /** + * Buffer for output + */ + private final TByteArrayOutputStream writeBuffer_ = + new TByteArrayOutputStream(1024); + + /** + * Buffer for input + */ + private ByteArrayInputStream readBuffer_ = null; + + public static class Factory extends TTransportFactory { + public Factory() { + } + + public TTransport getTransport(TTransport base) { + return new TFramedTransport(base); + } + } + + /** + * Constructor wraps around another tranpsort + */ + public TFramedTransport(TTransport transport) { + transport_ = transport; + } + + public void open() throws TTransportException { + transport_.open(); + } + + public boolean isOpen() { + return transport_.isOpen(); + } + + public void close() { + transport_.close(); + } + + public int read(byte[] buf, int off, int len) throws TTransportException { + if (readBuffer_ != null) { + int got = readBuffer_.read(buf, off, len); + if (got > 0) { + return got; + } + } + + // Read another frame of data + readFrame(); + + return readBuffer_.read(buf, off, len); + } + + private void readFrame() throws TTransportException { + byte[] i32rd = new byte[4]; + transport_.readAll(i32rd, 0, 4); + int size = + ((i32rd[0] & 0xff) << 24) | + ((i32rd[1] & 0xff) << 16) | + ((i32rd[2] & 0xff) << 8) | + ((i32rd[3] & 0xff)); + + byte[] buff = new byte[size]; + transport_.readAll(buff, 0, size); + readBuffer_ = new ByteArrayInputStream(buff); + } + + public void write(byte[] buf, int off, int len) throws TTransportException { + writeBuffer_.write(buf, off, len); + } + + public void flush() throws TTransportException { + byte[] buf = writeBuffer_.get(); + int len = writeBuffer_.len(); + writeBuffer_.reset(); + + byte[] i32out = new byte[4]; + i32out[0] = (byte)(0xff & (len >> 24)); + i32out[1] = (byte)(0xff & (len >> 16)); + i32out[2] = (byte)(0xff & (len >> 8)); + i32out[3] = (byte)(0xff & (len)); + transport_.write(i32out, 0, 4); + transport_.write(buf, 0, len); + transport_.flush(); + } +} diff --git a/lib/java/src/org/apache/thrift/transport/THttpClient.java b/lib/java/src/org/apache/thrift/transport/THttpClient.java new file mode 100644 index 000000000..419235310 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/THttpClient.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.IOException; + +import java.net.URL; +import java.net.HttpURLConnection; +import java.util.HashMap; +import java.util.Map; + +/** + * HTTP implementation of the TTransport interface. Used for working with a + * Thrift web services implementation. + * + */ +public class THttpClient extends TTransport { + + private URL url_ = null; + + private final ByteArrayOutputStream requestBuffer_ = + new ByteArrayOutputStream(); + + private InputStream inputStream_ = null; + + private int connectTimeout_ = 0; + + private int readTimeout_ = 0; + + private Map<String,String> customHeaders_ = null; + + public THttpClient(String url) throws TTransportException { + try { + url_ = new URL(url); + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void setConnectTimeout(int timeout) { + connectTimeout_ = timeout; + } + + public void setReadTimeout(int timeout) { + readTimeout_ = timeout; + } + + public void setCustomHeaders(Map<String,String> headers) { + customHeaders_ = headers; + } + + public void setCustomHeader(String key, String value) { + if (customHeaders_ == null) { + customHeaders_ = new HashMap<String, String>(); + } + customHeaders_.put(key, value); + } + + public void open() {} + + public void close() { + if (null != inputStream_) { + try { + inputStream_.close(); + } catch (IOException ioe) { + ; + } + inputStream_ = null; + } + } + + public boolean isOpen() { + return true; + } + + public int read(byte[] buf, int off, int len) throws TTransportException { + if (inputStream_ == null) { + throw new TTransportException("Response buffer is empty, no request."); + } + try { + int ret = inputStream_.read(buf, off, len); + if (ret == -1) { + throw new TTransportException("No more data available."); + } + return ret; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void write(byte[] buf, int off, int len) { + requestBuffer_.write(buf, off, len); + } + + public void flush() throws TTransportException { + // Extract request and reset buffer + byte[] data = requestBuffer_.toByteArray(); + requestBuffer_.reset(); + + try { + // Create connection object + HttpURLConnection connection = (HttpURLConnection)url_.openConnection(); + + // Timeouts, only if explicitly set + if (connectTimeout_ > 0) { + connection.setConnectTimeout(connectTimeout_); + } + if (readTimeout_ > 0) { + connection.setReadTimeout(readTimeout_); + } + + // Make the request + connection.setRequestMethod("POST"); + connection.setRequestProperty("Content-Type", "application/x-thrift"); + connection.setRequestProperty("Accept", "application/x-thrift"); + connection.setRequestProperty("User-Agent", "Java/THttpClient"); + if (customHeaders_ != null) { + for (Map.Entry<String, String> header : customHeaders_.entrySet()) { + connection.setRequestProperty(header.getKey(), header.getValue()); + } + } + connection.setDoOutput(true); + connection.connect(); + connection.getOutputStream().write(data); + + int responseCode = connection.getResponseCode(); + if (responseCode != HttpURLConnection.HTTP_OK) { + throw new TTransportException("HTTP Response code: " + responseCode); + } + + // Read the responses + inputStream_ = connection.getInputStream(); + + } catch (IOException iox) { + throw new TTransportException(iox); + } + } +} diff --git a/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java b/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java new file mode 100644 index 000000000..89cdb5828 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.log4j.Logger; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * This is the most commonly used base transport. It takes an InputStream + * and an OutputStream and uses those to perform all transport operations. + * This allows for compatibility with all the nice constructs Java already + * has to provide a variety of types of streams. + * + */ +public class TIOStreamTransport extends TTransport { + + private static final Logger LOGGER = Logger.getLogger(TIOStreamTransport.class.getName()); + + /** Underlying inputStream */ + protected InputStream inputStream_ = null; + + /** Underlying outputStream */ + protected OutputStream outputStream_ = null; + + /** + * Subclasses can invoke the default constructor and then assign the input + * streams in the open method. + */ + protected TIOStreamTransport() {} + + /** + * Input stream constructor. + * + * @param is Input stream to read from + */ + public TIOStreamTransport(InputStream is) { + inputStream_ = is; + } + + /** + * Output stream constructor. + * + * @param os Output stream to read from + */ + public TIOStreamTransport(OutputStream os) { + outputStream_ = os; + } + + /** + * Two-way stream constructor. + * + * @param is Input stream to read from + * @param os Output stream to read from + */ + public TIOStreamTransport(InputStream is, OutputStream os) { + inputStream_ = is; + outputStream_ = os; + } + + /** + * The streams must already be open at construction time, so this should + * always return true. + * + * @return true + */ + public boolean isOpen() { + return true; + } + + /** + * The streams must already be open. This method does nothing. + */ + public void open() throws TTransportException {} + + /** + * Closes both the input and output streams. + */ + public void close() { + if (inputStream_ != null) { + try { + inputStream_.close(); + } catch (IOException iox) { + LOGGER.warn("Error closing input stream.", iox); + } + inputStream_ = null; + } + if (outputStream_ != null) { + try { + outputStream_.close(); + } catch (IOException iox) { + LOGGER.warn("Error closing output stream.", iox); + } + outputStream_ = null; + } + } + + /** + * Reads from the underlying input stream if not null. + */ + public int read(byte[] buf, int off, int len) throws TTransportException { + if (inputStream_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot read from null inputStream"); + } + try { + return inputStream_.read(buf, off, len); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Writes to the underlying output stream if not null. + */ + public void write(byte[] buf, int off, int len) throws TTransportException { + if (outputStream_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot write to null outputStream"); + } + try { + outputStream_.write(buf, off, len); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Flushes the underlying output stream if not null. + */ + public void flush() throws TTransportException { + if (outputStream_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot flush null outputStream"); + } + try { + outputStream_.flush(); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } +} diff --git a/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java b/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java new file mode 100644 index 000000000..886fcbf62 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TByteArrayOutputStream; +import java.io.UnsupportedEncodingException; + +/** + * Memory buffer-based implementation of the TTransport interface. + * + */ +public class TMemoryBuffer extends TTransport { + + /** + * + */ + public TMemoryBuffer(int size) { + arr_ = new TByteArrayOutputStream(size); + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void open() { + /* Do nothing */ + } + + @Override + public void close() { + /* Do nothing */ + } + + @Override + public int read(byte[] buf, int off, int len) { + byte[] src = arr_.get(); + int amtToRead = (len > arr_.len() - pos_ ? arr_.len() - pos_ : len); + if (amtToRead > 0) { + System.arraycopy(src, pos_, buf, off, amtToRead); + pos_ += amtToRead; + } + return amtToRead; + } + + @Override + public void write(byte[] buf, int off, int len) { + arr_.write(buf, off, len); + } + + /** + * Output the contents of the memory buffer as a String, using the supplied + * encoding + * @param enc the encoding to use + * @return the contents of the memory buffer as a String + */ + public String toString(String enc) throws UnsupportedEncodingException { + return arr_.toString(enc); + } + + public String inspect() { + String buf = ""; + byte[] bytes = arr_.toByteArray(); + for (int i = 0; i < bytes.length; i++) { + buf += (pos_ == i ? "==>" : "" ) + Integer.toHexString(bytes[i] & 0xff) + " "; + } + return buf; + } + + // The contents of the buffer + private TByteArrayOutputStream arr_; + + // Position to read next byte from + private int pos_; + + public int length() { + return arr_.size(); + } +} + diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java new file mode 100644 index 000000000..571adbff3 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.transport; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.SocketException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; + +/** + * Wrapper around ServerSocketChannel + */ +public class TNonblockingServerSocket extends TNonblockingServerTransport { + + /** + * This channel is where all the nonblocking magic happens. + */ + private ServerSocketChannel serverSocketChannel = null; + + /** + * Underlying serversocket object + */ + private ServerSocket serverSocket_ = null; + + /** + * Port to listen on + */ + private int port_ = 0; + + /** + * Timeout for client sockets from accept + */ + private int clientTimeout_ = 0; + + /** + * Creates a server socket from underlying socket object + */ + // public TNonblockingServerSocket(ServerSocket serverSocket) { + // this(serverSocket, 0); + // } + + /** + * Creates a server socket from underlying socket object + */ + // public TNonblockingServerSocket(ServerSocket serverSocket, int clientTimeout) { + // serverSocket_ = serverSocket; + // clientTimeout_ = clientTimeout; + // } + + /** + * Creates just a port listening server socket + */ + public TNonblockingServerSocket(int port) throws TTransportException { + this(port, 0); + } + + /** + * Creates just a port listening server socket + */ + public TNonblockingServerSocket(int port, int clientTimeout) throws TTransportException { + port_ = port; + clientTimeout_ = clientTimeout; + try { + serverSocketChannel = ServerSocketChannel.open(); + serverSocketChannel.configureBlocking(false); + + // Make server socket + serverSocket_ = serverSocketChannel.socket(); + // Prevent 2MSL delay problem on server restarts + serverSocket_.setReuseAddress(true); + // Bind to listening port + serverSocket_.bind(new InetSocketAddress(port_)); + } catch (IOException ioe) { + serverSocket_ = null; + throw new TTransportException("Could not create ServerSocket on port " + port + "."); + } + } + + public void listen() throws TTransportException { + // Make sure not to block on accept + if (serverSocket_ != null) { + try { + serverSocket_.setSoTimeout(0); + } catch (SocketException sx) { + sx.printStackTrace(); + } + } + } + + protected TNonblockingSocket acceptImpl() throws TTransportException { + if (serverSocket_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket."); + } + try { + SocketChannel socketChannel = serverSocketChannel.accept(); + if (socketChannel == null) { + return null; + } + + TNonblockingSocket tsocket = new TNonblockingSocket(socketChannel); + tsocket.setTimeout(clientTimeout_); + return tsocket; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void registerSelector(Selector selector) { + try { + // Register the server socket channel, indicating an interest in + // accepting new connections + serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT); + } catch (ClosedChannelException e) { + // this shouldn't happen, ideally... + // TODO: decide what to do with this. + } + } + + public void close() { + if (serverSocket_ != null) { + try { + serverSocket_.close(); + } catch (IOException iox) { + System.err.println("WARNING: Could not close server socket: " + + iox.getMessage()); + } + serverSocket_ = null; + } + } + + public void interrupt() { + // The thread-safeness of this is dubious, but Java documentation suggests + // that it is safe to do this from a different thread context + close(); + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java b/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java new file mode 100644 index 000000000..ba45b09dc --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.transport; + +import java.nio.channels.Selector; + +/** + * Server transport that can be operated in a nonblocking fashion. + */ +public abstract class TNonblockingServerTransport extends TServerTransport { + + public abstract void registerSelector(Selector selector); +} diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java new file mode 100644 index 000000000..bc2d53969 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.transport; + +import java.io.IOException; +import java.net.Socket; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; + +/** + * Socket implementation of the TTransport interface. To be commented soon! + */ +public class TNonblockingSocket extends TNonblockingTransport { + + private SocketChannel socketChannel = null; + + /** + * Wrapped Socket object + */ + private Socket socket_ = null; + + /** + * Remote host + */ + private String host_ = null; + + /** + * Remote port + */ + private int port_ = 0; + + /** + * Socket timeout + */ + private int timeout_ = 0; + + /** + * Constructor that takes an already created socket. + * + * @param socketChannel Already created SocketChannel object + * @throws TTransportException if there is an error setting up the streams + */ + public TNonblockingSocket(SocketChannel socketChannel) throws TTransportException { + try { + // make it a nonblocking channel + socketChannel.configureBlocking(false); + } catch (IOException e) { + throw new TTransportException(e); + } + + this.socketChannel = socketChannel; + this.socket_ = socketChannel.socket(); + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + } catch (SocketException sx) { + sx.printStackTrace(); + } + } + + /** + * Register this socket with the specified selector for both read and write + * operations. + * + * @param selector + * @return the selection key for this socket. + */ + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + // Register the new SocketChannel with our Selector, indicating + // we'd like to be notified when there's data waiting to be read + return socketChannel.register(selector, interests); + } + + /** + * Initializes the socket object + */ + private void initSocket() { + socket_ = new Socket(); + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + socket_.setSoTimeout(timeout_); + } catch (SocketException sx) { + sx.printStackTrace(); + } + } + + /** + * Sets the socket timeout + * + * @param timeout Milliseconds timeout + */ + public void setTimeout(int timeout) { + timeout_ = timeout; + try { + socket_.setSoTimeout(timeout); + } catch (SocketException sx) { + sx.printStackTrace(); + } + } + + /** + * Returns a reference to the underlying socket. + */ + public Socket getSocket() { + if (socket_ == null) { + initSocket(); + } + return socket_; + } + + /** + * Checks whether the socket is connected. + */ + public boolean isOpen() { + if (socket_ == null) { + return false; + } + return socket_.isConnected(); + } + + /** + * Connects the socket, creating a new socket object if necessary. + */ + public void open() throws TTransportException { + throw new RuntimeException("Not implemented yet"); + } + + /** + * Perform a nonblocking read into buffer. + */ + public int read(ByteBuffer buffer) throws IOException { + return socketChannel.read(buffer); + } + + + /** + * Reads from the underlying input stream if not null. + */ + public int read(byte[] buf, int off, int len) throws TTransportException { + if ((socketChannel.validOps() & SelectionKey.OP_READ) != SelectionKey.OP_READ) { + throw new TTransportException(TTransportException.NOT_OPEN, + "Cannot read from write-only socket channel"); + } + try { + return socketChannel.read(ByteBuffer.wrap(buf, off, len)); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Perform a nonblocking write of the data in buffer; + */ + public int write(ByteBuffer buffer) throws IOException { + return socketChannel.write(buffer); + } + + /** + * Writes to the underlying output stream if not null. + */ + public void write(byte[] buf, int off, int len) throws TTransportException { + if ((socketChannel.validOps() & SelectionKey.OP_WRITE) != SelectionKey.OP_WRITE) { + throw new TTransportException(TTransportException.NOT_OPEN, + "Cannot write to write-only socket channel"); + } + try { + socketChannel.write(ByteBuffer.wrap(buf, off, len)); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Flushes the underlying output stream if not null. + */ + public void flush() throws TTransportException { + // Not supported by SocketChannel. + } + + /** + * Closes the socket. + */ + public void close() { + try { + socketChannel.close(); + } catch (IOException e) { + // silently ignore. + } + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingTransport.java b/lib/java/src/org/apache/thrift/transport/TNonblockingTransport.java new file mode 100644 index 000000000..517eacb74 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TNonblockingTransport.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.IOException; +import java.nio.channels.Selector; +import java.nio.channels.SelectionKey; +import java.nio.ByteBuffer; + +public abstract class TNonblockingTransport extends TTransport { + public abstract SelectionKey registerSelector(Selector selector, int interests) throws IOException; + public abstract int read(ByteBuffer buffer) throws IOException; + public abstract int write(ByteBuffer buffer) throws IOException; +} diff --git a/lib/java/src/org/apache/thrift/transport/TServerSocket.java b/lib/java/src/org/apache/thrift/transport/TServerSocket.java new file mode 100644 index 000000000..796cd659c --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TServerSocket.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.log4j.Logger; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketException; + +/** + * Wrapper around ServerSocket for Thrift. + * + */ +public class TServerSocket extends TServerTransport { + + private static final Logger LOGGER = Logger.getLogger(TServerSocket.class.getName()); + + /** + * Underlying serversocket object + */ + private ServerSocket serverSocket_ = null; + + /** + * Port to listen on + */ + private int port_ = 0; + + /** + * Timeout for client sockets from accept + */ + private int clientTimeout_ = 0; + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(ServerSocket serverSocket) { + this(serverSocket, 0); + } + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(ServerSocket serverSocket, int clientTimeout) { + serverSocket_ = serverSocket; + clientTimeout_ = clientTimeout; + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port) throws TTransportException { + this(port, 0); + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port, int clientTimeout) throws TTransportException { + this(new InetSocketAddress(port), clientTimeout); + port_ = port; + } + + public TServerSocket(InetSocketAddress bindAddr) throws TTransportException { + this(bindAddr, 0); + } + + public TServerSocket(InetSocketAddress bindAddr, int clientTimeout) throws TTransportException { + clientTimeout_ = clientTimeout; + try { + // Make server socket + serverSocket_ = new ServerSocket(); + // Prevent 2MSL delay problem on server restarts + serverSocket_.setReuseAddress(true); + // Bind to listening port + serverSocket_.bind(bindAddr); + } catch (IOException ioe) { + serverSocket_ = null; + throw new TTransportException("Could not create ServerSocket on address " + bindAddr.toString() + "."); + } + } + + public void listen() throws TTransportException { + // Make sure not to block on accept + if (serverSocket_ != null) { + try { + serverSocket_.setSoTimeout(0); + } catch (SocketException sx) { + LOGGER.error("Could not set socket timeout.", sx); + } + } + } + + protected TSocket acceptImpl() throws TTransportException { + if (serverSocket_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket."); + } + try { + Socket result = serverSocket_.accept(); + TSocket result2 = new TSocket(result); + result2.setTimeout(clientTimeout_); + return result2; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void close() { + if (serverSocket_ != null) { + try { + serverSocket_.close(); + } catch (IOException iox) { + LOGGER.warn("Could not close server socket.", iox); + } + serverSocket_ = null; + } + } + + public void interrupt() { + // The thread-safeness of this is dubious, but Java documentation suggests + // that it is safe to do this from a different thread context + close(); + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TServerTransport.java b/lib/java/src/org/apache/thrift/transport/TServerTransport.java new file mode 100644 index 000000000..17ff86bec --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TServerTransport.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +/** + * Server transport. Object which provides client transports. + * + */ +public abstract class TServerTransport { + + public abstract void listen() throws TTransportException; + + public final TTransport accept() throws TTransportException { + TTransport transport = acceptImpl(); + if (transport == null) { + throw new TTransportException("accept() may not return NULL"); + } + return transport; + } + + public abstract void close(); + + protected abstract TTransport acceptImpl() throws TTransportException; + + /** + * Optional method implementation. This signals to the server transport + * that it should break out of any accept() or listen() that it is currently + * blocked on. This method, if implemented, MUST be thread safe, as it may + * be called from a different thread context than the other TServerTransport + * methods. + */ + public void interrupt() {} + +} diff --git a/lib/java/src/org/apache/thrift/transport/TSocket.java b/lib/java/src/org/apache/thrift/transport/TSocket.java new file mode 100644 index 000000000..cdf1bcc4b --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TSocket.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.log4j.Logger; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketException; + +/** + * Socket implementation of the TTransport interface. To be commented soon! + * + */ +public class TSocket extends TIOStreamTransport { + + private static final Logger LOGGER = Logger.getLogger(TSocket.class.getName()); + + /** + * Wrapped Socket object + */ + private Socket socket_ = null; + + /** + * Remote host + */ + private String host_ = null; + + /** + * Remote port + */ + private int port_ = 0; + + /** + * Socket timeout + */ + private int timeout_ = 0; + + /** + * Constructor that takes an already created socket. + * + * @param socket Already created socket object + * @throws TTransportException if there is an error setting up the streams + */ + public TSocket(Socket socket) throws TTransportException { + socket_ = socket; + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + } catch (SocketException sx) { + LOGGER.warn("Could not configure socket.", sx); + } + + if (isOpen()) { + try { + inputStream_ = new BufferedInputStream(socket_.getInputStream(), 1024); + outputStream_ = new BufferedOutputStream(socket_.getOutputStream(), 1024); + } catch (IOException iox) { + close(); + throw new TTransportException(TTransportException.NOT_OPEN, iox); + } + } + } + + /** + * Creates a new unconnected socket that will connect to the given host + * on the given port. + * + * @param host Remote host + * @param port Remote port + */ + public TSocket(String host, int port) { + this(host, port, 0); + } + + /** + * Creates a new unconnected socket that will connect to the given host + * on the given port. + * + * @param host Remote host + * @param port Remote port + * @param timeout Socket timeout + */ + public TSocket(String host, int port, int timeout) { + host_ = host; + port_ = port; + timeout_ = timeout; + initSocket(); + } + + /** + * Initializes the socket object + */ + private void initSocket() { + socket_ = new Socket(); + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + socket_.setSoTimeout(timeout_); + } catch (SocketException sx) { + LOGGER.error("Could not configure socket.", sx); + } + } + + /** + * Sets the socket timeout + * + * @param timeout Milliseconds timeout + */ + public void setTimeout(int timeout) { + timeout_ = timeout; + try { + socket_.setSoTimeout(timeout); + } catch (SocketException sx) { + LOGGER.warn("Could not set socket timeout.", sx); + } + } + + /** + * Returns a reference to the underlying socket. + */ + public Socket getSocket() { + if (socket_ == null) { + initSocket(); + } + return socket_; + } + + /** + * Checks whether the socket is connected. + */ + public boolean isOpen() { + if (socket_ == null) { + return false; + } + return socket_.isConnected(); + } + + /** + * Connects the socket, creating a new socket object if necessary. + */ + public void open() throws TTransportException { + if (isOpen()) { + throw new TTransportException(TTransportException.ALREADY_OPEN, "Socket already connected."); + } + + if (host_.length() == 0) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot open null host."); + } + if (port_ <= 0) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot open without port."); + } + + if (socket_ == null) { + initSocket(); + } + + try { + socket_.connect(new InetSocketAddress(host_, port_)); + inputStream_ = new BufferedInputStream(socket_.getInputStream(), 1024); + outputStream_ = new BufferedOutputStream(socket_.getOutputStream(), 1024); + } catch (IOException iox) { + close(); + throw new TTransportException(TTransportException.NOT_OPEN, iox); + } + } + + /** + * Closes the socket. + */ + public void close() { + // Close the underlying streams + super.close(); + + // Close the socket + if (socket_ != null) { + try { + socket_.close(); + } catch (IOException iox) { + LOGGER.warn("Could not close socket.", iox); + } + socket_ = null; + } + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TTransport.java b/lib/java/src/org/apache/thrift/transport/TTransport.java new file mode 100644 index 000000000..a6c047bb5 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TTransport.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +/** + * Generic class that encapsulates the I/O layer. This is basically a thin + * wrapper around the combined functionality of Java input/output streams. + * + */ +public abstract class TTransport { + + /** + * Queries whether the transport is open. + * + * @return True if the transport is open. + */ + public abstract boolean isOpen(); + + /** + * Is there more data to be read? + * + * @return True if the remote side is still alive and feeding us + */ + public boolean peek() { + return isOpen(); + } + + /** + * Opens the transport for reading/writing. + * + * @throws TTransportException if the transport could not be opened + */ + public abstract void open() + throws TTransportException; + + /** + * Closes the transport. + */ + public abstract void close(); + + /** + * Reads up to len bytes into buffer buf, starting att offset off. + * + * @param buf Array to read into + * @param off Index to start reading at + * @param len Maximum number of bytes to read + * @return The number of bytes actually read + * @throws TTransportException if there was an error reading data + */ + public abstract int read(byte[] buf, int off, int len) + throws TTransportException; + + /** + * Guarantees that all of len bytes are actually read off the transport. + * + * @param buf Array to read into + * @param off Index to start reading at + * @param len Maximum number of bytes to read + * @return The number of bytes actually read, which must be equal to len + * @throws TTransportException if there was an error reading data + */ + public int readAll(byte[] buf, int off, int len) + throws TTransportException { + int got = 0; + int ret = 0; + while (got < len) { + ret = read(buf, off+got, len-got); + if (ret <= 0) { + throw new TTransportException("Cannot read. Remote side has closed. Tried to read " + len + " bytes, but only got " + got + " bytes."); + } + got += ret; + } + return got; + } + + /** + * Writes the buffer to the output + * + * @param buf The output data buffer + * @throws TTransportException if an error occurs writing data + */ + public void write(byte[] buf) throws TTransportException { + write(buf, 0, buf.length); + } + + /** + * Writes up to len bytes from the buffer. + * + * @param buf The output data buffer + * @param off The offset to start writing from + * @param len The number of bytes to write + * @throws TTransportException if there was an error writing data + */ + public abstract void write(byte[] buf, int off, int len) + throws TTransportException; + + /** + * Flush any pending data out of a transport buffer. + * + * @throws TTransportException if there was an error writing out data. + */ + public void flush() + throws TTransportException {} +} diff --git a/lib/java/src/org/apache/thrift/transport/TTransportException.java b/lib/java/src/org/apache/thrift/transport/TTransportException.java new file mode 100644 index 000000000..d08f3b02b --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TTransportException.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TException; + +/** + * Transport exceptions. + * + */ +public class TTransportException extends TException { + + private static final long serialVersionUID = 1L; + + public static final int UNKNOWN = 0; + public static final int NOT_OPEN = 1; + public static final int ALREADY_OPEN = 2; + public static final int TIMED_OUT = 3; + public static final int END_OF_FILE = 4; + + protected int type_ = UNKNOWN; + + public TTransportException() { + super(); + } + + public TTransportException(int type) { + super(); + type_ = type; + } + + public TTransportException(int type, String message) { + super(message); + type_ = type; + } + + public TTransportException(String message) { + super(message); + } + + public TTransportException(int type, Throwable cause) { + super(cause); + type_ = type; + } + + public TTransportException(Throwable cause) { + super(cause); + } + + public TTransportException(String message, Throwable cause) { + super(message, cause); + } + + public TTransportException(int type, String message, Throwable cause) { + super(message, cause); + type_ = type; + } + + public int getType() { + return type_; + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TTransportFactory.java b/lib/java/src/org/apache/thrift/transport/TTransportFactory.java new file mode 100644 index 000000000..3e71630ae --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TTransportFactory.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +/** + * Factory class used to create wrapped instance of Transports. + * This is used primarily in servers, which get Transports from + * a ServerTransport and then may want to mutate them (i.e. create + * a BufferedTransport from the underlying base transport) + * + */ +public class TTransportFactory { + + /** + * Return a wrapped instance of the base Transport. + * + * @param trans The base transport + * @return Wrapped Transport + */ + public TTransport getTransport(TTransport trans) { + return trans; + } + +} diff --git a/lib/java/test/TestClient b/lib/java/test/TestClient new file mode 100755 index 000000000..bd3c996fc --- /dev/null +++ b/lib/java/test/TestClient @@ -0,0 +1,22 @@ +#!/bin/bash -v + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +java -cp thrifttest.jar:../../lib/java/libthrift.jar org.apache.thrift.test.TestClient $* diff --git a/lib/java/test/TestNonblockingServer b/lib/java/test/TestNonblockingServer new file mode 100644 index 000000000..070991c7c --- /dev/null +++ b/lib/java/test/TestNonblockingServer @@ -0,0 +1,22 @@ +#!/bin/bash -v + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +java -server -Xmx256m -cp thrifttest.jar:../../lib/java/libthrift.jar org.apache.thrift.test.TestNonblockingServer $* diff --git a/lib/java/test/TestServer b/lib/java/test/TestServer new file mode 100755 index 000000000..0d36b58ea --- /dev/null +++ b/lib/java/test/TestServer @@ -0,0 +1,22 @@ +#!/bin/bash -v + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +java -server -cp thrifttest.jar:../../lib/java/libthrift.jar org.apache.thrift.test.TestServer $* diff --git a/lib/java/test/org/apache/thrift/test/DeepCopyTest.java b/lib/java/test/org/apache/thrift/test/DeepCopyTest.java new file mode 100644 index 000000000..a171cab35 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/DeepCopyTest.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import org.apache.thrift.TDeserializer; +import org.apache.thrift.TSerializer; +import org.apache.thrift.protocol.TBinaryProtocol; +import thrift.test.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +public class DeepCopyTest { + + private static final byte[] kUnicodeBytes = { + (byte)0xd3, (byte)0x80, (byte)0xe2, (byte)0x85, (byte)0xae, (byte)0xce, + (byte)0x9d, (byte)0x20, (byte)0xd0, (byte)0x9d, (byte)0xce, (byte)0xbf, + (byte)0xe2, (byte)0x85, (byte)0xbf, (byte)0xd0, (byte)0xbe, (byte)0xc9, + (byte)0xa1, (byte)0xd0, (byte)0xb3, (byte)0xd0, (byte)0xb0, (byte)0xcf, + (byte)0x81, (byte)0xe2, (byte)0x84, (byte)0x8e, (byte)0x20, (byte)0xce, + (byte)0x91, (byte)0x74, (byte)0x74, (byte)0xce, (byte)0xb1, (byte)0xe2, + (byte)0x85, (byte)0xbd, (byte)0xce, (byte)0xba, (byte)0x83, (byte)0xe2, + (byte)0x80, (byte)0xbc + }; + + public static void main(String[] args) throws Exception { + TSerializer binarySerializer = new TSerializer(new TBinaryProtocol.Factory()); + TDeserializer binaryDeserializer = new TDeserializer(new TBinaryProtocol.Factory()); + + OneOfEach ooe = new OneOfEach(); + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = (byte) 0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1 << 24; + ooe.integer64 = (long) 6000 * 1000 * 1000; + ooe.double_precision = Math.PI; + ooe.some_characters = "JSON THIS! \"\1"; + ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + ooe.base64 = "string to bytes".getBytes(); + + Nesting n = new Nesting(new Bonk(), new OneOfEach()); + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (Math.sqrt(5) + 1) / 2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + HolyMoley hm = new HolyMoley(); + + hm.big = new ArrayList<OneOfEach>(); + hm.big.add(ooe); + hm.big.add(n.my_ooe); + hm.big.get(0).a_bite = (byte) 0x22; + hm.big.get(1).a_bite = (byte) 0x23; + + hm.contain = new HashSet<List<String>>(); + ArrayList<String> stage1 = new ArrayList<String>(2); + stage1.add("and a one"); + stage1.add("and a two"); + hm.contain.add(stage1); + stage1 = new ArrayList<String>(3); + stage1.add("then a one, two"); + stage1.add("three!"); + stage1.add("FOUR!!"); + hm.contain.add(stage1); + stage1 = new ArrayList<String>(0); + hm.contain.add(stage1); + + ArrayList<Bonk> stage2 = new ArrayList<Bonk>(); + hm.bonks = new HashMap<String, List<Bonk>>(); + hm.bonks.put("nothing", stage2); + Bonk b = new Bonk(); + b.type = 1; + b.message = "Wait."; + stage2.add(b); + b = new Bonk(); + b.type = 2; + b.message = "What?"; + stage2.add(b); + stage2 = new ArrayList<Bonk>(); + hm.bonks.put("something", stage2); + b = new Bonk(); + b.type = 3; + b.message = "quoth"; + b = new Bonk(); + b.type = 4; + b.message = "the raven"; + b = new Bonk(); + b.type = 5; + b.message = "nevermore"; + hm.bonks.put("poe", stage2); + + + byte[] binaryCopy = binarySerializer.serialize(hm); + HolyMoley hmCopy = new HolyMoley(); + binaryDeserializer.deserialize(hmCopy, binaryCopy); + HolyMoley hmCopy2 = new HolyMoley(hm); + + if (!hm.equals(hmCopy)) + throw new RuntimeException("copy constructor modified the original object!"); + if (!hmCopy.equals(hmCopy2)) + throw new RuntimeException("copy constructor generated incorrect copy"); + + hm.big.get(0).base64[0]++; // change binary value in original object + if (hm.equals(hmCopy2)) // make sure the change didn't propagate to the copied object + throw new RuntimeException("Binary field not copied correctly!"); + hm.big.get(0).base64[0]--; // undo change + + hmCopy2.bonks.get("nothing").get(1).message = "What else?"; + + if (hm.equals(hmCopy2)) + throw new RuntimeException("A deep copy was not done!"); + + } +} diff --git a/lib/java/test/org/apache/thrift/test/EqualityTest.java b/lib/java/test/org/apache/thrift/test/EqualityTest.java new file mode 100644 index 000000000..f01378f74 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/EqualityTest.java @@ -0,0 +1,661 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* +This program was generated by the following Python script: + +#!/usr/bin/python2.5 + +# Remove this when Python 2.6 hits the streets. +from __future__ import with_statement + +import sys +import os.path + + +# Quines the easy way. +with open(sys.argv[0], 'r') as handle: + source = handle.read() + +with open(os.path.join(os.path.dirname(sys.argv[0]), 'EqualityTest.java'), 'w') as out: + print >> out, ("/""*" r""" + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + """ "*""/") + print >> out + print >> out, "/""*" + print >> out, "This program was generated by the following Python script:" + print >> out + out.write(source) + print >> out, "*""/" + + print >> out, r''' +package org.apache.thrift.test; + +// Generated code +import thrift.test.*; + +/'''r'''** + *'''r'''/ +public class EqualityTest { + public static void main(String[] args) throws Exception { + JavaTestHelper lhs, rhs; +''' + + vals = { + 'int': ("1", "2"), + 'obj': ("\"foo\"", "\"bar\""), + 'bin': ("new byte[]{1,2}", "new byte[]{3,4}"), + } + matrix = ( + (False,False), + (False,True ), + (True ,False), + (True ,True ), + ) + + for type in ('int', 'obj', 'bin'): + for option in ('req', 'opt'): + nulls = matrix[0:1] if type == 'int' else matrix[-1::-1] + issets = matrix + for is_null in nulls: + for is_set in issets: + # isset is implied for non-primitives, so only consider the case + # where isset and non-null match. + if type != 'int' and list(is_set) != [ not null for null in is_null ]: + continue + for equal in (True, False): + print >> out + print >> out, " lhs = new JavaTestHelper();" + print >> out, " rhs = new JavaTestHelper();" + print >> out, " lhs." + option + "_" + type, "=", vals[type][0] + ";" + print >> out, " rhs." + option + "_" + type, "=", vals[type][0 if equal else 1] + ";" + isset_setter = "set" + option[0].upper() + option[1:] + "_" + type + "IsSet" + if (type == 'int' and is_set[0]): print >> out, " lhs." + isset_setter + "(true);" + if (type == 'int' and is_set[1]): print >> out, " rhs." + isset_setter + "(true);" + if (is_null[0]): print >> out, " lhs." + option + "_" + type, "= null;" + if (is_null[1]): print >> out, " rhs." + option + "_" + type, "= null;" + this_present = not is_null[0] and (option == 'req' or is_set[0]) + that_present = not is_null[1] and (option == 'req' or is_set[1]) + print >> out, " // this_present = " + repr(this_present) + print >> out, " // that_present = " + repr(that_present) + is_equal = \ + (not this_present and not that_present) or \ + (this_present and that_present and equal) + eq_str = 'true' if is_equal else 'false' + + print >> out, " if (lhs.equals(rhs) != "+eq_str+")" + print >> out, " throw new RuntimeException(\"Failure\");" + if is_equal: + print >> out, " if (lhs.hashCode() != rhs.hashCode())" + print >> out, " throw new RuntimeException(\"Failure\");" + + print >> out, r''' + } +} +''' +*/ + +package org.apache.thrift.test; + +// Generated code +import thrift.test.*; + +/** + */ +public class EqualityTest { + public static void main(String[] args) throws Exception { + JavaTestHelper lhs, rhs; + + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 1; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 2; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 1; + rhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 2; + rhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 1; + lhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 2; + lhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 1; + lhs.setReq_intIsSet(true); + rhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 2; + lhs.setReq_intIsSet(true); + rhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 1; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 2; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 1; + rhs.setOpt_intIsSet(true); + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 2; + rhs.setOpt_intIsSet(true); + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 1; + lhs.setOpt_intIsSet(true); + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 2; + lhs.setOpt_intIsSet(true); + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 1; + lhs.setOpt_intIsSet(true); + rhs.setOpt_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 2; + lhs.setOpt_intIsSet(true); + rhs.setOpt_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "foo"; + lhs.req_obj = null; + rhs.req_obj = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "bar"; + lhs.req_obj = null; + rhs.req_obj = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "foo"; + lhs.req_obj = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "bar"; + lhs.req_obj = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "foo"; + rhs.req_obj = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "bar"; + rhs.req_obj = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "foo"; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "bar"; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "foo"; + lhs.opt_obj = null; + rhs.opt_obj = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "bar"; + lhs.opt_obj = null; + rhs.opt_obj = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "foo"; + lhs.opt_obj = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "bar"; + lhs.opt_obj = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "foo"; + rhs.opt_obj = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "bar"; + rhs.opt_obj = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "foo"; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "bar"; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{1,2}; + lhs.req_bin = null; + rhs.req_bin = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{3,4}; + lhs.req_bin = null; + rhs.req_bin = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{1,2}; + lhs.req_bin = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{3,4}; + lhs.req_bin = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{1,2}; + rhs.req_bin = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{3,4}; + rhs.req_bin = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{1,2}; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{3,4}; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{1,2}; + lhs.opt_bin = null; + rhs.opt_bin = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{3,4}; + lhs.opt_bin = null; + rhs.opt_bin = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{1,2}; + lhs.opt_bin = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{3,4}; + lhs.opt_bin = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{3,4}; + rhs.opt_bin = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{1,2}; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{3,4}; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + } +} + diff --git a/lib/java/test/org/apache/thrift/test/Fixtures.java b/lib/java/test/org/apache/thrift/test/Fixtures.java new file mode 100644 index 000000000..14ac44f77 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/Fixtures.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import java.util.*; +import thrift.test.*; + +public class Fixtures { + + private static final byte[] kUnicodeBytes = { + (byte)0xd3, (byte)0x80, (byte)0xe2, (byte)0x85, (byte)0xae, (byte)0xce, + (byte)0x9d, (byte)0x20, (byte)0xd0, (byte)0x9d, (byte)0xce, (byte)0xbf, + (byte)0xe2, (byte)0x85, (byte)0xbf, (byte)0xd0, (byte)0xbe, (byte)0xc9, + (byte)0xa1, (byte)0xd0, (byte)0xb3, (byte)0xd0, (byte)0xb0, (byte)0xcf, + (byte)0x81, (byte)0xe2, (byte)0x84, (byte)0x8e, (byte)0x20, (byte)0xce, + (byte)0x91, (byte)0x74, (byte)0x74, (byte)0xce, (byte)0xb1, (byte)0xe2, + (byte)0x85, (byte)0xbd, (byte)0xce, (byte)0xba, (byte)0x83, (byte)0xe2, + (byte)0x80, (byte)0xbc + }; + + + public static final OneOfEach oneOfEach; + public static final Nesting nesting; + public static final HolyMoley holyMoley; + public static final CompactProtoTestStruct compactProtoTestStruct; + + static { + try { + oneOfEach = new OneOfEach(); + oneOfEach.im_true = true; + oneOfEach.im_false = false; + oneOfEach.a_bite = (byte) 0x03; + oneOfEach.integer16 = 27000; + oneOfEach.integer32 = 1 << 24; + oneOfEach.integer64 = (long) 6000 * 1000 * 1000; + oneOfEach.double_precision = Math.PI; + oneOfEach.some_characters = "JSON THIS! \"\1"; + oneOfEach.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + + nesting = new Nesting(new Bonk(), new OneOfEach()); + nesting.my_ooe.integer16 = 16; + nesting.my_ooe.integer32 = 32; + nesting.my_ooe.integer64 = 64; + nesting.my_ooe.double_precision = (Math.sqrt(5) + 1) / 2; + nesting.my_ooe.some_characters = ":R (me going \"rrrr\")"; + nesting.my_ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + nesting.my_bonk.type = 31337; + nesting.my_bonk.message = "I am a bonk... xor!"; + + holyMoley = new HolyMoley(); + + holyMoley.big = new ArrayList<OneOfEach>(); + holyMoley.big.add(new OneOfEach(oneOfEach)); + holyMoley.big.add(nesting.my_ooe); + holyMoley.big.get(0).a_bite = (byte) 0x22; + holyMoley.big.get(1).a_bite = (byte) 0x23; + + holyMoley.contain = new HashSet<List<String>>(); + ArrayList<String> stage1 = new ArrayList<String>(2); + stage1.add("and a one"); + stage1.add("and a two"); + holyMoley.contain.add(stage1); + stage1 = new ArrayList<String>(3); + stage1.add("then a one, two"); + stage1.add("three!"); + stage1.add("FOUR!!"); + holyMoley.contain.add(stage1); + stage1 = new ArrayList<String>(0); + holyMoley.contain.add(stage1); + + ArrayList<Bonk> stage2 = new ArrayList<Bonk>(); + holyMoley.bonks = new HashMap<String, List<Bonk>>(); + // one empty + holyMoley.bonks.put("nothing", stage2); + + // one with two + stage2 = new ArrayList<Bonk>(); + Bonk b = new Bonk(); + b.type = 1; + b.message = "Wait."; + stage2.add(b); + b = new Bonk(); + b.type = 2; + b.message = "What?"; + stage2.add(b); + holyMoley.bonks.put("something", stage2); + + // one with three + stage2 = new ArrayList<Bonk>(); + b = new Bonk(); + b.type = 3; + b.message = "quoth"; + b = new Bonk(); + b.type = 4; + b.message = "the raven"; + b = new Bonk(); + b.type = 5; + b.message = "nevermore"; + holyMoley.bonks.put("poe", stage2); + + // superhuge compact proto test struct + compactProtoTestStruct = new CompactProtoTestStruct(thrift.test.Constants.COMPACT_TEST); + compactProtoTestStruct.a_binary = new byte[]{0,1,2,3,4,5,6,7,8}; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + +} diff --git a/lib/java/test/org/apache/thrift/test/IdentityTest.java b/lib/java/test/org/apache/thrift/test/IdentityTest.java new file mode 100644 index 000000000..c6453ce7c --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/IdentityTest.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +// Generated code +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +import org.apache.thrift.TDeserializer; +import org.apache.thrift.TSerializer; +import org.apache.thrift.protocol.TBinaryProtocol; + +import thrift.test.Bonk; +import thrift.test.HolyMoley; +import thrift.test.Nesting; +import thrift.test.OneOfEach; + +/** + * + */ +public class IdentityTest { + public static Object deepCopy(Object oldObj) throws Exception { + ObjectOutputStream oos = null; + ObjectInputStream ois = null; + try { + ByteArrayOutputStream bos = + new ByteArrayOutputStream(); + oos = new ObjectOutputStream(bos); + oos.writeObject(oldObj); + oos.flush(); + ByteArrayInputStream bis = + new ByteArrayInputStream(bos.toByteArray()); + ois = new ObjectInputStream(bis); + return ois.readObject(); + } finally { + oos.close(); + ois.close(); + } + } + + public static void main(String[] args) throws Exception { + TSerializer binarySerializer = new TSerializer(new TBinaryProtocol.Factory()); + TDeserializer binaryDeserializer = new TDeserializer(new TBinaryProtocol.Factory()); + + OneOfEach ooe = new OneOfEach(); + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = (byte)0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (long)6000 * 1000 * 1000; + ooe.double_precision = Math.PI; + ooe.some_characters = "JSON THIS! \"\u0001"; + ooe.base64 = new byte[]{1,2,3,(byte)255}; + + Nesting n = new Nesting(); + n.my_ooe = (OneOfEach)deepCopy(ooe); + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (Math.sqrt(5)+1)/2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = "\u04c0\u216e\u039d\u0020\u041d\u03bf\u217f"+ + "\u043e\u0261\u0433\u0430\u03c1\u210e\u0020"+ + "\u0391\u0074\u0074\u03b1\u217d\u03ba\u01c3"+ + "\u203c"; + n.my_bonk = new Bonk(); + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + HolyMoley hm = new HolyMoley(); + hm.big = new ArrayList<OneOfEach>(); + hm.contain = new HashSet<List<String>>(); + hm.bonks = new HashMap<String,List<Bonk>>(); + + hm.big.add((OneOfEach)deepCopy(ooe)); + hm.big.add((OneOfEach)deepCopy(n.my_ooe)); + hm.big.get(0).a_bite = 0x22; + hm.big.get(1).a_bite = 0x33; + + List<String> stage1 = new ArrayList<String>(); + stage1.add("and a one"); + stage1.add("and a two"); + hm.contain.add(stage1); + stage1 = new ArrayList<String>(); + stage1.add("then a one, two"); + stage1.add("three!"); + stage1.add("FOUR!!"); + hm.contain.add(stage1); + stage1 = new ArrayList<String>(); + hm.contain.add(stage1); + + List<Bonk> stage2 = new ArrayList<Bonk>(); + hm.bonks.put("nothing", stage2); + stage2.add(new Bonk()); + stage2.get(0).type = 1; + stage2.get(0).message = "Wait."; + stage2.add(new Bonk()); + stage2.get(1).type = 2; + stage2.get(1).message = "What?"; + hm.bonks.put("something", stage2); + stage2 = new ArrayList<Bonk>(); + stage2.add(new Bonk()); + stage2.get(0).type = 3; + stage2.get(0).message = "quoth"; + stage2.add(new Bonk()); + stage2.get(1).type = 4; + stage2.get(1).message = "the raven"; + stage2.add(new Bonk()); + stage2.get(2).type = 5; + stage2.get(2).message = "nevermore"; + hm.bonks.put("poe", stage2); + + OneOfEach ooe2 = new OneOfEach(); + binaryDeserializer.deserialize( + ooe2, + binarySerializer.serialize(ooe)); + + if (!ooe.equals(ooe2)) { + throw new RuntimeException("Failure: ooe (equals)"); + } + if (ooe.hashCode() != ooe2.hashCode()) { + throw new RuntimeException("Failure: ooe (hash)"); + } + + + Nesting n2 = new Nesting(); + binaryDeserializer.deserialize( + n2, + binarySerializer.serialize(n)); + + if (!n.equals(n2)) { + throw new RuntimeException("Failure: n (equals)"); + } + if (n.hashCode() != n2.hashCode()) { + throw new RuntimeException("Failure: n (hash)"); + } + + HolyMoley hm2 = new HolyMoley(); + binaryDeserializer.deserialize( + hm2, + binarySerializer.serialize(hm)); + + if (!hm.equals(hm2)) { + throw new RuntimeException("Failure: hm (equals)"); + } + if (hm.hashCode() != hm2.hashCode()) { + throw new RuntimeException("Failure: hm (hash)"); + } + + } +} diff --git a/lib/java/test/org/apache/thrift/test/JSONProtoTest.java b/lib/java/test/org/apache/thrift/test/JSONProtoTest.java new file mode 100644 index 000000000..59f4ce182 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/JSONProtoTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +// Generated code +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +import org.apache.thrift.protocol.TJSONProtocol; +import org.apache.thrift.transport.TMemoryBuffer; + +import thrift.test.Base64; +import thrift.test.Bonk; +import thrift.test.HolyMoley; +import thrift.test.Nesting; +import thrift.test.OneOfEach; + +/** + * Tests for the Java implementation of TJSONProtocol. Mirrors the C++ version + * + */ +public class JSONProtoTest { + + private static final byte[] kUnicodeBytes = { + (byte)0xd3, (byte)0x80, (byte)0xe2, (byte)0x85, (byte)0xae, (byte)0xce, + (byte)0x9d, (byte)0x20, (byte)0xd0, (byte)0x9d, (byte)0xce, (byte)0xbf, + (byte)0xe2, (byte)0x85, (byte)0xbf, (byte)0xd0, (byte)0xbe, (byte)0xc9, + (byte)0xa1, (byte)0xd0, (byte)0xb3, (byte)0xd0, (byte)0xb0, (byte)0xcf, + (byte)0x81, (byte)0xe2, (byte)0x84, (byte)0x8e, (byte)0x20, (byte)0xce, + (byte)0x91, (byte)0x74, (byte)0x74, (byte)0xce, (byte)0xb1, (byte)0xe2, + (byte)0x85, (byte)0xbd, (byte)0xce, (byte)0xba, (byte)0x83, (byte)0xe2, + (byte)0x80, (byte)0xbc + }; + + public static void main(String [] args) throws Exception { + try { + System.out.println("In JSON Proto test"); + + OneOfEach ooe = new OneOfEach(); + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = (byte)0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (long)6000 * 1000 * 1000; + ooe.double_precision = Math.PI; + ooe.some_characters = "JSON THIS! \"\1"; + ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + + + Nesting n = new Nesting(new Bonk(), new OneOfEach()); + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (Math.sqrt(5)+1)/2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + HolyMoley hm = new HolyMoley(); + + hm.big = new ArrayList<OneOfEach>(); + hm.big.add(ooe); + hm.big.add(n.my_ooe); + hm.big.get(0).a_bite = (byte)0x22; + hm.big.get(1).a_bite = (byte)0x23; + + hm.contain = new HashSet<List<String>>(); + ArrayList<String> stage1 = new ArrayList<String>(2); + stage1.add("and a one"); + stage1.add("and a two"); + hm.contain.add(stage1); + stage1 = new ArrayList<String>(3); + stage1.add("then a one, two"); + stage1.add("three!"); + stage1.add("FOUR!!"); + hm.contain.add(stage1); + stage1 = new ArrayList<String>(0); + hm.contain.add(stage1); + + ArrayList<Bonk> stage2 = new ArrayList<Bonk>(); + hm.bonks = new HashMap<String, List<Bonk>>(); + hm.bonks.put("nothing", stage2); + Bonk b = new Bonk(); + b.type = 1; + b.message = "Wait."; + stage2.add(b); + b = new Bonk(); + b.type = 2; + b.message = "What?"; + stage2.add(b); + stage2 = new ArrayList<Bonk>(); + hm.bonks.put("something", stage2); + b = new Bonk(); + b.type = 3; + b.message = "quoth"; + b = new Bonk(); + b.type = 4; + b.message = "the raven"; + b = new Bonk(); + b.type = 5; + b.message = "nevermore"; + hm.bonks.put("poe", stage2); + + TMemoryBuffer buffer = new TMemoryBuffer(1024); + TJSONProtocol proto = new TJSONProtocol(buffer); + + System.out.println("Writing ooe"); + ooe.write(proto); + System.out.println("Reading ooe"); + OneOfEach ooe2 = new OneOfEach(); + ooe2.read(proto); + + System.out.println("Comparing ooe"); + if (!ooe.equals(ooe2)) { + throw new RuntimeException("ooe != ooe2"); + } + + System.out.println("Writing hm"); + hm.write(proto); + + System.out.println("Reading hm"); + HolyMoley hm2 = new HolyMoley(); + hm2.read(proto); + + System.out.println("Comparing hm"); + if (!hm.equals(hm2)) { + throw new RuntimeException("hm != hm2"); + } + + hm2.big.get(0).a_bite = (byte)0xFF; + if (hm.equals(hm2)) { + throw new RuntimeException("hm should not equal hm2"); + } + + Base64 base = new Base64(); + base.a = 123; + base.b1 = "1".getBytes("UTF-8"); + base.b2 = "12".getBytes("UTF-8"); + base.b3 = "123".getBytes("UTF-8"); + base.b4 = "1234".getBytes("UTF-8"); + base.b5 = "12345".getBytes("UTF-8"); + base.b6 = "123456".getBytes("UTF-8"); + + System.out.println("Writing base"); + base.write(proto); + + System.out.println("Reading base"); + Base64 base2 = new Base64(); + base2.read(proto); + + System.out.println("Comparing base"); + if (!base.equals(base2)) { + throw new RuntimeException("base != base2"); + } + + } catch (Exception ex) { + ex.printStackTrace(); + throw ex; + } + } + +} diff --git a/lib/java/test/org/apache/thrift/test/JavaBeansTest.java b/lib/java/test/org/apache/thrift/test/JavaBeansTest.java new file mode 100644 index 000000000..b72bd388c --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/JavaBeansTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import java.util.LinkedList; +import thrift.test.OneOfEachBeans; + +public class JavaBeansTest { + public static void main(String[] args) throws Exception { + // Test isSet methods + OneOfEachBeans ooe = new OneOfEachBeans(); + + // Nothing should be set + if (ooe.is_set_a_bite()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_base64()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_byte_list()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_double_precision()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_i16_list()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_i64_list()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_boolean_field()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_integer16()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_integer32()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_integer64()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_some_characters()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + + for (int i = 1; i < 12; i++){ + if (ooe.isSet(i)) + throw new RuntimeException("isSet method error: unset field " + i + " returned as set!"); + } + + // Everything is set + ooe.set_a_bite((byte) 1); + ooe.set_base64("bytes".getBytes()); + ooe.set_byte_list(new LinkedList<Byte>()); + ooe.set_double_precision(1); + ooe.set_i16_list(new LinkedList<Short>()); + ooe.set_i64_list(new LinkedList<Long>()); + ooe.set_boolean_field(true); + ooe.set_integer16((short) 1); + ooe.set_integer32(1); + ooe.set_integer64(1); + ooe.set_some_characters("string"); + + if (!ooe.is_set_a_bite()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_base64()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_byte_list()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_double_precision()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_i16_list()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_i64_list()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_boolean_field()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_integer16()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_integer32()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_integer64()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_some_characters()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + + for (int i = 1; i < 12; i++){ + if (!ooe.isSet(i)) + throw new RuntimeException("isSet method error: set field " + i + " returned as unset!"); + } + + // Should throw exception when field doesn't exist + boolean exceptionThrown = false; + try{ + if (ooe.isSet(100)); + } catch (IllegalArgumentException e){ + exceptionThrown = true; + } + if (!exceptionThrown) + throw new RuntimeException("isSet method error: non-existent field provided as agument but no exception thrown!"); + } +} diff --git a/lib/java/test/org/apache/thrift/test/MetaDataTest.java b/lib/java/test/org/apache/thrift/test/MetaDataTest.java new file mode 100644 index 000000000..a0180348f --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/MetaDataTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import java.util.Map; +import org.apache.thrift.TFieldRequirementType; +import org.apache.thrift.meta_data.FieldMetaData; +import org.apache.thrift.meta_data.ListMetaData; +import org.apache.thrift.meta_data.MapMetaData; +import org.apache.thrift.meta_data.SetMetaData; +import org.apache.thrift.meta_data.StructMetaData; +import org.apache.thrift.protocol.TType; +import thrift.test.*; + +public class MetaDataTest { + + public static void main(String[] args) throws Exception { + Map<Integer, FieldMetaData> mdMap = CrazyNesting.metaDataMap; + + // Check for struct fields existence + if (mdMap.size() != 3) + throw new RuntimeException("metadata map contains wrong number of entries!"); + if (!mdMap.containsKey(CrazyNesting.SET_FIELD) || !mdMap.containsKey(CrazyNesting.LIST_FIELD) || !mdMap.containsKey(CrazyNesting.STRING_FIELD)) + throw new RuntimeException("metadata map doesn't contain entry for a struct field!"); + + // Check for struct fields contents + if (!mdMap.get(CrazyNesting.STRING_FIELD).fieldName.equals("string_field") || + !mdMap.get(CrazyNesting.LIST_FIELD).fieldName.equals("list_field") || + !mdMap.get(CrazyNesting.SET_FIELD).fieldName.equals("set_field")) + throw new RuntimeException("metadata map contains a wrong fieldname"); + if (mdMap.get(CrazyNesting.STRING_FIELD).requirementType != TFieldRequirementType.DEFAULT || + mdMap.get(CrazyNesting.LIST_FIELD).requirementType != TFieldRequirementType.REQUIRED || + mdMap.get(CrazyNesting.SET_FIELD).requirementType != TFieldRequirementType.OPTIONAL) + throw new RuntimeException("metadata map contains the wrong requirement type for a field"); + if (mdMap.get(CrazyNesting.STRING_FIELD).valueMetaData.type != TType.STRING || + mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.type != TType.LIST || + mdMap.get(CrazyNesting.SET_FIELD).valueMetaData.type != TType.SET) + throw new RuntimeException("metadata map contains the wrong requirement type for a field"); + + // Check nested structures + if (!mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.isContainer()) + throw new RuntimeException("value metadata for a list is stored as non-container!"); + if (mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.isStruct()) + throw new RuntimeException("value metadata for a list is stored as a struct!"); + if (((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData.type != TType.STRUCT) + throw new RuntimeException("metadata map contains wrong type for a value in a deeply nested structure"); + if (((StructMetaData)((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData).structClass != Insanity.class) + throw new RuntimeException("metadata map contains wrong class for a struct in a deeply nested structure"); + + // Check that FieldMetaData contains a map with metadata for all generated struct classes + if (FieldMetaData.getStructMetaDataMap(CrazyNesting.class) == null || + FieldMetaData.getStructMetaDataMap(Insanity.class) == null || + FieldMetaData.getStructMetaDataMap(Xtruct.class) == null) + throw new RuntimeException("global metadata map doesn't contain an entry for a known struct"); + if (FieldMetaData.getStructMetaDataMap(CrazyNesting.class) != CrazyNesting.metaDataMap || + FieldMetaData.getStructMetaDataMap(Insanity.class) != Insanity.metaDataMap) + throw new RuntimeException("global metadata map contains wrong entry for a loaded struct"); + } +} diff --git a/lib/java/test/org/apache/thrift/test/OverloadNonblockingServer.java b/lib/java/test/org/apache/thrift/test/OverloadNonblockingServer.java new file mode 100644 index 000000000..54d78e56d --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/OverloadNonblockingServer.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.transport.TSocket; + + +public class OverloadNonblockingServer { + + public static void main(String[] args) throws Exception { + int msg_size_mb = Integer.parseInt(args[0]); + int msg_size = msg_size_mb * 1024 * 1024; + + TSocket socket = new TSocket("localhost", 9090); + TBinaryProtocol binprot = new TBinaryProtocol(socket); + socket.open(); + binprot.writeI32(msg_size); + binprot.writeI32(1); + socket.flush(); + + System.in.read(); + // Thread.sleep(30000); + for (int i = 0; i < msg_size_mb; i++) { + binprot.writeBinary(new byte[1024 * 1024]); + } + + socket.close(); + } +} diff --git a/lib/java/test/org/apache/thrift/test/ReadStruct.java b/lib/java/test/org/apache/thrift/test/ReadStruct.java new file mode 100644 index 000000000..2dc042c56 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/ReadStruct.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import java.io.BufferedInputStream; +import java.io.FileInputStream; + +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TIOStreamTransport; +import org.apache.thrift.transport.TTransport; + +import thrift.test.CompactProtoTestStruct; + +public class ReadStruct { + public static void main(String[] args) throws Exception { + if (args.length != 2) { + System.out.println("usage: java -cp build/classes org.apache.thrift.test.ReadStruct filename proto_factory_class"); + System.out.println("Read in an instance of CompactProtocolTestStruct from 'file', making sure that it is equivalent to Fixtures.compactProtoTestStruct. Use a protocol from 'proto_factory_class'."); + } + + TTransport trans = new TIOStreamTransport(new BufferedInputStream(new FileInputStream(args[0]))); + + TProtocolFactory factory = (TProtocolFactory)Class.forName(args[1]).newInstance(); + + TProtocol proto = factory.getProtocol(trans); + + CompactProtoTestStruct cpts = new CompactProtoTestStruct(); + + for (Integer fid : CompactProtoTestStruct.metaDataMap.keySet()) { + cpts.setFieldValue(fid, null); + } + + cpts.read(proto); + + if (cpts.equals(Fixtures.compactProtoTestStruct)) { + System.out.println("Object verified successfully!"); + } else { + System.out.println("Object failed verification!"); + System.out.println("Expected: " + Fixtures.compactProtoTestStruct + " but got " + cpts); + } + + } + +} diff --git a/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java b/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java new file mode 100644 index 000000000..b83b2f9b6 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import java.io.ByteArrayInputStream; + +import org.apache.thrift.*; +import org.apache.thrift.protocol.*; +import org.apache.thrift.transport.*; + +import thrift.test.*; + +public class SerializationBenchmark { + private final static int HOW_MANY = 10000000; + + public static void main(String[] args) throws Exception { + TProtocolFactory factory = new TBinaryProtocol.Factory(); + + OneOfEach ooe = new OneOfEach(); + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = (byte)0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (long)6000 * 1000 * 1000; + ooe.double_precision = Math.PI; + ooe.some_characters = "JSON THIS! \"\u0001"; + ooe.base64 = new byte[]{1,2,3,(byte)255}; + + testSerialization(factory, ooe); + testDeserialization(factory, ooe, OneOfEach.class); + } + + public static void testSerialization(TProtocolFactory factory, TBase object) throws Exception { + TTransport trans = new TTransport() { + public void write(byte[] bin, int x, int y) throws TTransportException {} + public int read(byte[] bin, int x, int y) throws TTransportException {return 0;} + public void close() {} + public void open() {} + public boolean isOpen() {return true;} + }; + + TProtocol proto = factory.getProtocol(trans); + + long startTime = System.currentTimeMillis(); + for (int i = 0; i < HOW_MANY; i++) { + object.write(proto); + } + long endTime = System.currentTimeMillis(); + + System.out.println("Test time: " + (endTime - startTime) + " ms"); + } + + public static <T extends TBase> void testDeserialization(TProtocolFactory factory, T object, Class<T> klass) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + object.write(factory.getProtocol(buf)); + byte[] serialized = new byte[100*1024]; + buf.read(serialized, 0, 100*1024); + + long startTime = System.currentTimeMillis(); + for (int i = 0; i < HOW_MANY; i++) { + T o2 = klass.newInstance(); + o2.read(factory.getProtocol(new TIOStreamTransport(new ByteArrayInputStream(serialized)))); + } + long endTime = System.currentTimeMillis(); + + System.out.println("Test time: " + (endTime - startTime) + " ms"); + } +}
\ No newline at end of file diff --git a/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java b/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java new file mode 100755 index 000000000..5d1c3cba1 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import java.util.Arrays; +import java.util.List; + +import org.apache.thrift.TBase; +import org.apache.thrift.TException; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TCompactProtocol; +import org.apache.thrift.protocol.TField; +import org.apache.thrift.protocol.TMessage; +import org.apache.thrift.protocol.TMessageType; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.protocol.TStruct; +import org.apache.thrift.protocol.TType; +import org.apache.thrift.transport.TMemoryBuffer; + +import thrift.test.CompactProtoTestStruct; +import thrift.test.HolyMoley; +import thrift.test.Nesting; +import thrift.test.OneOfEach; +import thrift.test.Srv; + +public class TCompactProtocolTest { + + static TProtocolFactory factory = new TCompactProtocol.Factory(); + + public static void main(String[] args) throws Exception { + testNakedByte(); + for (int i = 0; i < 128; i++) { + testByteField((byte)i); + testByteField((byte)-i); + } + + testNakedI16((short)0); + testNakedI16((short)1); + testNakedI16((short)15000); + testNakedI16((short)0x7fff); + testNakedI16((short)-1); + testNakedI16((short)-15000); + testNakedI16((short)-0x7fff); + + testI16Field((short)0); + testI16Field((short)1); + testI16Field((short)7); + testI16Field((short)150); + testI16Field((short)15000); + testI16Field((short)0x7fff); + testI16Field((short)-1); + testI16Field((short)-7); + testI16Field((short)-150); + testI16Field((short)-15000); + testI16Field((short)-0x7fff); + + testNakedI32(0); + testNakedI32(1); + testNakedI32(15000); + testNakedI32(0xffff); + testNakedI32(-1); + testNakedI32(-15000); + testNakedI32(-0xffff); + + testI32Field(0); + testI32Field(1); + testI32Field(7); + testI32Field(150); + testI32Field(15000); + testI32Field(31337); + testI32Field(0xffff); + testI32Field(0xffffff); + testI32Field(-1); + testI32Field(-7); + testI32Field(-150); + testI32Field(-15000); + testI32Field(-0xffff); + testI32Field(-0xffffff); + + testNakedI64(0); + for (int i = 0; i < 62; i++) { + testNakedI64(1L << i); + testNakedI64(-(1L << i)); + } + + testI64Field(0); + for (int i = 0; i < 62; i++) { + testI64Field(1L << i); + testI64Field(-(1L << i)); + } + + testDouble(); + + testNakedString(""); + testNakedString("short"); + testNakedString("borderlinetiny"); + testNakedString("a bit longer than the smallest possible"); + + testStringField(""); + testStringField("short"); + testStringField("borderlinetiny"); + testStringField("a bit longer than the smallest possible"); + + testNakedBinary(new byte[]{}); + testNakedBinary(new byte[]{0,1,2,3,4,5,6,7,8,9,10}); + testNakedBinary(new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14}); + testNakedBinary(new byte[128]); + + testBinaryField(new byte[]{}); + testBinaryField(new byte[]{0,1,2,3,4,5,6,7,8,9,10}); + testBinaryField(new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14}); + testBinaryField(new byte[128]); + + testSerialization(OneOfEach.class, Fixtures.oneOfEach); + testSerialization(Nesting.class, Fixtures.nesting); + testSerialization(HolyMoley.class, Fixtures.holyMoley); + testSerialization(CompactProtoTestStruct.class, Fixtures.compactProtoTestStruct); + + testMessage(); + + testServerRequest(); + } + + public static void testNakedByte() throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeByte((byte)123); + byte out = proto.readByte(); + if (out != 123) { + throw new RuntimeException("Byte was supposed to be " + (byte)123 + " but was " + out); + } + } + + public static void testByteField(final byte b) throws Exception { + testStructField(new StructFieldTestCase(TType.BYTE, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeByte(b); + } + + public void readMethod(TProtocol proto) throws TException { + byte result = proto.readByte(); + if (result != b) { + throw new RuntimeException("Byte was supposed to be " + (byte)b + " but was " + result); + } + } + }); + } + + public static void testNakedI16(short n) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeI16(n); + // System.out.println(buf.inspect()); + int out = proto.readI16(); + if (out != n) { + throw new RuntimeException("I16 was supposed to be " + n + " but was " + out); + } + } + + public static void testI16Field(final short n) throws Exception { + testStructField(new StructFieldTestCase(TType.I16, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeI16(n); + } + + public void readMethod(TProtocol proto) throws TException { + short result = proto.readI16(); + if (result != n) { + throw new RuntimeException("I16 was supposed to be " + n + " but was " + result); + } + } + }); + } + + public static void testNakedI32(int n) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeI32(n); + // System.out.println(buf.inspect()); + int out = proto.readI32(); + if (out != n) { + throw new RuntimeException("I32 was supposed to be " + n + " but was " + out); + } + } + + public static void testI32Field(final int n) throws Exception { + testStructField(new StructFieldTestCase(TType.I32, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeI32(n); + } + + public void readMethod(TProtocol proto) throws TException { + int result = proto.readI32(); + if (result != n) { + throw new RuntimeException("I32 was supposed to be " + n + " but was " + result); + } + } + }); + + } + + public static void testNakedI64(long n) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeI64(n); + // System.out.println(buf.inspect()); + long out = proto.readI64(); + if (out != n) { + throw new RuntimeException("I64 was supposed to be " + n + " but was " + out); + } + } + + public static void testI64Field(final long n) throws Exception { + testStructField(new StructFieldTestCase(TType.I64, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeI64(n); + } + + public void readMethod(TProtocol proto) throws TException { + long result = proto.readI64(); + if (result != n) { + throw new RuntimeException("I64 was supposed to be " + n + " but was " + result); + } + } + }); + } + + public static void testDouble() throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(1000); + TProtocol proto = factory.getProtocol(buf); + proto.writeDouble(123.456); + double out = proto.readDouble(); + if (out != 123.456) { + throw new RuntimeException("Double was supposed to be " + 123.456 + " but was " + out); + } + } + + public static void testNakedString(String str) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeString(str); + // System.out.println(buf.inspect()); + String out = proto.readString(); + if (!str.equals(out)) { + throw new RuntimeException("String was supposed to be '" + str + "' but was '" + out + "'"); + } + } + + public static void testStringField(final String str) throws Exception { + testStructField(new StructFieldTestCase(TType.STRING, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeString(str); + } + + public void readMethod(TProtocol proto) throws TException { + String result = proto.readString(); + if (!result.equals(str)) { + throw new RuntimeException("String was supposed to be " + str + " but was " + result); + } + } + }); + } + + public static void testNakedBinary(byte[] data) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeBinary(data); + // System.out.println(buf.inspect()); + byte[] out = proto.readBinary(); + if (!Arrays.equals(data, out)) { + throw new RuntimeException("Binary was supposed to be '" + data + "' but was '" + out + "'"); + } + } + + public static void testBinaryField(final byte[] data) throws Exception { + testStructField(new StructFieldTestCase(TType.STRING, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeBinary(data); + } + + public void readMethod(TProtocol proto) throws TException { + byte[] result = proto.readBinary(); + if (!Arrays.equals(data, result)) { + throw new RuntimeException("Binary was supposed to be '" + bytesToString(data) + "' but was '" + bytesToString(result) + "'"); + } + } + }); + + } + + public static <T extends TBase> void testSerialization(Class<T> klass, T obj) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TBinaryProtocol binproto = new TBinaryProtocol(buf); + + try { + obj.write(binproto); + // System.out.println("Size in binary protocol: " + buf.length()); + + buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + + obj.write(proto); + System.out.println("Size in compact protocol: " + buf.length()); + // System.out.println(buf.inspect()); + + T objRead = klass.newInstance(); + objRead.read(proto); + if (!obj.equals(objRead)) { + System.out.println("Expected: " + obj.toString()); + System.out.println("Actual: " + objRead.toString()); + // System.out.println(buf.inspect()); + throw new RuntimeException("Objects didn't match!"); + } + } catch (Exception e) { + System.out.println(buf.inspect()); + throw e; + } + } + + public static void testMessage() throws Exception { + List<TMessage> msgs = Arrays.asList(new TMessage[]{ + new TMessage("short message name", TMessageType.CALL, 0), + new TMessage("1", TMessageType.REPLY, 12345), + new TMessage("loooooooooooooooooooooooooooooooooong", TMessageType.EXCEPTION, 1 << 16), + new TMessage("Janky", TMessageType.CALL, 0), + }); + + for (TMessage msg : msgs) { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + TMessage output = null; + + proto.writeMessageBegin(msg); + proto.writeMessageEnd(); + + output = proto.readMessageBegin(); + + if (!msg.equals(output)) { + throw new RuntimeException("Message was supposed to be " + msg + " but was " + output); + } + } + } + + public static void testServerRequest() throws Exception { + Srv.Iface handler = new Srv.Iface() { + public int Janky(int i32arg) throws TException { + return i32arg * 2; + } + + public int primitiveMethod() throws TException { + // TODO Auto-generated method stub + return 0; + } + + public CompactProtoTestStruct structMethod() throws TException { + // TODO Auto-generated method stub + return null; + } + + public void voidMethod() throws TException { + // TODO Auto-generated method stub + + } + }; + + Srv.Processor testProcessor = new Srv.Processor(handler); + + TMemoryBuffer clientOutTrans = new TMemoryBuffer(0); + TProtocol clientOutProto = factory.getProtocol(clientOutTrans); + TMemoryBuffer clientInTrans = new TMemoryBuffer(0); + TProtocol clientInProto = factory.getProtocol(clientInTrans); + + Srv.Client testClient = new Srv.Client(clientInProto, clientOutProto); + + testClient.send_Janky(1); + // System.out.println(clientOutTrans.inspect()); + testProcessor.process(clientOutProto, clientInProto); + // System.out.println(clientInTrans.inspect()); + int result = testClient.recv_Janky(); + if (result != 2) { + throw new RuntimeException("Got an unexpected result: " + result); + } + } + + // + // Helper methods + // + + private static String bytesToString(byte[] bytes) { + String s = ""; + for (int i = 0; i < bytes.length; i++) { + s += Integer.toHexString((int)bytes[i]) + " "; + } + return s; + } + + private static void testStructField(StructFieldTestCase testCase) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + + TField field = new TField("test_field", testCase.type_, testCase.id_); + proto.writeStructBegin(new TStruct("test_struct")); + proto.writeFieldBegin(field); + testCase.writeMethod(proto); + proto.writeFieldEnd(); + proto.writeStructEnd(); + + // System.out.println(buf.inspect()); + + proto.readStructBegin(); + TField readField = proto.readFieldBegin(); + // TODO: verify the field is as expected + if (!field.equals(readField)) { + throw new RuntimeException("Expected " + field + " but got " + readField); + } + testCase.readMethod(proto); + proto.readStructEnd(); + } + + public static abstract class StructFieldTestCase { + byte type_; + short id_; + public StructFieldTestCase(byte type, short id) { + type_ = type; + id_ = id; + } + + public abstract void writeMethod(TProtocol proto) throws TException; + public abstract void readMethod(TProtocol proto) throws TException; + } +}
\ No newline at end of file diff --git a/lib/java/test/org/apache/thrift/test/TestClient.java b/lib/java/test/org/apache/thrift/test/TestClient.java new file mode 100644 index 000000000..4a95f7a6d --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/TestClient.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +// Generated code +import thrift.test.*; + +import org.apache.thrift.TApplicationException; +import org.apache.thrift.TSerializer; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TSocket; +import org.apache.thrift.transport.THttpClient; +import org.apache.thrift.transport.TFramedTransport; +import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TSimpleJSONProtocol; + +import java.util.Map; +import java.util.HashMap; +import java.util.Set; +import java.util.HashSet; +import java.util.List; +import java.util.ArrayList; + +/** + * Test Java client for thrift. Essentially just a copy of the C++ version, + * this makes a variety of requests to enable testing for both performance and + * correctness of the output. + * + */ +public class TestClient { + public static void main(String [] args) { + try { + String host = "localhost"; + int port = 9090; + String url = null; + int numTests = 1; + boolean framed = false; + + int socketTimeout = 1000; + + try { + for (int i = 0; i < args.length; ++i) { + if (args[i].equals("-h")) { + String[] hostport = (args[++i]).split(":"); + host = hostport[0]; + port = Integer.valueOf(hostport[1]); + } else if (args[i].equals("-f") || args[i].equals("-framed")) { + framed = true; + } else if (args[i].equals("-u")) { + url = args[++i]; + } else if (args[i].equals("-n")) { + numTests = Integer.valueOf(args[++i]); + } else if (args[i].equals("-timeout")) { + socketTimeout = Integer.valueOf(args[++i]); + } + } + } catch (Exception x) { + x.printStackTrace(); + } + + TTransport transport; + + if (url != null) { + transport = new THttpClient(url); + } else { + TSocket socket = new TSocket(host, port); + socket.setTimeout(socketTimeout); + transport = socket; + if (framed) { + transport = new TFramedTransport(transport); + } + } + + TBinaryProtocol binaryProtocol = + new TBinaryProtocol(transport); + ThriftTest.Client testClient = + new ThriftTest.Client(binaryProtocol); + Insanity insane = new Insanity(); + + long timeMin = 0; + long timeMax = 0; + long timeTot = 0; + + for (int test = 0; test < numTests; ++test) { + + /** + * CONNECT TEST + */ + System.out.println("Test #" + (test+1) + ", " + "connect " + host + ":" + port); + try { + transport.open(); + } catch (TTransportException ttx) { + System.out.println("Connect failed: " + ttx.getMessage()); + continue; + } + + long start = System.nanoTime(); + + /** + * VOID TEST + */ + try { + System.out.print("testVoid()"); + testClient.testVoid(); + System.out.print(" = void\n"); + } catch (TApplicationException tax) { + tax.printStackTrace(); + } + + /** + * STRING TEST + */ + System.out.print("testString(\"Test\")"); + String s = testClient.testString("Test"); + System.out.print(" = \"" + s + "\"\n"); + + /** + * BYTE TEST + */ + System.out.print("testByte(1)"); + byte i8 = testClient.testByte((byte)1); + System.out.print(" = " + i8 + "\n"); + + /** + * I32 TEST + */ + System.out.print("testI32(-1)"); + int i32 = testClient.testI32(-1); + System.out.print(" = " + i32 + "\n"); + + /** + * I64 TEST + */ + System.out.print("testI64(-34359738368)"); + long i64 = testClient.testI64(-34359738368L); + System.out.print(" = " + i64 + "\n"); + + /** + * DOUBLE TEST + */ + System.out.print("testDouble(5.325098235)"); + double dub = testClient.testDouble(5.325098235); + System.out.print(" = " + dub + "\n"); + + /** + * STRUCT TEST + */ + System.out.print("testStruct({\"Zero\", 1, -3, -5})"); + Xtruct out = new Xtruct(); + out.string_thing = "Zero"; + out.byte_thing = (byte) 1; + out.i32_thing = -3; + out.i64_thing = -5; + Xtruct in = testClient.testStruct(out); + System.out.print(" = {" + "\"" + in.string_thing + "\", " + in.byte_thing + ", " + in.i32_thing + ", " + in.i64_thing + "}\n"); + + /** + * NESTED STRUCT TEST + */ + System.out.print("testNest({1, {\"Zero\", 1, -3, -5}), 5}"); + Xtruct2 out2 = new Xtruct2(); + out2.byte_thing = (short)1; + out2.struct_thing = out; + out2.i32_thing = 5; + Xtruct2 in2 = testClient.testNest(out2); + in = in2.struct_thing; + System.out.print(" = {" + in2.byte_thing + ", {" + "\"" + in.string_thing + "\", " + in.byte_thing + ", " + in.i32_thing + ", " + in.i64_thing + "}, " + in2.i32_thing + "}\n"); + + /** + * MAP TEST + */ + Map<Integer,Integer> mapout = new HashMap<Integer,Integer>(); + for (int i = 0; i < 5; ++i) { + mapout.put(i, i-10); + } + System.out.print("testMap({"); + boolean first = true; + for (int key : mapout.keySet()) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(key + " => " + mapout.get(key)); + } + System.out.print("})"); + Map<Integer,Integer> mapin = testClient.testMap(mapout); + System.out.print(" = {"); + first = true; + for (int key : mapin.keySet()) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(key + " => " + mapout.get(key)); + } + System.out.print("}\n"); + + /** + * SET TEST + */ + Set<Integer> setout = new HashSet<Integer>(); + for (int i = -2; i < 3; ++i) { + setout.add(i); + } + System.out.print("testSet({"); + first = true; + for (int elem : setout) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("})"); + Set<Integer> setin = testClient.testSet(setout); + System.out.print(" = {"); + first = true; + for (int elem : setin) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("}\n"); + + /** + * LIST TEST + */ + List<Integer> listout = new ArrayList<Integer>(); + for (int i = -2; i < 3; ++i) { + listout.add(i); + } + System.out.print("testList({"); + first = true; + for (int elem : listout) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("})"); + List<Integer> listin = testClient.testList(listout); + System.out.print(" = {"); + first = true; + for (int elem : listin) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("}\n"); + + /** + * ENUM TEST + */ + System.out.print("testEnum(ONE)"); + int ret = testClient.testEnum(Numberz.ONE); + System.out.print(" = " + ret + "\n"); + + System.out.print("testEnum(TWO)"); + ret = testClient.testEnum(Numberz.TWO); + System.out.print(" = " + ret + "\n"); + + System.out.print("testEnum(THREE)"); + ret = testClient.testEnum(Numberz.THREE); + System.out.print(" = " + ret + "\n"); + + System.out.print("testEnum(FIVE)"); + ret = testClient.testEnum(Numberz.FIVE); + System.out.print(" = " + ret + "\n"); + + System.out.print("testEnum(EIGHT)"); + ret = testClient.testEnum(Numberz.EIGHT); + System.out.print(" = " + ret + "\n"); + + /** + * TYPEDEF TEST + */ + System.out.print("testTypedef(309858235082523)"); + long uid = testClient.testTypedef(309858235082523L); + System.out.print(" = " + uid + "\n"); + + /** + * NESTED MAP TEST + */ + System.out.print("testMapMap(1)"); + Map<Integer,Map<Integer,Integer>> mm = + testClient.testMapMap(1); + System.out.print(" = {"); + for (int key : mm.keySet()) { + System.out.print(key + " => {"); + Map<Integer,Integer> m2 = mm.get(key); + for (int k2 : m2.keySet()) { + System.out.print(k2 + " => " + m2.get(k2) + ", "); + } + System.out.print("}, "); + } + System.out.print("}\n"); + + /** + * INSANITY TEST + */ + insane = new Insanity(); + insane.userMap = new HashMap<Integer, Long>(); + insane.userMap.put(Numberz.FIVE, (long)5000); + Xtruct truck = new Xtruct(); + truck.string_thing = "Truck"; + truck.byte_thing = (byte)8; + truck.i32_thing = 8; + truck.i64_thing = 8; + insane.xtructs = new ArrayList<Xtruct>(); + insane.xtructs.add(truck); + System.out.print("testInsanity()"); + Map<Long,Map<Integer,Insanity>> whoa = + testClient.testInsanity(insane); + System.out.print(" = {"); + for (long key : whoa.keySet()) { + Map<Integer,Insanity> val = whoa.get(key); + System.out.print(key + " => {"); + + for (int k2 : val.keySet()) { + Insanity v2 = val.get(k2); + System.out.print(k2 + " => {"); + Map<Integer, Long> userMap = v2.userMap; + System.out.print("{"); + if (userMap != null) { + for (int k3 : userMap.keySet()) { + System.out.print(k3 + " => " + userMap.get(k3) + ", "); + } + } + System.out.print("}, "); + + List<Xtruct> xtructs = v2.xtructs; + System.out.print("{"); + if (xtructs != null) { + for (Xtruct x : xtructs) { + System.out.print("{" + "\"" + x.string_thing + "\", " + x.byte_thing + ", " + x.i32_thing + ", "+ x.i64_thing + "}, "); + } + } + System.out.print("}"); + + System.out.print("}, "); + } + System.out.print("}, "); + } + System.out.print("}\n"); + + // Test oneway + System.out.print("testOneway(3)..."); + long startOneway = System.nanoTime(); + testClient.testOneway(3); + long onewayElapsedMillis = (System.nanoTime() - startOneway) / 1000000; + if (onewayElapsedMillis > 200) { + throw new Exception("Oneway test failed: took " + + Long.toString(onewayElapsedMillis) + + "ms"); + } else { + System.out.println("Success - took " + + Long.toString(onewayElapsedMillis) + + "ms"); + } + + + long stop = System.nanoTime(); + long tot = stop-start; + + System.out.println("Total time: " + tot/1000 + "us"); + + if (timeMin == 0 || tot < timeMin) { + timeMin = tot; + } + if (tot > timeMax) { + timeMax = tot; + } + timeTot += tot; + + transport.close(); + } + + long timeAvg = timeTot / numTests; + + System.out.println("Min time: " + timeMin/1000 + "us"); + System.out.println("Max time: " + timeMax/1000 + "us"); + System.out.println("Avg time: " + timeAvg/1000 + "us"); + + String json = (new TSerializer(new TSimpleJSONProtocol.Factory())).toString(insane); + + System.out.println("\nFor good meausre here is some JSON:\n" + json); + + } catch (Exception x) { + x.printStackTrace(); + } + + } + +} diff --git a/lib/java/test/org/apache/thrift/test/TestNonblockingServer.java b/lib/java/test/org/apache/thrift/test/TestNonblockingServer.java new file mode 100644 index 000000000..2e6a1780f --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/TestNonblockingServer.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import org.apache.thrift.server.THsHaServer; +import org.apache.thrift.server.TNonblockingServer; +import org.apache.thrift.server.TServer; +import org.apache.thrift.transport.TNonblockingServerSocket; + +import thrift.test.ThriftTest; + + +public class TestNonblockingServer extends TestServer { + public static void main(String [] args) { + try { + int port = 9090; + boolean hsha = false; + + for (int i = 0; i < args.length; i++) { + if (args[i].equals("-p")) { + port = Integer.valueOf(args[i++]); + } else if (args[i].equals("-hsha")) { + hsha = true; + } + } + + // Processor + TestHandler testHandler = + new TestHandler(); + ThriftTest.Processor testProcessor = + new ThriftTest.Processor(testHandler); + + // Transport + TNonblockingServerSocket tServerSocket = + new TNonblockingServerSocket(port); + + TServer serverEngine; + + if (hsha) { + // HsHa Server + serverEngine = new THsHaServer(testProcessor, tServerSocket); + } else { + // Nonblocking Server + serverEngine = new TNonblockingServer(testProcessor, tServerSocket); + } + + // Run it + System.out.println("Starting the server on port " + port + "..."); + serverEngine.serve(); + + } catch (Exception x) { + x.printStackTrace(); + } + System.out.println("done."); + } +} diff --git a/lib/java/test/org/apache/thrift/test/TestServer.java b/lib/java/test/org/apache/thrift/test/TestServer.java new file mode 100644 index 000000000..986f88906 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/TestServer.java @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.server.TServer; +import org.apache.thrift.server.TThreadPoolServer; +import org.apache.thrift.transport.TServerSocket; + +import thrift.test.Insanity; +import thrift.test.Numberz; +import thrift.test.ThriftTest; +import thrift.test.Xception; +import thrift.test.Xception2; +import thrift.test.Xtruct; +import thrift.test.Xtruct2; + +public class TestServer { + + public static class TestHandler implements ThriftTest.Iface { + + public TestHandler() {} + + public void testVoid() { + System.out.print("testVoid()\n"); + } + + public String testString(String thing) { + System.out.print("testString(\"" + thing + "\")\n"); + return thing; + } + + public byte testByte(byte thing) { + System.out.print("testByte(" + thing + ")\n"); + return thing; + } + + public int testI32(int thing) { + System.out.print("testI32(" + thing + ")\n"); + return thing; + } + + public long testI64(long thing) { + System.out.print("testI64(" + thing + ")\n"); + return thing; + } + + public double testDouble(double thing) { + System.out.print("testDouble(" + thing + ")\n"); + return thing; + } + + public Xtruct testStruct(Xtruct thing) { + System.out.print("testStruct({" + + "\"" + thing.string_thing + "\", " + + thing.byte_thing + ", " + + thing.i32_thing + ", " + + thing.i64_thing + "})\n"); + return thing; + } + + public Xtruct2 testNest(Xtruct2 nest) { + Xtruct thing = nest.struct_thing; + System.out.print("testNest({" + + nest.byte_thing + ", {" + + "\"" + thing.string_thing + "\", " + + thing.byte_thing + ", " + + thing.i32_thing + ", " + + thing.i64_thing + "}, " + + nest.i32_thing + "})\n"); + return nest; + } + + public Map<Integer,Integer> testMap(Map<Integer,Integer> thing) { + System.out.print("testMap({"); + boolean first = true; + for (int key : thing.keySet()) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(key + " => " + thing.get(key)); + } + System.out.print("})\n"); + return thing; + } + + public Set<Integer> testSet(Set<Integer> thing) { + System.out.print("testSet({"); + boolean first = true; + for (int elem : thing) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("})\n"); + return thing; + } + + public List<Integer> testList(List<Integer> thing) { + System.out.print("testList({"); + boolean first = true; + for (int elem : thing) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("})\n"); + return thing; + } + + public int testEnum(int thing) { + System.out.print("testEnum(" + thing + ")\n"); + return thing; + } + + public long testTypedef(long thing) { + System.out.print("testTypedef(" + thing + ")\n"); + return thing; + } + + public Map<Integer,Map<Integer,Integer>> testMapMap(int hello) { + System.out.print("testMapMap(" + hello + ")\n"); + Map<Integer,Map<Integer,Integer>> mapmap = + new HashMap<Integer,Map<Integer,Integer>>(); + + HashMap<Integer,Integer> pos = new HashMap<Integer,Integer>(); + HashMap<Integer,Integer> neg = new HashMap<Integer,Integer>(); + for (int i = 1; i < 5; i++) { + pos.put(i, i); + neg.put(-i, -i); + } + + mapmap.put(4, pos); + mapmap.put(-4, neg); + + return mapmap; + } + + public Map<Long, Map<Integer,Insanity>> testInsanity(Insanity argument) { + System.out.print("testInsanity()\n"); + + Xtruct hello = new Xtruct(); + hello.string_thing = "Hello2"; + hello.byte_thing = 2; + hello.i32_thing = 2; + hello.i64_thing = 2; + + Xtruct goodbye = new Xtruct(); + goodbye.string_thing = "Goodbye4"; + goodbye.byte_thing = (byte)4; + goodbye.i32_thing = 4; + goodbye.i64_thing = (long)4; + + Insanity crazy = new Insanity(); + crazy.userMap = new HashMap<Integer, Long>(); + crazy.xtructs = new ArrayList<Xtruct>(); + + crazy.userMap.put(Numberz.EIGHT, (long)8); + crazy.xtructs.add(goodbye); + + Insanity looney = new Insanity(); + crazy.userMap.put(Numberz.FIVE, (long)5); + crazy.xtructs.add(hello); + + HashMap<Integer,Insanity> first_map = new HashMap<Integer, Insanity>(); + HashMap<Integer,Insanity> second_map = new HashMap<Integer, Insanity>();; + + first_map.put(Numberz.TWO, crazy); + first_map.put(Numberz.THREE, crazy); + + second_map.put(Numberz.SIX, looney); + + Map<Long,Map<Integer,Insanity>> insane = + new HashMap<Long, Map<Integer,Insanity>>(); + insane.put((long)1, first_map); + insane.put((long)2, second_map); + + return insane; + } + + public Xtruct testMulti(byte arg0, int arg1, long arg2, Map<Short,String> arg3, int arg4, long arg5) { + System.out.print("testMulti()\n"); + + Xtruct hello = new Xtruct();; + hello.string_thing = "Hello2"; + hello.byte_thing = arg0; + hello.i32_thing = arg1; + hello.i64_thing = arg2; + return hello; + } + + public void testException(String arg) throws Xception { + System.out.print("testException("+arg+")\n"); + if (arg.equals("Xception")) { + Xception x = new Xception(); + x.errorCode = 1001; + x.message = "This is an Xception"; + throw x; + } + return; + } + + public Xtruct testMultiException(String arg0, String arg1) throws Xception, Xception2 { + System.out.print("testMultiException(" + arg0 + ", " + arg1 + ")\n"); + if (arg0.equals("Xception")) { + Xception x = new Xception(); + x.errorCode = 1001; + x.message = "This is an Xception"; + throw x; + } else if (arg0.equals("Xception2")) { + Xception2 x = new Xception2(); + x.errorCode = 2002; + x.struct_thing = new Xtruct(); + x.struct_thing.string_thing = "This is an Xception2"; + throw x; + } + + Xtruct result = new Xtruct(); + result.string_thing = arg1; + return result; + } + + public void testOneway(int sleepFor) { + System.out.println("testOneway(" + Integer.toString(sleepFor) + + ") => sleeping..."); + try { + Thread.sleep(sleepFor * 1000); + System.out.println("Done sleeping!"); + } catch (InterruptedException ie) { + throw new RuntimeException(ie); + } + } + + } // class TestHandler + + public static void main(String [] args) { + try { + int port = 9090; + if (args.length > 1) { + port = Integer.valueOf(args[0]); + } + + // Processor + TestHandler testHandler = + new TestHandler(); + ThriftTest.Processor testProcessor = + new ThriftTest.Processor(testHandler); + + // Transport + TServerSocket tServerSocket = + new TServerSocket(port); + + // Protocol factory + TProtocolFactory tProtocolFactory = + new TBinaryProtocol.Factory(); + + TServer serverEngine; + + // Simple Server + // serverEngine = new TSimpleServer(testProcessor, tServerSocket); + + // ThreadPool Server + serverEngine = new TThreadPoolServer(testProcessor, tServerSocket, tProtocolFactory); + + // Run it + System.out.println("Starting the server on port " + port + "..."); + serverEngine.serve(); + + } catch (Exception x) { + x.printStackTrace(); + } + System.out.println("done."); + } +} diff --git a/lib/java/test/org/apache/thrift/test/ToStringTest.java b/lib/java/test/org/apache/thrift/test/ToStringTest.java new file mode 100644 index 000000000..569a61c47 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/ToStringTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import thrift.test.*; + +/** + */ +public class ToStringTest { + public static void main(String[] args) throws Exception { + JavaTestHelper object = new JavaTestHelper(); + object.req_int = 0; + object.req_obj = ""; + + + object.req_bin = new byte[] { + 0, -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12, -13, 14, -15, + 16, -17, 18, -19, 20, -21, 22, -23, 24, -25, 26, -27, 28, -29, + 30, -31, 32, -33, 34, -35, 36, -37, 38, -39, 40, -41, 42, -43, 44, + -45, 46, -47, 48, -49, 50, -51, 52, -53, 54, -55, 56, -57, 58, -59, + 60, -61, 62, -63, 64, -65, 66, -67, 68, -69, 70, -71, 72, -73, 74, + -75, 76, -77, 78, -79, 80, -81, 82, -83, 84, -85, 86, -87, 88, -89, + 90, -91, 92, -93, 94, -95, 96, -97, 98, -99, 100, -101, 102, -103, + 104, -105, 106, -107, 108, -109, 110, -111, 112, -113, 114, -115, + 116, -117, 118, -119, 120, -121, 122, -123, 124, -125, 126, -127, + }; + + if (!object.toString().equals( + "JavaTestHelper(req_int:0, req_obj:, req_bin:"+ + "00 FF 02 FD 04 FB 06 F9 08 F7 0A F5 0C F3 0E F1 10 EF 12 ED 14 "+ + "EB 16 E9 18 E7 1A E5 1C E3 1E E1 20 DF 22 DD 24 DB 26 D9 28 D7 "+ + "2A D5 2C D3 2E D1 30 CF 32 CD 34 CB 36 C9 38 C7 3A C5 3C C3 3E "+ + "C1 40 BF 42 BD 44 BB 46 B9 48 B7 4A B5 4C B3 4E B1 50 AF 52 AD "+ + "54 AB 56 A9 58 A7 5A A5 5C A3 5E A1 60 9F 62 9D 64 9B 66 99 68 "+ + "97 6A 95 6C 93 6E 91 70 8F 72 8D 74 8B 76 89 78 87 7A 85 7C 83 "+ + "7E 81)")) { + throw new RuntimeException(); + } + + object.req_bin = new byte[] { + 0, -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12, -13, 14, -15, + 16, -17, 18, -19, 20, -21, 22, -23, 24, -25, 26, -27, 28, -29, + 30, -31, 32, -33, 34, -35, 36, -37, 38, -39, 40, -41, 42, -43, 44, + -45, 46, -47, 48, -49, 50, -51, 52, -53, 54, -55, 56, -57, 58, -59, + 60, -61, 62, -63, 64, -65, 66, -67, 68, -69, 70, -71, 72, -73, 74, + -75, 76, -77, 78, -79, 80, -81, 82, -83, 84, -85, 86, -87, 88, -89, + 90, -91, 92, -93, 94, -95, 96, -97, 98, -99, 100, -101, 102, -103, + 104, -105, 106, -107, 108, -109, 110, -111, 112, -113, 114, -115, + 116, -117, 118, -119, 120, -121, 122, -123, 124, -125, 126, -127, + 0, + }; + + if (!object.toString().equals( + "JavaTestHelper(req_int:0, req_obj:, req_bin:"+ + "00 FF 02 FD 04 FB 06 F9 08 F7 0A F5 0C F3 0E F1 10 EF 12 ED 14 "+ + "EB 16 E9 18 E7 1A E5 1C E3 1E E1 20 DF 22 DD 24 DB 26 D9 28 D7 "+ + "2A D5 2C D3 2E D1 30 CF 32 CD 34 CB 36 C9 38 C7 3A C5 3C C3 3E "+ + "C1 40 BF 42 BD 44 BB 46 B9 48 B7 4A B5 4C B3 4E B1 50 AF 52 AD "+ + "54 AB 56 A9 58 A7 5A A5 5C A3 5E A1 60 9F 62 9D 64 9B 66 99 68 "+ + "97 6A 95 6C 93 6E 91 70 8F 72 8D 74 8B 76 89 78 87 7A 85 7C 83 "+ + "7E 81 ...)")) { + throw new RuntimeException(); + } + + object.req_bin = new byte[] {}; + object.setOpt_binIsSet(true); + + + if (!object.toString().equals( + "JavaTestHelper(req_int:0, req_obj:, req_bin:)")) { + throw new RuntimeException(); + } + } +} + diff --git a/lib/java/test/org/apache/thrift/test/WriteStruct.java b/lib/java/test/org/apache/thrift/test/WriteStruct.java new file mode 100644 index 000000000..474c808e9 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/WriteStruct.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import java.io.BufferedOutputStream; +import java.io.FileOutputStream; + +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TIOStreamTransport; +import org.apache.thrift.transport.TTransport; + +public class WriteStruct { + public static void main(String[] args) throws Exception { + if (args.length != 2) { + System.out.println("usage: java -cp build/classes org.apache.thrift.test.WriteStruct filename proto_factory_class"); + System.out.println("Write out an instance of Fixtures.compactProtocolTestStruct to 'file'. Use a protocol from 'proto_factory_class'."); + } + + TTransport trans = new TIOStreamTransport(new BufferedOutputStream(new FileOutputStream(args[0]))); + + TProtocolFactory factory = (TProtocolFactory)Class.forName(args[1]).newInstance(); + + TProtocol proto = factory.getProtocol(trans); + + Fixtures.compactProtoTestStruct.write(proto); + trans.flush(); + } + +} diff --git a/lib/ocaml/Makefile b/lib/ocaml/Makefile new file mode 100644 index 000000000..6abeee71f --- /dev/null +++ b/lib/ocaml/Makefile @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +all: + cd src; make; cd .. +clean: + cd src; make clean; cd .. diff --git a/lib/ocaml/OCamlMakefile b/lib/ocaml/OCamlMakefile new file mode 100644 index 000000000..b0b9252c9 --- /dev/null +++ b/lib/ocaml/OCamlMakefile @@ -0,0 +1,1231 @@ +########################################################################### +# OCamlMakefile +# Copyright (C) 1999-2007 Markus Mottl +# +# For updates see: +# http://www.ocaml.info/home/ocaml_sources.html +# +########################################################################### + +# Modified by damien for .glade.ml compilation + +# Set these variables to the names of the sources to be processed and +# the result variable. Order matters during linkage! + +ifndef SOURCES + SOURCES := foo.ml +endif +export SOURCES + +ifndef RES_CLIB_SUF + RES_CLIB_SUF := _stubs +endif +export RES_CLIB_SUF + +ifndef RESULT + RESULT := foo +endif +export RESULT := $(strip $(RESULT)) + +export LIB_PACK_NAME + +ifndef DOC_FILES + DOC_FILES := $(filter %.mli, $(SOURCES)) +endif +export DOC_FILES +FIRST_DOC_FILE := $(firstword $(DOC_FILES)) + +export BCSUFFIX +export NCSUFFIX + +ifndef TOPSUFFIX + TOPSUFFIX := .top +endif +export TOPSUFFIX + +# Eventually set include- and library-paths, libraries to link, +# additional compilation-, link- and ocamlyacc-flags +# Path- and library information needs not be written with "-I" and such... +# Define THREADS if you need it, otherwise leave it unset (same for +# USE_CAMLP4)! + +export THREADS +export VMTHREADS +export ANNOTATE +export USE_CAMLP4 + +export INCDIRS +export LIBDIRS +export EXTLIBDIRS +export RESULTDEPS +export OCAML_DEFAULT_DIRS + +export LIBS +export CLIBS +export CFRAMEWORKS + +export OCAMLFLAGS +export OCAMLNCFLAGS +export OCAMLBCFLAGS + +export OCAMLLDFLAGS +export OCAMLNLDFLAGS +export OCAMLBLDFLAGS + +export OCAMLMKLIB_FLAGS + +ifndef OCAMLCPFLAGS + OCAMLCPFLAGS := a +endif +export OCAMLCPFLAGS + +ifndef DOC_DIR + DOC_DIR := doc +endif +export DOC_DIR + +export PPFLAGS + +export LFLAGS +export YFLAGS +export IDLFLAGS + +export OCAMLDOCFLAGS + +export OCAMLFIND_INSTFLAGS + +export DVIPSFLAGS + +export STATIC + +# Add a list of optional trash files that should be deleted by "make clean" +export TRASH + +ECHO := echo + +ifdef REALLY_QUIET + export REALLY_QUIET + ECHO := true + LFLAGS := $(LFLAGS) -q + YFLAGS := $(YFLAGS) -q +endif + +#################### variables depending on your OCaml-installation + +ifdef MINGW + export MINGW + WIN32 := 1 + CFLAGS_WIN32 := -mno-cygwin +endif +ifdef MSVC + export MSVC + WIN32 := 1 + ifndef STATIC + CPPFLAGS_WIN32 := -DCAML_DLL + endif + CFLAGS_WIN32 += -nologo + EXT_OBJ := obj + EXT_LIB := lib + ifeq ($(CC),gcc) + # work around GNU Make default value + ifdef THREADS + CC := cl -MT + else + CC := cl + endif + endif + ifeq ($(CXX),g++) + # work around GNU Make default value + CXX := $(CC) + endif + CFLAG_O := -Fo +endif +ifdef WIN32 + EXT_CXX := cpp + EXE := .exe +endif + +ifndef EXT_OBJ + EXT_OBJ := o +endif +ifndef EXT_LIB + EXT_LIB := a +endif +ifndef EXT_CXX + EXT_CXX := cc +endif +ifndef EXE + EXE := # empty +endif +ifndef CFLAG_O + CFLAG_O := -o # do not delete this comment (preserves trailing whitespace)! +endif + +export CC +export CXX +export CFLAGS +export CXXFLAGS +export LDFLAGS +export CPPFLAGS + +ifndef RPATH_FLAG + ifdef ELF_RPATH_FLAG + RPATH_FLAG := $(ELF_RPATH_FLAG) + else + RPATH_FLAG := -R + endif +endif +export RPATH_FLAG + +ifndef MSVC +ifndef PIC_CFLAGS + PIC_CFLAGS := -fPIC +endif +ifndef PIC_CPPFLAGS + PIC_CPPFLAGS := -DPIC +endif +endif + +export PIC_CFLAGS +export PIC_CPPFLAGS + +BCRESULT := $(addsuffix $(BCSUFFIX), $(RESULT)) +NCRESULT := $(addsuffix $(NCSUFFIX), $(RESULT)) +TOPRESULT := $(addsuffix $(TOPSUFFIX), $(RESULT)) + +ifndef OCAMLFIND + OCAMLFIND := ocamlfind +endif +export OCAMLFIND + +ifndef OCAMLC + OCAMLC := ocamlc +endif +export OCAMLC + +ifndef OCAMLOPT + OCAMLOPT := ocamlopt +endif +export OCAMLOPT + +ifndef OCAMLMKTOP + OCAMLMKTOP := ocamlmktop +endif +export OCAMLMKTOP + +ifndef OCAMLCP + OCAMLCP := ocamlcp +endif +export OCAMLCP + +ifndef OCAMLDEP + OCAMLDEP := ocamldep +endif +export OCAMLDEP + +ifndef OCAMLLEX + OCAMLLEX := ocamllex +endif +export OCAMLLEX + +ifndef OCAMLYACC + OCAMLYACC := ocamlyacc +endif +export OCAMLYACC + +ifndef OCAMLMKLIB + OCAMLMKLIB := ocamlmklib +endif +export OCAMLMKLIB + +ifndef OCAML_GLADECC + OCAML_GLADECC := lablgladecc2 +endif +export OCAML_GLADECC + +ifndef OCAML_GLADECC_FLAGS + OCAML_GLADECC_FLAGS := +endif +export OCAML_GLADECC_FLAGS + +ifndef CAMELEON_REPORT + CAMELEON_REPORT := report +endif +export CAMELEON_REPORT + +ifndef CAMELEON_REPORT_FLAGS + CAMELEON_REPORT_FLAGS := +endif +export CAMELEON_REPORT_FLAGS + +ifndef CAMELEON_ZOGGY + CAMELEON_ZOGGY := camlp4o pa_zog.cma pr_o.cmo +endif +export CAMELEON_ZOGGY + +ifndef CAMELEON_ZOGGY_FLAGS + CAMELEON_ZOGGY_FLAGS := +endif +export CAMELEON_ZOGGY_FLAGS + +ifndef OXRIDL + OXRIDL := oxridl +endif +export OXRIDL + +ifndef CAMLIDL + CAMLIDL := camlidl +endif +export CAMLIDL + +ifndef CAMLIDLDLL + CAMLIDLDLL := camlidldll +endif +export CAMLIDLDLL + +ifndef NOIDLHEADER + MAYBE_IDL_HEADER := -header +endif +export NOIDLHEADER + +export NO_CUSTOM + +ifndef CAMLP4 + CAMLP4 := camlp4 +endif +export CAMLP4 + +ifndef REAL_OCAMLFIND + ifdef PACKS + ifndef CREATE_LIB + ifdef THREADS + PACKS += threads + endif + endif + empty := + space := $(empty) $(empty) + comma := , + ifdef PREDS + PRE_OCAML_FIND_PREDICATES := $(subst $(space),$(comma),$(PREDS)) + PRE_OCAML_FIND_PACKAGES := $(subst $(space),$(comma),$(PACKS)) + OCAML_FIND_PREDICATES := -predicates $(PRE_OCAML_FIND_PREDICATES) + # OCAML_DEP_PREDICATES := -syntax $(PRE_OCAML_FIND_PREDICATES) + OCAML_FIND_PACKAGES := $(OCAML_FIND_PREDICATES) -package $(PRE_OCAML_FIND_PACKAGES) + OCAML_DEP_PACKAGES := $(OCAML_DEP_PREDICATES) -package $(PRE_OCAML_FIND_PACKAGES) + else + OCAML_FIND_PACKAGES := -package $(subst $(space),$(comma),$(PACKS)) + OCAML_DEP_PACKAGES := + endif + OCAML_FIND_LINKPKG := -linkpkg + REAL_OCAMLFIND := $(OCAMLFIND) + endif +endif + +export OCAML_FIND_PACKAGES +export OCAML_DEP_PACKAGES +export OCAML_FIND_LINKPKG +export REAL_OCAMLFIND + +ifndef OCAMLDOC + OCAMLDOC := ocamldoc +endif +export OCAMLDOC + +ifndef LATEX + LATEX := latex +endif +export LATEX + +ifndef DVIPS + DVIPS := dvips +endif +export DVIPS + +ifndef PS2PDF + PS2PDF := ps2pdf +endif +export PS2PDF + +ifndef OCAMLMAKEFILE + OCAMLMAKEFILE := OCamlMakefile +endif +export OCAMLMAKEFILE + +ifndef OCAMLLIBPATH + OCAMLLIBPATH := \ + $(shell $(OCAMLC) 2>/dev/null -where || echo /usr/local/lib/ocaml) +endif +export OCAMLLIBPATH + +ifndef OCAML_LIB_INSTALL + OCAML_LIB_INSTALL := $(OCAMLLIBPATH)/contrib +endif +export OCAML_LIB_INSTALL + +########################################################################### + +#################### change following sections only if +#################### you know what you are doing! + +# delete target files when a build command fails +.PHONY: .DELETE_ON_ERROR +.DELETE_ON_ERROR: + +# for pedants using "--warn-undefined-variables" +export MAYBE_IDL +export REAL_RESULT +export CAMLIDLFLAGS +export THREAD_FLAG +export RES_CLIB +export MAKEDLL +export ANNOT_FLAG +export C_OXRIDL +export SUBPROJS +export CFLAGS_WIN32 +export CPPFLAGS_WIN32 + +INCFLAGS := + +SHELL := /bin/sh + +MLDEPDIR := ._d +BCDIDIR := ._bcdi +NCDIDIR := ._ncdi + +FILTER_EXTNS := %.mli %.ml %.mll %.mly %.idl %.oxridl %.c %.m %.$(EXT_CXX) %.rep %.zog %.glade + +FILTERED := $(filter $(FILTER_EXTNS), $(SOURCES)) +SOURCE_DIRS := $(filter-out ./, $(sort $(dir $(FILTERED)))) + +FILTERED_REP := $(filter %.rep, $(FILTERED)) +DEP_REP := $(FILTERED_REP:%.rep=$(MLDEPDIR)/%.d) +AUTO_REP := $(FILTERED_REP:.rep=.ml) + +FILTERED_ZOG := $(filter %.zog, $(FILTERED)) +DEP_ZOG := $(FILTERED_ZOG:%.zog=$(MLDEPDIR)/%.d) +AUTO_ZOG := $(FILTERED_ZOG:.zog=.ml) + +FILTERED_GLADE := $(filter %.glade, $(FILTERED)) +DEP_GLADE := $(FILTERED_GLADE:%.glade=$(MLDEPDIR)/%.d) +AUTO_GLADE := $(FILTERED_GLADE:.glade=.ml) + +FILTERED_ML := $(filter %.ml, $(FILTERED)) +DEP_ML := $(FILTERED_ML:%.ml=$(MLDEPDIR)/%.d) + +FILTERED_MLI := $(filter %.mli, $(FILTERED)) +DEP_MLI := $(FILTERED_MLI:.mli=.di) + +FILTERED_MLL := $(filter %.mll, $(FILTERED)) +DEP_MLL := $(FILTERED_MLL:%.mll=$(MLDEPDIR)/%.d) +AUTO_MLL := $(FILTERED_MLL:.mll=.ml) + +FILTERED_MLY := $(filter %.mly, $(FILTERED)) +DEP_MLY := $(FILTERED_MLY:%.mly=$(MLDEPDIR)/%.d) $(FILTERED_MLY:.mly=.di) +AUTO_MLY := $(FILTERED_MLY:.mly=.mli) $(FILTERED_MLY:.mly=.ml) + +FILTERED_IDL := $(filter %.idl, $(FILTERED)) +DEP_IDL := $(FILTERED_IDL:%.idl=$(MLDEPDIR)/%.d) $(FILTERED_IDL:.idl=.di) +C_IDL := $(FILTERED_IDL:%.idl=%_stubs.c) +ifndef NOIDLHEADER + C_IDL += $(FILTERED_IDL:.idl=.h) +endif +OBJ_C_IDL := $(FILTERED_IDL:%.idl=%_stubs.$(EXT_OBJ)) +AUTO_IDL := $(FILTERED_IDL:.idl=.mli) $(FILTERED_IDL:.idl=.ml) $(C_IDL) + +FILTERED_OXRIDL := $(filter %.oxridl, $(FILTERED)) +DEP_OXRIDL := $(FILTERED_OXRIDL:%.oxridl=$(MLDEPDIR)/%.d) $(FILTERED_OXRIDL:.oxridl=.di) +AUTO_OXRIDL := $(FILTERED_OXRIDL:.oxridl=.mli) $(FILTERED_OXRIDL:.oxridl=.ml) $(C_OXRIDL) + +FILTERED_C_CXX := $(filter %.c %.m %.$(EXT_CXX), $(FILTERED)) +OBJ_C_CXX := $(FILTERED_C_CXX:.c=.$(EXT_OBJ)) +OBJ_C_CXX := $(OBJ_C_CXX:.m=.$(EXT_OBJ)) +OBJ_C_CXX := $(OBJ_C_CXX:.$(EXT_CXX)=.$(EXT_OBJ)) + +PRE_TARGETS += $(AUTO_MLL) $(AUTO_MLY) $(AUTO_IDL) $(AUTO_OXRIDL) $(AUTO_ZOG) $(AUTO_REP) $(AUTO_GLADE) + +ALL_DEPS := $(DEP_ML) $(DEP_MLI) $(DEP_MLL) $(DEP_MLY) $(DEP_IDL) $(DEP_OXRIDL) $(DEP_ZOG) $(DEP_REP) $(DEP_GLADE) + +MLDEPS := $(filter %.d, $(ALL_DEPS)) +MLIDEPS := $(filter %.di, $(ALL_DEPS)) +BCDEPIS := $(MLIDEPS:%.di=$(BCDIDIR)/%.di) +NCDEPIS := $(MLIDEPS:%.di=$(NCDIDIR)/%.di) + +ALLML := $(filter %.mli %.ml %.mll %.mly %.idl %.oxridl %.rep %.zog %.glade, $(FILTERED)) + +IMPLO_INTF := $(ALLML:%.mli=%.mli.__) +IMPLO_INTF := $(foreach file, $(IMPLO_INTF), \ + $(basename $(file)).cmi $(basename $(file)).cmo) +IMPLO_INTF := $(filter-out %.mli.cmo, $(IMPLO_INTF)) +IMPLO_INTF := $(IMPLO_INTF:%.mli.cmi=%.cmi) + +IMPLX_INTF := $(IMPLO_INTF:.cmo=.cmx) + +INTF := $(filter %.cmi, $(IMPLO_INTF)) +IMPL_CMO := $(filter %.cmo, $(IMPLO_INTF)) +IMPL_CMX := $(IMPL_CMO:.cmo=.cmx) +IMPL_ASM := $(IMPL_CMO:.cmo=.asm) +IMPL_S := $(IMPL_CMO:.cmo=.s) + +OBJ_LINK := $(OBJ_C_IDL) $(OBJ_C_CXX) +OBJ_FILES := $(IMPL_CMO:.cmo=.$(EXT_OBJ)) $(OBJ_LINK) + +EXECS := $(addsuffix $(EXE), \ + $(sort $(TOPRESULT) $(BCRESULT) $(NCRESULT))) +ifdef WIN32 + EXECS += $(BCRESULT).dll $(NCRESULT).dll +endif + +CLIB_BASE := $(RESULT)$(RES_CLIB_SUF) +ifneq ($(strip $(OBJ_LINK)),) + RES_CLIB := lib$(CLIB_BASE).$(EXT_LIB) +endif + +ifdef WIN32 +DLLSONAME := $(CLIB_BASE).dll +else +DLLSONAME := dll$(CLIB_BASE).so +endif + +NONEXECS := $(INTF) $(IMPL_CMO) $(IMPL_CMX) $(IMPL_ASM) $(IMPL_S) \ + $(OBJ_FILES) $(PRE_TARGETS) $(BCRESULT).cma $(NCRESULT).cmxa \ + $(NCRESULT).$(EXT_LIB) $(BCRESULT).cmi $(BCRESULT).cmo \ + $(NCRESULT).cmi $(NCRESULT).cmx $(NCRESULT).o \ + $(RES_CLIB) $(IMPL_CMO:.cmo=.annot) \ + $(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmo $(LIB_PACK_NAME).cmx $(LIB_PACK_NAME).o + +ifndef STATIC + NONEXECS += $(DLLSONAME) +endif + +ifndef LIBINSTALL_FILES + LIBINSTALL_FILES := $(RESULT).mli $(RESULT).cmi $(RESULT).cma \ + $(RESULT).cmxa $(RESULT).$(EXT_LIB) $(RES_CLIB) + ifndef STATIC + ifneq ($(strip $(OBJ_LINK)),) + LIBINSTALL_FILES += $(DLLSONAME) + endif + endif +endif + +export LIBINSTALL_FILES + +ifdef WIN32 + # some extra stuff is created while linking DLLs + NONEXECS += $(BCRESULT).$(EXT_LIB) $(BCRESULT).exp $(NCRESULT).exp $(CLIB_BASE).exp $(CLIB_BASE).lib +endif + +TARGETS := $(EXECS) $(NONEXECS) + +# If there are IDL-files +ifneq ($(strip $(FILTERED_IDL)),) + MAYBE_IDL := -cclib -lcamlidl +endif + +ifdef USE_CAMLP4 + CAMLP4PATH := \ + $(shell $(CAMLP4) -where 2>/dev/null || echo /usr/local/lib/camlp4) + INCFLAGS := -I $(CAMLP4PATH) + CINCFLAGS := -I$(CAMLP4PATH) +endif + +DINCFLAGS := $(INCFLAGS) $(SOURCE_DIRS:%=-I %) $(OCAML_DEFAULT_DIRS:%=-I %) +INCFLAGS := $(DINCFLAGS) $(INCDIRS:%=-I %) +CINCFLAGS += $(SOURCE_DIRS:%=-I%) $(INCDIRS:%=-I%) $(OCAML_DEFAULT_DIRS:%=-I%) + +ifndef MSVC + CLIBFLAGS += $(SOURCE_DIRS:%=-L%) $(LIBDIRS:%=-L%) \ + $(EXTLIBDIRS:%=-L%) $(OCAML_DEFAULT_DIRS:%=-L%) + + ifeq ($(ELF_RPATH), yes) + CLIBFLAGS += $(EXTLIBDIRS:%=-Wl,$(RPATH_FLAG)%) + endif +endif + +ifndef PROFILING + INTF_OCAMLC := $(OCAMLC) +else + ifndef THREADS + INTF_OCAMLC := $(OCAMLCP) -p $(OCAMLCPFLAGS) + else + # OCaml does not support profiling byte code + # with threads (yet), therefore we force an error. + ifndef REAL_OCAMLC + $(error Profiling of multithreaded byte code not yet supported by OCaml) + endif + INTF_OCAMLC := $(OCAMLC) + endif +endif + +ifndef MSVC + COMMON_LDFLAGS := $(LDFLAGS:%=-ccopt %) $(SOURCE_DIRS:%=-ccopt -L%) \ + $(LIBDIRS:%=-ccopt -L%) $(EXTLIBDIRS:%=-ccopt -L%) \ + $(EXTLIBDIRS:%=-ccopt -Wl $(OCAML_DEFAULT_DIRS:%=-ccopt -L%)) + + ifeq ($(ELF_RPATH),yes) + COMMON_LDFLAGS += $(EXTLIBDIRS:%=-ccopt -Wl,$(RPATH_FLAG)%) + endif +else + COMMON_LDFLAGS := -ccopt "/link -NODEFAULTLIB:LIBC $(LDFLAGS:%=%) $(SOURCE_DIRS:%=-LIBPATH:%) \ + $(LIBDIRS:%=-LIBPATH:%) $(EXTLIBDIRS:%=-LIBPATH:%) \ + $(OCAML_DEFAULT_DIRS:%=-LIBPATH:%) " +endif + +CLIBS_OPTS := $(CLIBS:%=-cclib -l%) $(CFRAMEWORKS:%=-cclib '-framework %') +ifdef MSVC + ifndef STATIC + # MSVC libraries do not have 'lib' prefix + CLIBS_OPTS := $(CLIBS:%=-cclib %.lib) + endif +endif + +ifneq ($(strip $(OBJ_LINK)),) + ifdef CREATE_LIB + OBJS_LIBS := -cclib -l$(CLIB_BASE) $(CLIBS_OPTS) $(MAYBE_IDL) + else + OBJS_LIBS := $(OBJ_LINK) $(CLIBS_OPTS) $(MAYBE_IDL) + endif +else + OBJS_LIBS := $(CLIBS_OPTS) $(MAYBE_IDL) +endif + +# If we have to make byte-code +ifndef REAL_OCAMLC + BYTE_OCAML := y + + # EXTRADEPS is added dependencies we have to insert for all + # executable files we generate. Ideally it should be all of the + # libraries we use, but it's hard to find the ones that get searched on + # the path since I don't know the paths built into the compiler, so + # just include the ones with slashes in their names. + EXTRADEPS := $(addsuffix .cma,$(foreach i,$(LIBS),$(if $(findstring /,$(i)),$(i)))) + SPECIAL_OCAMLFLAGS := $(OCAMLBCFLAGS) + + REAL_OCAMLC := $(INTF_OCAMLC) + + REAL_IMPL := $(IMPL_CMO) + REAL_IMPL_INTF := $(IMPLO_INTF) + IMPL_SUF := .cmo + + DEPFLAGS := + MAKE_DEPS := $(MLDEPS) $(BCDEPIS) + + ifdef CREATE_LIB + override CFLAGS := $(PIC_CFLAGS) $(CFLAGS) + override CPPFLAGS := $(PIC_CPPFLAGS) $(CPPFLAGS) + ifndef STATIC + ifneq ($(strip $(OBJ_LINK)),) + MAKEDLL := $(DLLSONAME) + ALL_LDFLAGS := -dllib $(DLLSONAME) + endif + endif + endif + + ifndef NO_CUSTOM + ifneq "$(strip $(OBJ_LINK) $(THREADS) $(MAYBE_IDL) $(CLIBS) $(CFRAMEWORKS))" "" + ALL_LDFLAGS += -custom + endif + endif + + ALL_LDFLAGS += $(INCFLAGS) $(OCAMLLDFLAGS) $(OCAMLBLDFLAGS) \ + $(COMMON_LDFLAGS) $(LIBS:%=%.cma) + CAMLIDLDLLFLAGS := + + ifdef THREADS + ifdef VMTHREADS + THREAD_FLAG := -vmthread + else + THREAD_FLAG := -thread + endif + ALL_LDFLAGS := $(THREAD_FLAG) $(ALL_LDFLAGS) + ifndef CREATE_LIB + ifndef REAL_OCAMLFIND + ALL_LDFLAGS := unix.cma threads.cma $(ALL_LDFLAGS) + endif + endif + endif + +# we have to make native-code +else + EXTRADEPS := $(addsuffix .cmxa,$(foreach i,$(LIBS),$(if $(findstring /,$(i)),$(i)))) + ifndef PROFILING + SPECIAL_OCAMLFLAGS := $(OCAMLNCFLAGS) + PLDFLAGS := + else + SPECIAL_OCAMLFLAGS := -p $(OCAMLNCFLAGS) + PLDFLAGS := -p + endif + + REAL_IMPL := $(IMPL_CMX) + REAL_IMPL_INTF := $(IMPLX_INTF) + IMPL_SUF := .cmx + + override CPPFLAGS := -DNATIVE_CODE $(CPPFLAGS) + + DEPFLAGS := -native + MAKE_DEPS := $(MLDEPS) $(NCDEPIS) + + ALL_LDFLAGS := $(PLDFLAGS) $(INCFLAGS) $(OCAMLLDFLAGS) \ + $(OCAMLNLDFLAGS) $(COMMON_LDFLAGS) + CAMLIDLDLLFLAGS := -opt + + ifndef CREATE_LIB + ALL_LDFLAGS += $(LIBS:%=%.cmxa) + else + override CFLAGS := $(PIC_CFLAGS) $(CFLAGS) + override CPPFLAGS := $(PIC_CPPFLAGS) $(CPPFLAGS) + endif + + ifdef THREADS + THREAD_FLAG := -thread + ALL_LDFLAGS := $(THREAD_FLAG) $(ALL_LDFLAGS) + ifndef CREATE_LIB + ifndef REAL_OCAMLFIND + ALL_LDFLAGS := unix.cmxa threads.cmxa $(ALL_LDFLAGS) + endif + endif + endif +endif + +export MAKE_DEPS + +ifdef ANNOTATE + ANNOT_FLAG := -dtypes +else +endif + +ALL_OCAMLCFLAGS := $(THREAD_FLAG) $(ANNOT_FLAG) $(OCAMLFLAGS) \ + $(INCFLAGS) $(SPECIAL_OCAMLFLAGS) + +ifdef make_deps + -include $(MAKE_DEPS) + PRE_TARGETS := +endif + +########################################################################### +# USER RULES + +# Call "OCamlMakefile QUIET=" to get rid of all of the @'s. +QUIET=@ + +# generates byte-code (default) +byte-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(BCRESULT) \ + REAL_RESULT="$(BCRESULT)" make_deps=yes +bc: byte-code + +byte-code-nolink: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) nolink \ + REAL_RESULT="$(BCRESULT)" make_deps=yes +bcnl: byte-code-nolink + +top: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(TOPRESULT) \ + REAL_RESULT="$(BCRESULT)" make_deps=yes + +# generates native-code + +native-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(NCRESULT) \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + make_deps=yes +nc: native-code + +native-code-nolink: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) nolink \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + make_deps=yes +ncnl: native-code-nolink + +# generates byte-code libraries +byte-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(BCRESULT).cma \ + REAL_RESULT="$(BCRESULT)" \ + CREATE_LIB=yes \ + make_deps=yes +bcl: byte-code-library + +# generates native-code libraries +native-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(NCRESULT).cmxa \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + CREATE_LIB=yes \ + make_deps=yes +ncl: native-code-library + +ifdef WIN32 +# generates byte-code dll +byte-code-dll: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(BCRESULT).dll \ + REAL_RESULT="$(BCRESULT)" \ + make_deps=yes +bcd: byte-code-dll + +# generates native-code dll +native-code-dll: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(NCRESULT).dll \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + make_deps=yes +ncd: native-code-dll +endif + +# generates byte-code with debugging information +debug-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(BCRESULT) \ + REAL_RESULT="$(BCRESULT)" make_deps=yes \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dc: debug-code + +debug-code-nolink: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) nolink \ + REAL_RESULT="$(BCRESULT)" make_deps=yes \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dcnl: debug-code-nolink + +# generates byte-code with debugging information (native code) +debug-native-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(NCRESULT) \ + REAL_RESULT="$(NCRESULT)" make_deps=yes \ + REAL_OCAMLC="$(OCAMLOPT)" \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dnc: debug-native-code + +debug-native-code-nolink: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) nolink \ + REAL_RESULT="$(NCRESULT)" make_deps=yes \ + REAL_OCAMLC="$(OCAMLOPT)" \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dncnl: debug-native-code-nolink + +# generates byte-code libraries with debugging information +debug-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(BCRESULT).cma \ + REAL_RESULT="$(BCRESULT)" make_deps=yes \ + CREATE_LIB=yes \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dcl: debug-code-library + +# generates byte-code libraries with debugging information (native code) +debug-native-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(NCRESULT).cma \ + REAL_RESULT="$(NCRESULT)" make_deps=yes \ + REAL_OCAMLC="$(OCAMLOPT)" \ + CREATE_LIB=yes \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dncl: debug-native-code-library + +# generates byte-code for profiling +profiling-byte-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(BCRESULT) \ + REAL_RESULT="$(BCRESULT)" PROFILING="y" \ + make_deps=yes +pbc: profiling-byte-code + +# generates native-code + +profiling-native-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(NCRESULT) \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + PROFILING="y" \ + make_deps=yes +pnc: profiling-native-code + +# generates byte-code libraries +profiling-byte-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(BCRESULT).cma \ + REAL_RESULT="$(BCRESULT)" PROFILING="y" \ + CREATE_LIB=yes \ + make_deps=yes +pbcl: profiling-byte-code-library + +# generates native-code libraries +profiling-native-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(NCRESULT).cmxa \ + REAL_RESULT="$(NCRESULT)" PROFILING="y" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + CREATE_LIB=yes \ + make_deps=yes +pncl: profiling-native-code-library + +# packs byte-code objects +pack-byte-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(BCRESULT).cmo \ + REAL_RESULT="$(BCRESULT)" \ + PACK_LIB=yes make_deps=yes +pabc: pack-byte-code + +# packs native-code objects +pack-native-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(NCRESULT).cmx $(NCRESULT).o \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + PACK_LIB=yes make_deps=yes +panc: pack-native-code + +# generates HTML-documentation +htdoc: $(DOC_DIR)/$(RESULT)/html/index.html + +# generates Latex-documentation +ladoc: $(DOC_DIR)/$(RESULT)/latex/doc.tex + +# generates PostScript-documentation +psdoc: $(DOC_DIR)/$(RESULT)/latex/doc.ps + +# generates PDF-documentation +pdfdoc: $(DOC_DIR)/$(RESULT)/latex/doc.pdf + +# generates all supported forms of documentation +doc: htdoc ladoc psdoc pdfdoc + +########################################################################### +# LOW LEVEL RULES + +$(REAL_RESULT): $(REAL_IMPL_INTF) $(OBJ_LINK) $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) \ + $(OCAML_FIND_PACKAGES) $(OCAML_FIND_LINKPKG) \ + $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@$(EXE) \ + $(REAL_IMPL) + +nolink: $(REAL_IMPL_INTF) $(OBJ_LINK) + +ifdef WIN32 +$(REAL_RESULT).dll: $(REAL_IMPL_INTF) $(OBJ_LINK) + $(CAMLIDLDLL) $(CAMLIDLDLLFLAGS) $(OBJ_LINK) $(CLIBS) \ + -o $@ $(REAL_IMPL) +endif + +%$(TOPSUFFIX): $(REAL_IMPL_INTF) $(OBJ_LINK) $(EXTRADEPS) + $(REAL_OCAMLFIND) $(OCAMLMKTOP) \ + $(OCAML_FIND_PACKAGES) $(OCAML_FIND_LINKPKG) \ + $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@$(EXE) \ + $(REAL_IMPL) + +.SUFFIXES: .mli .ml .cmi .cmo .cmx .cma .cmxa .$(EXT_OBJ) \ + .mly .di .d .$(EXT_LIB) .idl %.oxridl .c .m .$(EXT_CXX) .h .so \ + .rep .zog .glade + +ifndef STATIC +ifdef MINGW +$(DLLSONAME): $(OBJ_LINK) + $(CC) $(CFLAGS) $(CFLAGS_WIN32) $(OBJ_LINK) -shared -o $@ \ + -Wl,--whole-archive $(wildcard $(foreach dir,$(LIBDIRS),$(CLIBS:%=$(dir)/lib%.a))) \ + $(OCAMLLIBPATH)/ocamlrun.a \ + -Wl,--export-all-symbols \ + -Wl,--no-whole-archive +else +ifdef MSVC +$(DLLSONAME): $(OBJ_LINK) + link /NOLOGO /DLL /OUT:$@ $(OBJ_LINK) \ + $(wildcard $(foreach dir,$(LIBDIRS),$(CLIBS:%=$(dir)/%.lib))) \ + $(OCAMLLIBPATH)/ocamlrun.lib + +else +$(DLLSONAME): $(OBJ_LINK) + $(OCAMLMKLIB) $(INCFLAGS) $(CLIBFLAGS) \ + -o $(CLIB_BASE) $(OBJ_LINK) $(CLIBS:%=-l%) $(CFRAMEWORKS:%=-framework %) \ + $(OCAMLMKLIB_FLAGS) +endif +endif +endif + +ifndef LIB_PACK_NAME +$(RESULT).cma: $(REAL_IMPL_INTF) $(MAKEDLL) $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) -a $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@ $(REAL_IMPL) + +$(RESULT).cmxa $(RESULT).$(EXT_LIB): $(REAL_IMPL_INTF) $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(OCAMLOPT) -a $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@ $(REAL_IMPL) +else +ifdef BYTE_OCAML +$(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmo: $(REAL_IMPL_INTF) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) -pack -o $(LIB_PACK_NAME).cmo $(OCAMLLDFLAGS) $(REAL_IMPL) +else +$(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmx: $(REAL_IMPL_INTF) + $(REAL_OCAMLFIND) $(OCAMLOPT) -pack -o $(LIB_PACK_NAME).cmx $(OCAMLLDFLAGS) $(REAL_IMPL) +endif + +$(RESULT).cma: $(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmo $(MAKEDLL) $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) -a $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@ $(LIB_PACK_NAME).cmo + +$(RESULT).cmxa $(RESULT).$(EXT_LIB): $(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmx $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(OCAMLOPT) -a $(filter-out -custom, $(ALL_LDFLAGS)) $(OBJS_LIBS) -o $@ $(LIB_PACK_NAME).cmx +endif + +$(RES_CLIB): $(OBJ_LINK) +ifndef MSVC + ifneq ($(strip $(OBJ_LINK)),) + $(AR) rcs $@ $(OBJ_LINK) + endif +else + ifneq ($(strip $(OBJ_LINK)),) + lib -nologo -debugtype:cv -out:$(RES_CLIB) $(OBJ_LINK) + endif +endif + +.mli.cmi: $(EXTRADEPS) + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $<`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(INTF_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c $(THREAD_FLAG) $(ANNOT_FLAG) \ + $(OCAMLFLAGS) $(INCFLAGS) $<; \ + $(REAL_OCAMLFIND) $(INTF_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c $(THREAD_FLAG) $(ANNOT_FLAG) \ + $(OCAMLFLAGS) $(INCFLAGS) $<; \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(INTF_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c -pp \"$$pp $(PPFLAGS)\" $(THREAD_FLAG) $(ANNOT_FLAG) \ + $(OCAMLFLAGS) $(INCFLAGS) $<; \ + $(REAL_OCAMLFIND) $(INTF_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c -pp "$$pp $(PPFLAGS)" $(THREAD_FLAG) $(ANNOT_FLAG) \ + $(OCAMLFLAGS) $(INCFLAGS) $<; \ + fi + +.ml.cmi .ml.$(EXT_OBJ) .ml.cmx .ml.cmo: $(EXTRADEPS) + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $<`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(REAL_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c $(ALL_OCAMLCFLAGS) $<; \ + $(REAL_OCAMLFIND) $(REAL_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c $(ALL_OCAMLCFLAGS) $<; \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(REAL_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c -pp \"$$pp $(PPFLAGS)\" $(ALL_OCAMLCFLAGS) $<; \ + $(REAL_OCAMLFIND) $(REAL_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c -pp "$$pp $(PPFLAGS)" $(ALL_OCAMLCFLAGS) $<; \ + fi + +ifdef PACK_LIB +$(REAL_RESULT).cmo $(REAL_RESULT).cmx $(REAL_RESULT).o: $(REAL_IMPL_INTF) $(OBJ_LINK) $(EXTRADEPS) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) -pack $(ALL_LDFLAGS) \ + $(OBJS_LIBS) -o $@ $(REAL_IMPL) +endif + +.PRECIOUS: %.ml +%.ml: %.mll + $(OCAMLLEX) $(LFLAGS) $< + +.PRECIOUS: %.ml %.mli +%.ml %.mli: %.mly + $(OCAMLYACC) $(YFLAGS) $< + $(QUIET)pp=`sed -n -e 's/.*(\*pp \([^*]*\) \*).*/\1/p;q' $<`; \ + if [ ! -z "$$pp" ]; then \ + mv $*.ml $*.ml.temporary; \ + echo "(*pp $$pp $(PPFLAGS)*)" > $*.ml; \ + cat $*.ml.temporary >> $*.ml; \ + rm $*.ml.temporary; \ + mv $*.mli $*.mli.temporary; \ + echo "(*pp $$pp $(PPFLAGS)*)" > $*.mli; \ + cat $*.mli.temporary >> $*.mli; \ + rm $*.mli.temporary; \ + fi + + +.PRECIOUS: %.ml +%.ml: %.rep + $(CAMELEON_REPORT) $(CAMELEON_REPORT_FLAGS) -gen $< + +.PRECIOUS: %.ml +%.ml: %.zog + $(CAMELEON_ZOGGY) $(CAMELEON_ZOGGY_FLAGS) -impl $< > $@ + +.PRECIOUS: %.ml +%.ml: %.glade + $(OCAML_GLADECC) $(OCAML_GLADECC_FLAGS) $< > $@ + +.PRECIOUS: %.ml %.mli +%.ml %.mli: %.oxridl + $(OXRIDL) $< + +.PRECIOUS: %.ml %.mli %_stubs.c %.h +%.ml %.mli %_stubs.c %.h: %.idl + $(CAMLIDL) $(MAYBE_IDL_HEADER) $(IDLFLAGS) \ + $(CAMLIDLFLAGS) $< + $(QUIET)if [ $(NOIDLHEADER) ]; then touch $*.h; fi + +.c.$(EXT_OBJ): + $(OCAMLC) -c -cc "$(CC)" -ccopt "$(CFLAGS) \ + $(CPPFLAGS) $(CPPFLAGS_WIN32) \ + $(CFLAGS_WIN32) $(CINCFLAGS) $(CFLAG_O)$@ " $< + +.m.$(EXT_OBJ): + $(CC) -c $(CFLAGS) $(CINCFLAGS) $(CPPFLAGS) \ + -I'$(OCAMLLIBPATH)' \ + $< $(CFLAG_O)$@ + +.$(EXT_CXX).$(EXT_OBJ): + $(CXX) -c $(CXXFLAGS) $(CINCFLAGS) $(CPPFLAGS) \ + -I'$(OCAMLLIBPATH)' \ + $< $(CFLAG_O)$@ + +$(MLDEPDIR)/%.d: %.ml + $(QUIET)if [ ! -d $(@D) ]; then mkdir -p $(@D); fi + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $<`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDEP) $(OCAML_DEP_PACKAGES) \ + $(DINCFLAGS) $< \> $@; \ + $(REAL_OCAMLFIND) $(OCAMLDEP) $(OCAML_DEP_PACKAGES) \ + $(DINCFLAGS) $< > $@; \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDEP) $(OCAML_DEP_PACKAGES) \ + -pp \"$$pp $(PPFLAGS)\" $(DINCFLAGS) $< \> $@; \ + $(REAL_OCAMLFIND) $(OCAMLDEP) $(OCAML_DEP_PACKAGES) \ + -pp "$$pp $(PPFLAGS)" $(DINCFLAGS) $< > $@; \ + fi + +$(BCDIDIR)/%.di $(NCDIDIR)/%.di: %.mli + $(QUIET)if [ ! -d $(@D) ]; then mkdir -p $(@D); fi + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $<`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDEP) $(DEPFLAGS) $(DINCFLAGS) $< \> $@; \ + $(REAL_OCAMLFIND) $(OCAMLDEP) $(DEPFLAGS) $(DINCFLAGS) $< > $@; \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDEP) $(DEPFLAGS) \ + -pp \"$$pp $(PPFLAGS)\" $(DINCFLAGS) $< \> $@; \ + $(REAL_OCAMLFIND) $(OCAMLDEP) $(DEPFLAGS) \ + -pp "$$pp $(PPFLAGS)" $(DINCFLAGS) $< > $@; \ + fi + +$(DOC_DIR)/$(RESULT)/html: + mkdir -p $@ + +$(DOC_DIR)/$(RESULT)/html/index.html: $(DOC_DIR)/$(RESULT)/html $(DOC_FILES) + rm -rf $</* + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $(FIRST_DOC_FILE)`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDOC) $(OCAML_FIND_PACKAGES) -html -d $< $(OCAMLDOCFLAGS) $(INCFLAGS) $(DOC_FILES); \ + $(REAL_OCAMLFIND) $(OCAMLDOC) $(OCAML_FIND_PACKAGES) -html -d $< $(OCAMLDOCFLAGS) $(INCFLAGS) $(DOC_FILES); \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDOC) $(OCAML_FIND_PACKAGES) -pp \"$$pp $(PPFLAGS)\" -html -d $< $(OCAMLDOCFLAGS) \ + $(INCFLAGS) $(DOC_FILES); \ + $(REAL_OCAMLFIND) $(OCAMLDOC) $(OCAML_FIND_PACKAGES) -pp "$$pp $(PPFLAGS)" -html -d $< $(OCAMLDOCFLAGS) \ + $(INCFLAGS) $(DOC_FILES); \ + fi + +$(DOC_DIR)/$(RESULT)/latex: + mkdir -p $@ + +$(DOC_DIR)/$(RESULT)/latex/doc.tex: $(DOC_DIR)/$(RESULT)/latex $(DOC_FILES) + rm -rf $</* + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $(FIRST_DOC_FILE)`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDOC) $(OCAML_FIND_PACKAGES) -latex $(OCAMLDOCFLAGS) $(INCFLAGS) \ + $(DOC_FILES) -o $@; \ + $(REAL_OCAMLFIND) $(OCAMLDOC) $(OCAML_FIND_PACKAGES) -latex $(OCAMLDOCFLAGS) $(INCFLAGS) $(DOC_FILES) \ + -o $@; \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDOC) $(OCAML_FIND_PACKAGES) -pp \"$$pp $(PPFLAGS)\" -latex $(OCAMLDOCFLAGS) \ + $(INCFLAGS) $(DOC_FILES) -o $@; \ + $(REAL_OCAMLFIND) $(OCAMLDOC) $(OCAML_FIND_PACKAGES) -pp "$$pp $(PPFLAGS)" -latex $(OCAMLDOCFLAGS) \ + $(INCFLAGS) $(DOC_FILES) -o $@; \ + fi + +$(DOC_DIR)/$(RESULT)/latex/doc.ps: $(DOC_DIR)/$(RESULT)/latex/doc.tex + cd $(DOC_DIR)/$(RESULT)/latex && \ + $(LATEX) doc.tex && \ + $(LATEX) doc.tex && \ + $(DVIPS) $(DVIPSFLAGS) doc.dvi -o $(@F) + +$(DOC_DIR)/$(RESULT)/latex/doc.pdf: $(DOC_DIR)/$(RESULT)/latex/doc.ps + cd $(DOC_DIR)/$(RESULT)/latex && $(PS2PDF) $(<F) + +define make_subproj +.PHONY: +subproj_$(1): + $$(eval $$(call PROJ_$(1))) + $(QUIET)if [ "$(SUBTARGET)" != "all" ]; then \ + $(MAKE) -f $(OCAMLMAKEFILE) $(SUBTARGET); \ + fi +endef + +$(foreach subproj,$(SUBPROJS),$(eval $(call make_subproj,$(subproj)))) + +.PHONY: +subprojs: $(SUBPROJS:%=subproj_%) + +########################################################################### +# (UN)INSTALL RULES FOR LIBRARIES + +.PHONY: libinstall +libinstall: all + $(QUIET)printf "\nInstalling library with ocamlfind\n" + $(OCAMLFIND) install $(OCAMLFIND_INSTFLAGS) $(RESULT) META $(LIBINSTALL_FILES) + $(QUIET)printf "\nInstallation successful.\n" + +.PHONY: libinstall-byte-code +libinstall-byte-code: all + $(QUIET)printf "\nInstalling byte-code library with ocamlfind\n" + $(OCAMLFIND) install $(OCAMLFIND_INSTFLAGS) $(RESULT) META \ + $(filter-out $(RESULT).$(EXT_LIB) $(RESULT).cmxa, $(LIBINSTALL_FILES)) + $(QUIET)printf "\nInstallation successful.\n" + +.PHONY: libinstall-native-code +libinstall-native-code: all + $(QUIET)printf "\nInstalling native-code library with ocamlfind\n" + $(OCAMLFIND) install $(OCAMLFIND_INSTFLAGS) $(RESULT) META \ + $(filter-out $(DLLSONAME) $(RESULT).cma, $(LIBINSTALL_FILES)) + $(QUIET)printf "\nInstallation successful.\n" + +.PHONY: libuninstall +libuninstall: + $(QUIET)printf "\nUninstalling library with ocamlfind\n" + $(OCAMLFIND) remove $(OCAMLFIND_INSTFLAGS) $(RESULT) + $(QUIET)printf "\nUninstallation successful.\n" + +.PHONY: rawinstall +rawinstall: all + $(QUIET)printf "\nInstalling library to: $(OCAML_LIB_INSTALL)\n" + -install -d $(OCAML_LIB_INSTALL) + for i in $(LIBINSTALL_FILES); do \ + if [ -f $$i ]; then \ + install -c -m 0644 $$i $(OCAML_LIB_INSTALL); \ + fi; \ + done + $(QUIET)printf "\nInstallation successful.\n" + +.PHONY: rawuninstall +rawuninstall: + $(QUIET)printf "\nUninstalling library from: $(OCAML_LIB_INSTALL)\n" + cd $(OCAML_LIB_INSTALL) && rm $(notdir $(LIBINSTALL_FILES)) + $(QUIET)printf "\nUninstallation successful.\n" + +########################################################################### +# MAINTENANCE RULES + +.PHONY: clean +clean:: + rm -f $(TARGETS) $(TRASH) + rm -rf $(BCDIDIR) $(NCDIDIR) $(MLDEPDIR) + +.PHONY: cleanup +cleanup:: + rm -f $(NONEXECS) $(TRASH) + rm -rf $(BCDIDIR) $(NCDIDIR) $(MLDEPDIR) + +.PHONY: clean-doc +clean-doc:: + rm -rf $(DOC_DIR)/$(RESULT) + +.PHONY: clean-all +clean-all:: clean clean-doc + +.PHONY: nobackup +nobackup: + rm -f *.bak *~ *.dup diff --git a/lib/ocaml/README b/lib/ocaml/README new file mode 100644 index 000000000..5a47a4247 --- /dev/null +++ b/lib/ocaml/README @@ -0,0 +1,119 @@ +Thrift OCaml Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + + +Library +======= + +The library abstract classes, exceptions, and general use functions +are mostly jammed in Thrift.ml (an exception being +TServer). + +Generally, classes are used, however they are often put in their own +module along with other relevant types and functions. The classes +often called t, exceptions are called E. + +Implementations live in their own files. There is TBinaryProtocol, +TSocket, TThreadedServer, TSimpleServer, and TServerSocket. + +A note on making the library: Running make should create native, debug +code libraries, and a toplevel. + + +Struct format +------------- +Structs are turned into classes. The fields are all option types and +are initially None. Write is a method, but reading is done by a +separate function (since there is no such thing as a static +class). The class type is t and is in a module with the name of the +struct. + + +enum format +----------- +Enums are put in their own module along with +functions to_i and of_i which convert the ocaml types into ints. For +example: + +enum Numberz +{ + ONE = 1, + TWO, + THREE, + FIVE = 5, + SIX, + EIGHT = 8 +} + +==> + +module Numberz = +struct +type t = +| ONE +| TWO +| THREE +| FIVE +| SIX +| EIGHT + +let of_i = ... +let to_i = ... +end + +typedef format +-------------- +Typedef turns into the type declaration: +typedef i64 UserId + +==> + +type userid Int64.t + +exception format +---------------- +The same as structs except that the module also has an exception type +E of t that is raised/caught. + +For example, with an exception Xception, +raise (Xception.E (new Xception.t)) +and +try + ... +with Xception.E e -> ... + +list format +----------- +Lists are turned into OCaml native lists. + +Map/Set formats +--------------- +These are both turned into Hashtbl.t's. Set values are bool. + +Services +-------- +The client is a class "client" parametrized on input and output +protocols. The processor is a class parametrized on a handler. A +handler is a class inheriting the iface abstract class. Unlike other +implementations, client does not implement iface since iface functions +must take option arguments so as to deal with the case where a client +does not send all the arguments. diff --git a/lib/ocaml/README-OCamlMakefile b/lib/ocaml/README-OCamlMakefile new file mode 100644 index 000000000..54787b967 --- /dev/null +++ b/lib/ocaml/README-OCamlMakefile @@ -0,0 +1,640 @@ +--------------------------------------------------------------------------- + + Distribution of "ocaml_make" + Copyright (C) 1999 - 2006 Markus Mottl - free to copy and modify! + USE AT YOUR OWN RISK! + +--------------------------------------------------------------------------- + + PREREQUISITES + + *** YOU WILL NEED GNU-MAKE VERSION >3.80 *** + +--------------------------------------------------------------------------- + + Contents of this distribution + +Changes - guess what? ;-) + +OCamlMakefile - Makefile for easy handling of compilation of not so easy + OCaml-projects. It generates dependencies of OCaml-files + automatically, is able to handle "ocamllex"-, + "ocamlyacc"-, IDL- and C-files, knows how to run + preprocessors and generates native- or byte-code, as + executable or as library - with thread-support if you + want! Profiling and debugging support can be added on + the fly! There is also support for installing libraries. + Ah, yes, and you can also create toplevels from any + sources: this allows you immediate interactive testing. + Automatic generation of documentation is easy due to + integration of support for OCamldoc. + +README - this file + +calc/ - Directory containing a quite fully-featured example + of what "OCamlMakefile" can do for you. This example + makes use of "ocamllex", "ocamlyacc", IDL + C and + threads. + +camlp4/ - This simple example demonstrates how to automatically + preprocess files with the camlp4-preprocessor. + +gtk/ - Demonstration of how to use OCamlMakefile with GTK + and threads. Courtesy of Tim Freeman <tim@fungible.com>. + +idl/ - Contains a very small example of how to use + "camlidl" together with "OCamlMakefile". Also intended + to show, how easy it is to interface OCaml and C. + +threads/ - Two examples of how to use threads (originally + posted by Xavier Leroy some time ago). Shows the use of + "OCamlMakefile" in an environment of multiple compilation + targets. + +--------------------------------------------------------------------------- + + Why should you use it? + +For several reasons: + + * It is well-tested (I use it in all of my projects). + + * In contrast to most other approaches it generates dependencies + correctly by ensuring that all automatically generated OCaml-files + exist before dependency calculation. This is the only way to + guarantee that "ocamldep" works correctly. + + * It is extremely convenient (at least I think so ;-). + Even quite complex compilation processes (see example "calc.ml") + need very little information to work correctly - actually just about + the minimum (file names of sources). + +--------------------------------------------------------------------------- + + When you shouldn't use it... + +In projects where every compilation unit needs different flags - but +in such complicated cases you will be on your own anyway. Luckily, +this doesn't happen too frequently... + +--------------------------------------------------------------------------- + + How to use "OCamlMakefile" in your own project + (Take a look at the examples for a quick introduction!) + +Create your project-specific "Makefile" in the appropriate directory. + +Now there are two ways of making use of "OCamlMakefile": + + 1) Have a look at the default settings in "OCamlMakefile" and set + them to the values that are vaild on your system - whether the + path to the standard libraries is ok, what executables shall be + used, etc... + + 2) Copy it into the directory of the project to be compiled. + Add "-include OCamlMakefile" as a last line of your "Makefile". + + 3) Put it somewhere else on the system. In this case you will have to + set a variable "OCAMLMAKEFILE" in your project-specific "Makefile". + This is the way in which the examples are written: so you need + only one version of "OCamlMakefile" to manage all your projects! + See the examples for details. + +You should usually specify two further variables for your project: + + * SOURCES (default: foo.ml) + * RESULT (default: foo) + +Put all the sources necessary for a target into variable "SOURCES". +Then set "RESULT" to the name of the target. If you want to generate +libraries, you should *not* specify the suffix (".cma", ".cmxa", ".a") +- it will be added automatically if you specify that you want to build +a library. + + ** Don't forget to add the ".mli"-files, too! ** + ** Don't forget that order of the source files matters! ** + +The order is important, because it matters during linking anyway +due to potential side effects caused at program startup. This is +why OCamlMakefile does not attempt to partially order dependencies by +itself, which might confuse users even more. It just compiles and links +OCaml-sources in the order specified by the user, even if it could +determine automatically that the order cannot be correct. + +The minimum of your "Makefile" looks like this (assuming that +"OCamlMakefile" is in the search path of "make"): + + -include OCamlMakefile + +This will assume that you want to compile a file "foo.ml" to a binary +"foo". + +Otherwise, your Makefile will probably contain something like this: + + SOURCES = foo.ml + RESULT = foo + -include OCamlMakefile + +Be careful with the names you put into these variables: if they are wrong, +a "make clean" might erase the wrong files - but I know you will not do +that ;-) + +A simple "make" will generate a byte-code executable. If you want to +change this, you may add an "all"-rule that generates something else. + +E.g.: + + SOURCES = foo.ml + RESULT = foo + all: native-code-library + -include OCamlMakefile + +This will build a native-code library "foo.cmxa" (+ "foo.a") from file +"foo.ml". + +You may even build several targets at once. To produce byte- and native-code +executables with one "make", add the following rule: + + all: byte-code native-code + +You will probably want to use a different suffix for each of these targets +so that the result will not be overwritten (see optional variables below +for details). + +You may also tell "make" at the command-line what kind of target to +produce (e.g. "make nc"). Here all the possibilities with shortcuts +between parenthesis: + + * byte-code (bc) + * byte-code-nolink (bcnl) - no linking stage + * byte-code-library (bcl) + * native-code (nc) + * native-code-nolink (ncnl) - no linking stage + * native-code-library (ncl) + * debug-code (dc) + * debug-code-nolink (dcnl) - no linking stage + * debug-code-library (dcl) + * profiling-byte-code (pbc) + * profiling-byte-code-library (pbcl) + * profiling-native-code (pnc) + * profiling-native-code-library (pncl) + * byte-code-dll (bcd) + * native-code-dll (ncd) + * pack-byte-code (pabc) + * pack-native-code (panc) + * toplevel interpreter (top) + * subprojs + +Here a short note concerning building and linking byte code libraries +with C-files: + + OCaml links C-object files only when they are used in an executable. + After compilation they should be placed in some directory that is in + your include path if you link your library against an executable. + + It is sometimes more convenient to link all C-object files into a + single C-library. Then you have to override the automatic link flags + of your library using "-noautolink" and add another linkflag that + links in your C-library explicitly. + +What concerns maintainance: + + "make clean" removes all (all!) automatically generated files - so + again: make sure your variables are ok! + + "make cleanup" is similar to "make clean" but leaves executables. + +Another way to destroy some important files is by having "OCamlMakefile" +automatically generate files with the same name. Read the documentation +about the tools in the OCaml-distribution to see what kind of files are +generated. "OCamlMakefile" additionally generates ('%' is basename of +source file): + + %_idl.c - "camlidl" generates a file "%.c" from "%.idl", but this is + not such a good idea, because when generating native-code, + both the file "%.c" and "%.ml" would generate files "%.o" + which would overwrite each other. Thus, "OCamlMakefile" + renames "%.c" to "%_idl.c" to work around this problem. + +The dependencies are stored in three different subdirectories (dot dirs): + + ._d - contains dependencies for .ml-files + ._bcdi - contains byte code dependencies for .mli-files + ._ncdi - contains native code dependencies for .mli-files + +The endings of the dependency files are: "%.d" for those generated from +"%.ml"-files, "%.di" for ones derived from "%.mli"-files. + +--------------------------------------------------------------------------- + + Debugging + + This is easy: if you discover a bug, just do a "make clean; make dc" + to recompile your project with debugging information. Then you can + immediately apply "ocamldebug" to the executable. + +--------------------------------------------------------------------------- + + Profiling + + For generating code that can be profiled with "ocamlprof" (byte code) + or "gprof" (native code), compile your project with one of the profiling + targets (see targets above). E.g.: + + * "make pbc" will build byte code that can be profiled with + "ocamlprof". + + * "make pnc" will build native code that can be profiled with + "gprof". + + Please note that it is not currently possible to profile byte code with + threads. OCamlMakefile will force an error if you try to do this. + + A short hint for DEC Alpha-users (under Digital Unix): you may also + compile your sources to native code without any further profiling + options/targets. Then call "pixie my_exec", "my_exec" being your + executable. This will produce (among other files) an executable + "my_exec.pixie". Call it and it will produce profiling information which + can be analysed using "prof -pixie my_exec". The resulting information + is extremely detailed and allows analysis up to the clock cycle level... + +--------------------------------------------------------------------------- + + Using Preprocessors + + Because one could employ any kind of program that reads from standard + input and prints to standard output as preprocessor, there cannot be any + default way to handle all of them correctly without further knowledge. + + Therefore you have to cooperate a bit with OCamlMakefile to let + preprocessing happen automatically. Basically, this only requires + that you put a comment into the first line of files that should be + preprocessed, e.g.: + + (*pp cat *) + ... rest of program ... + + OCamlMakefile looks at the first line of your files, and if it finds + a comment that starts with "(*pp", then it will assume that the + rest of the comment tells it how to correctly call the appropriate + preprocessor. In this case the program "cat" will be called, which will, + of course, just output the source text again without changing it. + + If you are, for example, an advocate of the new "revised syntax", + which is supported by the camlp4 preprocessor, you could simply write: + + (*pp camlp4r *) + ... rest of program in revised syntax ... + + Simple, isn't it? + + If you want to write your own syntax extensions, just take a look at the + example in the directory "camlp4": it implements the "repeat ... until" + extension as described in the camlp4-tutorial. + +--------------------------------------------------------------------------- + + Library (Un-)Installation Support + + OCamlMakefile contains two targets using "ocamlfind" for this purpose: + + * libinstall + * libuninstall + + These two targets require the existence of the variable + "LIBINSTALL_FILES", which should be set to all the files that you + want to install in the library directory (usually %.mli, %.cmi, %.cma, + %.cmxa, %.a and possibly further C-libraries). The target "libinstall" + has the dependency "all" to force compilation of the library so make + sure you define target "all" in your Makefile appropriately. + + The targets inform the user about the configured install path and ask + for confirmation to (un)install there. If you want to use them, it + is often a good idea to just alias them in your Makefile to "install" + and "uninstall" respectively. + + Two other targets allow installation of files into a particular + directory (without using ocamlfind): + + * rawinstall + * rawuninstall + +--------------------------------------------------------------------------- + + Building toplevels + + There is just one target for this: + + * top + + The generated file can be used immediately for interactive sessions - + even with scanners, parsers, C-files, etc.! + +--------------------------------------------------------------------------- + + Generating documentation + + The following targets are supported: + + * htdoc - generates HTML-documentation + * ladoc - generates Latex-documentation + * psdoc - generates PostScript-documentation + * pdfdoc - generates PDF-documentation + * doc - generates all supported forms of documentation + * clean-doc - generates all supported forms of documentation + + All of them generate a sub-directory "doc". More precisely, for HTML it + is "doc/$(RESULT)/html" and for Latex, PostScript and PDF the directory + "doc/$(RESULT)/latex". See the OCamldoc-manual for details and the + optional variables below for settings you can control. + +--------------------------------------------------------------------------- + + Handling subprojects + + You can have several targets in the same directory and manage them + from within an single Makefile. + + Give each subproject a name, e.g. "p1", "p2", etc. Then you export + settings specific to each project by using variables of the form + "PROJ_p1", "PROJ_p2", etc. E.g.: + + define PROJ_p1 + SOURCES="foo.ml main.ml" + RESULT="p1" + OCAMLFLAGS="-unsafe" + endef + export PROJ_p1 + + define PROJ_p2 + ... + endef + export PROJ_p2 + + You may also export common settings used by all projects directly, e.g. + "export THREADS = y". + + Now it is a good idea to define, which projects should be affected by + commands by default. E.g.: + + ifndef SUBPROJS + export SUBPROJS = p1 p2 + endif + + This will automatically generate a given target for all those + subprojects if this variable has not been defined in the shell + environment or in the command line of the make-invocation by the user. + E.g., "make dc" will generate debug code for all subprojects. + + Then you need to define a default action for your subprojects if "make" + has been called without arguments: + + all: bc + + This will build byte code by default for all subprojects. + + Finally, you'll have to define a catch-all target that uses the target + provided by the user for all subprojects. Just add (assuming that + OCAMLMAKEFILE has been defined appropriately): + + %: + @make -f $(OCAMLMAKEFILE) subprojs SUBTARGET=$@ + + See the "threads"-directory in the distribution for a short example! + +--------------------------------------------------------------------------- + + Optional variables that may be passed to "OCamlMakefile" + + * LIB_PACK_NAME - packs all modules of a library into a module whose + name is given in variable "LIB_PACK_NAME". + + * RES_CLIB_SUF - when building a library that contains C-stubs, this + variable controls the suffix appended to the name + of the C-library (default: "_stubs"). + + * THREADS - say "THREADS = yes" if you need thread support compiled in, + otherwise leave it away. + + * VMTHREADS - say "VMTHREADS = yes" if you want to force VM-level + scheduling of threads (byte-code only). + + * ANNOTATE - say "ANNOTATE = yes" to generate type annotation files + (.annot) to support displaying of type information + in editors. + + * USE_CAMLP4 - say "USE_CAMLP4 = yes" in your "Makefile" if you + want to include the camlp4 directory during the + build process, otherwise leave it away. + + * INCDIRS - directories that should be searched for ".cmi"- and + ".cmo"-files. You need not write "-I ..." - just the + plain names. + * LIBDIRS - directories that should be searched for libraries + Also just put the plain paths into this variable + * EXTLIBDIRS - Same as "LIBDIRS", but paths in this variable are + also added to the binary via the "-R"-flag so that + dynamic libraries in non-standard places can be found. + * RESULTDEPS - Targets on which results (executables or libraries) + should additionally depend. + + * PACKS - adds packages under control of "findlib". + + * PREDS - specifies "findlib"-predicates. + + * LIBS - OCaml-libraries that should be linked (just plain names). + E.g. if you want to link the Str-library, just write + "str" (without quotes). + The new OCaml-compiler handles libraries in such + a way that they "remember" whether they have to + be linked against a C-library and it gets linked + in automatically. + If there is a slash in the library name (such as + "./str" or "lib/foo") then make is told that the + generated files depend on the library. This + helps to ensure that changes to your libraries are + taken into account, which is important if you are + regenerating your libraries frequently. + * CLIBS - C-libraries that should be linked (just plain names). + + * PRE_TARGETS - set this to a list of target files that you want + to have buildt before dependency calculation actually + takes place. E.g. use this to automatically compile + modules needed by camlp4, which have to be available + before other modules can be parsed at all. + + ** WARNING **: the files mentioned in this variable + will be removed when "make clean" is executed! + + * LIBINSTALL_FILES - the files of a library that should be installed + using "findlib". Default: + + $(RESULT).mli $(RESULT).cmi $(RESULT).cma + $(RESULT).cmxa $(RESULT).a lib$(RESULT).a + + * OCAML_LIB_INSTALL - target directory for "rawinstall/rawuninstall". + (default: $(OCAMLLIBPATH)/contrib) + + * DOC_FILES - names of files from which documentation is generated. + (default: all .mli-files in your $(SOURCES)). + + * DOC_DIR - name of directory where documentation should be stored. + + * OCAMLFLAGS - flags passed to the compilers + * OCAMLBCFLAGS - flags passed to the byte code compiler only + * OCAMLNCFLAGS - flags passed to the native code compiler only + + * OCAMLLDFLAGS - flags passed to the OCaml-linker + * OCAMLBLDFLAGS - flags passed to the OCaml-linker when linking byte code + * OCAMLNLDFLAGS - flags passed to the OCaml-linker when linking + native code + + * OCAMLMKLIB_FLAGS - flags passed to the OCaml library tool + + * OCAMLCPFLAGS - profiling flags passed to "ocamlcp" (default: "a") + + * PPFLAGS - additional flags passed to the preprocessor (default: none) + + * LFLAGS - flags passed to "ocamllex" + * YFLAGS - flags passed to "ocamlyacc" + * IDLFLAGS - flags passed to "camlidl" + + * OCAMLDOCFLAGS - flags passed to "ocamldoc" + + * OCAMLFIND_INSTFLAGS - flags passed to "ocamlfind" during installation + (default: none) + + * DVIPSFLAGS - flags passed to dvips + (when generating documentation in PostScript). + + * STATIC - set this variable if you want to force creation + of static libraries + + * CC - the C-compiler to be used + * CXX - the C++-compiler to be used + + * CFLAGS - additional flags passed to the C-compiler. + The flag "-DNATIVE_CODE" will be passed automatically + if you choose to build native code. This allows you + to compile your C-files conditionally. But please + note: You should do a "make clean" or remove the + object files manually or touch the %.c-files: + otherwise, they may not be correctly recompiled + between different builds. + + * CXXFLAGS - additional flags passed to the C++-compiler. + + * CPPFLAGS - additional flags passed to the C-preprocessor. + + * CFRAMEWORKS - Objective-C framework to pass to linker on MacOS X. + + * LDFLAGS - additional flags passed to the C-linker + + * RPATH_FLAG - flag passed through to the C-linker to set a path for + dynamic libraries. May need to be set by user on + exotic platforms. (default: "-R"). + + * ELF_RPATH_FLAG - this flag is used to set the rpath on ELF-platforms. + (default: "-R") + + * ELF_RPATH - if this flag is "yes", then the RPATH_FLAG will be + passed by "-Wl" to the linker as normal on + ELF-platforms. + + * OCAMLLIBPATH - path to the OCaml-standard-libraries + (first default: `$(OCAMLC) -where`) + (second default: "/usr/local/lib/ocaml") + + * OCAML_DEFAULT_DIRS - additional path in which the user can supply + default directories to his own collection of + libraries. The idea is to pass this as an environment + variable so that the Makefiles do not have to contain + this path all the time. + + * OCAMLFIND - ocamlfind from findlib (default: "ocamlfind") + * OCAMLC - byte-code compiler (default: "ocamlc") + * OCAMLOPT - native-code compiler (default: "ocamlopt") + * OCAMLMKTOP - top-level compiler (default: "ocamlmktop") + * OCAMLCP - profiling byte-code compiler (default: "ocamlcp") + * OCAMLDEP - dependency generator (default: "ocamldep") + * OCAMLLEX - scanner generator (default: "ocamllex") + * OCAMLYACC - parser generator (default: "ocamlyacc") + * OCAMLMKLIB - tool to create libraries (default: "ocamlmklib") + * CAMLIDL - IDL-code generator (default: "camlidl") + * CAMLIDLDLL - IDL-utility (default: "camlidldll") + * CAMLP4 - camlp4 preprocessor (default: "camlp4") + * OCAMLDOC - OCamldoc-command (default: "ocamldoc") + + * LATEX - Latex-processor (default: "latex") + * DVIPS - dvips-command (default: "dvips") + * PS2PDF - PostScript-to-PDF converter (default: "ps2pdf") + + * CAMELEON_REPORT - report tool of Cameleon (default: "report") + * CAMELEON_REPORT_FLAGS - flags for the report tool of Cameleon + + * CAMELEON_ZOGGY - zoggy tool of Cameleon + (default: "camlp4o pa_zog.cma pr_o.cmo") + * CAMELEON_ZOGGY_FLAGS - flags for the zoggy tool of Cameleon + + * OCAML_GLADECC - Glade compiler for OCaml (default: "lablgladecc2") + * OCAML_GLADECC_FLAGS - flags for the Glade compiler + + * OXRIDL - OXRIDL-generator (default: "oxridl") + + * NOIDLHEADER - set to "yes" to prohibit "OCamlMakefile" from using + the default camlidl-flag "-header". + + * NO_CUSTOM - Prevent linking in custom mode. + + * QUIET - unsetting this variable (e.g. "make QUIET=") + will print all executed commands, including + intermediate ones. This allows more comfortable + debugging when things go wrong during a build. + + * REALLY_QUIET - when set this flag turns off output from some commands. + + * OCAMLMAKEFILE - location of (=path to) this "OCamlMakefile". + Because it calles itself recursively, it has to + know where it is. (default: "OCamlMakefile" = + local directory) + + * BCSUFFIX - Suffix for all byte-code files. E.g.: + + RESULT = foo + BCSUFFIX = _bc + + This will produce byte-code executables/libraries + with basename "foo_bc". + + * NCSUFFIX - Similar to "BCSUFFIX", but for native-code files. + * TOPSUFFIX - Suffix added to toplevel interpreters (default: ".top") + + * SUBPROJS - variable containing the names of subprojects to be + compiled. + + * SUBTARGET - target to be built for all projects in variable + SUBPROJS. + +--------------------------------------------------------------------------- + + Optional variables for Windows users + + * MINGW - variable to detect the MINGW-environment + * MSVC - variable to detect the MSVC-compiler + +--------------------------------------------------------------------------- + +Up-to-date information (newest release of distribution) can always be +found at: + + http://www.ocaml.info/home/ocaml_sources.html + +--------------------------------------------------------------------------- + +Enjoy! + +New York, 2007-04-22 +Markus Mottl + +e-mail: markus.mottl@gmail.com +WWW: http://www.ocaml.info diff --git a/lib/ocaml/TODO b/lib/ocaml/TODO new file mode 100644 index 000000000..4d1dc771b --- /dev/null +++ b/lib/ocaml/TODO @@ -0,0 +1,5 @@ +Write interfaces +Clean up the code generator +Avoid capture properly instead of relying on the user not to use _ + + diff --git a/lib/ocaml/src/Makefile b/lib/ocaml/src/Makefile new file mode 100644 index 000000000..42ec8dbdb --- /dev/null +++ b/lib/ocaml/src/Makefile @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SOURCES = Thrift.ml TBinaryProtocol.ml TSocket.ml TChannelTransport.ml TServer.ml TSimpleServer.ml TServerSocket.ml TThreadedServer.ml +RESULT = thrift +LIBS = unix threads +THREADS = yes +all: native-code-library debug-code-library top +OCAMLMAKEFILE = ../OCamlMakefile +include $(OCAMLMAKEFILE) diff --git a/lib/ocaml/src/TBinaryProtocol.ml b/lib/ocaml/src/TBinaryProtocol.ml new file mode 100644 index 000000000..a06cc9a90 --- /dev/null +++ b/lib/ocaml/src/TBinaryProtocol.ml @@ -0,0 +1,171 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +module P = Protocol + +let get_byte i b = 255 land (i lsr (8*b)) +let get_byte64 i b = 255 land (Int64.to_int (Int64.shift_right i (8*b))) + + +let tv = P.t_type_to_i +let vt = P.t_type_of_i + + +let comp_int b n = + let s = ref 0l in + let sb = 32 - 8*n in + for i=0 to (n-1) do + s:= Int32.logor !s (Int32.shift_left (Int32.of_int (int_of_char b.[i])) (8*(n-1-i))) + done; + Int32.to_int (Int32.shift_right (Int32.shift_left !s sb) sb) + +let comp_int64 b n = + let s = ref 0L in + for i=0 to (n-1) do + s:=Int64.logor !s (Int64.shift_left (Int64.of_int (int_of_char b.[i])) (8*(n-1-i))) + done; + !s + +let version_mask = 0xffff0000 +let version_1 = 0x80010000 + +class t trans = +object (self) + inherit P.t trans + val ibyte = String.create 8 + method writeBool b = + ibyte.[0] <- char_of_int (if b then 1 else 0); + trans#write ibyte 0 1 + method writeByte i = + ibyte.[0] <- char_of_int (get_byte i 0); + trans#write ibyte 0 1 + method writeI16 i = + let gb = get_byte i in + ibyte.[1] <- char_of_int (gb 0); + ibyte.[0] <- char_of_int (gb 1); + trans#write ibyte 0 2 + method writeI32 i = + let gb = get_byte i in + for i=0 to 3 do + ibyte.[3-i] <- char_of_int (gb i) + done; + trans#write ibyte 0 4 + method writeI64 i= + let gb = get_byte64 i in + for i=0 to 7 do + ibyte.[7-i] <- char_of_int (gb i) + done; + trans#write ibyte 0 8 + method writeDouble d = + self#writeI64 (Int64.bits_of_float d) + method writeString s= + let n = String.length s in + self#writeI32(n); + trans#write s 0 n + method writeBinary a = self#writeString a + method writeMessageBegin (n,t,s) = + self#writeI32 (version_1 lor (P.message_type_to_i t)); + self#writeString n; + self#writeI32 s + method writeMessageEnd = () + method writeStructBegin s = () + method writeStructEnd = () + method writeFieldBegin (n,t,i) = + self#writeByte (tv t); + self#writeI16 i + method writeFieldEnd = () + method writeFieldStop = + self#writeByte (tv (Protocol.T_STOP)) + method writeMapBegin (k,v,s) = + self#writeByte (tv k); + self#writeByte (tv v); + self#writeI32 s + method writeMapEnd = () + method writeListBegin (t,s) = + self#writeByte (tv t); + self#writeI32 s + method writeListEnd = () + method writeSetBegin (t,s) = + self#writeByte (tv t); + self#writeI32 s + method writeSetEnd = () + method readByte = + ignore (trans#readAll ibyte 0 1); + (comp_int ibyte 1) + method readI16 = + ignore (trans#readAll ibyte 0 2); + comp_int ibyte 2 + method readI32 = + ignore (trans#readAll ibyte 0 4); + comp_int ibyte 4 + method readI64 = + ignore (trans#readAll ibyte 0 8); + comp_int64 ibyte 8 + method readDouble = + Int64.float_of_bits (self#readI64) + method readBool = + self#readByte = 1 + method readString = + let sz = self#readI32 in + let buf = String.create sz in + ignore (trans#readAll buf 0 sz); + buf + method readBinary = self#readString + method readMessageBegin = + let ver = self#readI32 in + if (ver land version_mask != version_1) then + (print_int ver; + raise (P.E (P.BAD_VERSION, "Missing version identifier"))) + else + let s = self#readString in + let mt = P.message_type_of_i (ver land 0xFF) in + (s,mt, self#readI32) + method readMessageEnd = () + method readStructBegin = + "" + method readStructEnd = () + method readFieldBegin = + let t = (vt (self#readByte)) + in + if t != P.T_STOP then + ("",t,self#readI16) + else ("",t,0); + method readFieldEnd = () + method readMapBegin = + let kt = vt (self#readByte) in + let vt = vt (self#readByte) in + (kt,vt, self#readI32) + method readMapEnd = () + method readListBegin = + let t = vt (self#readByte) in + (t,self#readI32) + method readListEnd = () + method readSetBegin = + let t = vt (self#readByte) in + (t, self#readI32); + method readSetEnd = () +end + +class factory = +object + inherit P.factory + method getProtocol tr = new t tr +end diff --git a/lib/ocaml/src/TChannelTransport.ml b/lib/ocaml/src/TChannelTransport.ml new file mode 100644 index 000000000..0f7d616f5 --- /dev/null +++ b/lib/ocaml/src/TChannelTransport.ml @@ -0,0 +1,39 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift +module T = Transport + +class t (i,o) = +object (self) + val mutable opened = true + inherit Transport.t + method isOpen = opened + method opn = () + method close = close_in i; opened <- false + method read buf off len = + if opened then + try + really_input i buf off len; len + with _ -> raise (T.E (T.UNKNOWN, ("TChannelTransport: Could not read "^(string_of_int len)))) + else + raise (T.E (T.NOT_OPEN, "TChannelTransport: Channel was closed")) + method write buf off len = output o buf off len + method flush = flush o +end diff --git a/lib/ocaml/src/TServer.ml b/lib/ocaml/src/TServer.ml new file mode 100644 index 000000000..fc51efa8f --- /dev/null +++ b/lib/ocaml/src/TServer.ml @@ -0,0 +1,42 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +class virtual t + (pf : Processor.t) + (st : Transport.server_t) + (tf : Transport.factory) + (ipf : Protocol.factory) + (opf : Protocol.factory)= +object + method virtual serve : unit +end;; + + + +let run_basic_server proc port = + Unix.establish_server (fun inp -> fun out -> + let trans = new TChannelTransport.t (inp,out) in + let proto = new TBinaryProtocol.t (trans :> Transport.t) in + try + while proc#process proto proto do () done; () + with e -> ()) (Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1",port)) + + diff --git a/lib/ocaml/src/TServerSocket.ml b/lib/ocaml/src/TServerSocket.ml new file mode 100644 index 000000000..405ef82c1 --- /dev/null +++ b/lib/ocaml/src/TServerSocket.ml @@ -0,0 +1,41 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +class t port = +object + inherit Transport.server_t + val mutable sock = None + method listen = + let s = Unix.socket Unix.PF_INET Unix.SOCK_STREAM 0 in + sock <- Some s; + Unix.bind s (Unix.ADDR_INET (Unix.inet_addr_any, port)); + Unix.listen s 256 + method close = + match sock with + Some s -> Unix.shutdown s Unix.SHUTDOWN_ALL; Unix.close s; + sock <- None + | _ -> () + method acceptImpl = + match sock with + Some s -> let (fd,_) = Unix.accept s in + new TChannelTransport.t (Unix.in_channel_of_descr fd,Unix.out_channel_of_descr fd) + | _ -> raise (Transport.E (Transport.NOT_OPEN,"TServerSocket: Not listening but tried to accept")) +end diff --git a/lib/ocaml/src/TSimpleServer.ml b/lib/ocaml/src/TSimpleServer.ml new file mode 100644 index 000000000..d19d8c557 --- /dev/null +++ b/lib/ocaml/src/TSimpleServer.ml @@ -0,0 +1,38 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift +module S = TServer + +class t pf st tf ipf opf = +object + inherit S.t pf st tf ipf opf + method serve = + try + st#listen; + let c = st#accept in + let trans = tf#getTransport c in + let inp = ipf#getProtocol trans in + let op = opf#getProtocol trans in + try + while (pf#process inp op) do () done; + trans#close + with e -> trans#close; raise e + with _ -> () +end diff --git a/lib/ocaml/src/TSocket.ml b/lib/ocaml/src/TSocket.ml new file mode 100644 index 000000000..109e11c56 --- /dev/null +++ b/lib/ocaml/src/TSocket.ml @@ -0,0 +1,59 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +module T = Transport + +class t host port= +object (self) + inherit T.t + val mutable chans = None + method isOpen = chans != None + method opn = + try + let addr = (let {Unix.h_addr_list=x} = Unix.gethostbyname host in x.(0)) in + chans <- Some(Unix.open_connection (Unix.ADDR_INET (addr,port))) + with + Unix.Unix_error (e,fn,_) -> raise (T.E (T.NOT_OPEN, ("TSocket: Could not connect to "^host^":"^(string_of_int port)^" because: "^fn^":"^(Unix.error_message e)))) + | _ -> raise (T.E (T.NOT_OPEN, ("TSocket: Could not connect to "^host^":"^(string_of_int port)))) + + method close = + match chans with + None -> () + | Some(inc,out) -> (Unix.shutdown_connection inc; + close_in inc; + chans <- None) + method read buf off len = match chans with + None -> raise (T.E (T.NOT_OPEN, "TSocket: Socket not open")) + | Some(i,o) -> + try + really_input i buf off len; len + with + Unix.Unix_error (e,fn,_) -> raise (T.E (T.UNKNOWN, ("TSocket: Could not read "^(string_of_int len)^" from "^host^":"^(string_of_int port)^" because: "^fn^":"^(Unix.error_message e)))) + | _ -> raise (T.E (T.UNKNOWN, ("TSocket: Could not read "^(string_of_int len)^" from "^host^":"^(string_of_int port)))) + method write buf off len = match chans with + None -> raise (T.E (T.NOT_OPEN, "TSocket: Socket not open")) + | Some(i,o) -> output o buf off len + method flush = match chans with + None -> raise (T.E (T.NOT_OPEN, "TSocket: Socket not open")) + | Some(i,o) -> flush o +end + + diff --git a/lib/ocaml/src/TThreadedServer.ml b/lib/ocaml/src/TThreadedServer.ml new file mode 100644 index 000000000..4462dbd73 --- /dev/null +++ b/lib/ocaml/src/TThreadedServer.ml @@ -0,0 +1,45 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +class t + (pf : Processor.t) + (st : Transport.server_t) + (tf : Transport.factory) + (ipf : Protocol.factory) + (opf : Protocol.factory)= +object + inherit TServer.t pf st tf ipf opf + method serve = + st#listen; + while true do + let tr = tf#getTransport (st#accept) in + ignore (Thread.create + (fun _ -> + let ip = ipf#getProtocol tr in + let op = opf#getProtocol tr in + try + while pf#process ip op do + () + done + with _ -> ()) ()) + done +end + diff --git a/lib/ocaml/src/Thrift.ml b/lib/ocaml/src/Thrift.ml new file mode 100644 index 000000000..8dc9afa35 --- /dev/null +++ b/lib/ocaml/src/Thrift.ml @@ -0,0 +1,368 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +exception Break;; +exception Thrift_error;; +exception Field_empty of string;; + +class t_exn = +object + val mutable message = "" + method get_message = message + method set_message s = message <- s +end;; + +module Transport = +struct + type exn_type = + | UNKNOWN + | NOT_OPEN + | ALREADY_OPEN + | TIMED_OUT + | END_OF_FILE;; + + exception E of exn_type * string + + class virtual t = + object (self) + method virtual isOpen : bool + method virtual opn : unit + method virtual close : unit + method virtual read : string -> int -> int -> int + method readAll buf off len = + let got = ref 0 in + let ret = ref 0 in + while !got < len do + ret := self#read buf (off+(!got)) (len - (!got)); + if !ret <= 0 then + raise (E (UNKNOWN, "Cannot read. Remote side has closed.")); + got := !got + !ret + done; + !got + method virtual write : string -> int -> int -> unit + method virtual flush : unit + end + + class factory = + object + method getTransport (t : t) = t + end + + class virtual server_t = + object (self) + method virtual listen : unit + method accept = self#acceptImpl + method virtual close : unit + method virtual acceptImpl : t + end + +end;; + + + +module Protocol = +struct + type t_type = + | T_STOP + | T_VOID + | T_BOOL + | T_BYTE + | T_I08 + | T_I16 + | T_I32 + | T_U64 + | T_I64 + | T_DOUBLE + | T_STRING + | T_UTF7 + | T_STRUCT + | T_MAP + | T_SET + | T_LIST + | T_UTF8 + | T_UTF16 + + let t_type_to_i = function + T_STOP -> 0 + | T_VOID -> 1 + | T_BOOL -> 2 + | T_BYTE -> 3 + | T_I08 -> 3 + | T_I16 -> 6 + | T_I32 -> 8 + | T_U64 -> 9 + | T_I64 -> 10 + | T_DOUBLE -> 4 + | T_STRING -> 11 + | T_UTF7 -> 11 + | T_STRUCT -> 12 + | T_MAP -> 13 + | T_SET -> 14 + | T_LIST -> 15 + | T_UTF8 -> 16 + | T_UTF16 -> 17 + + let t_type_of_i = function + 0 -> T_STOP + | 1 -> T_VOID + | 2 -> T_BOOL + | 3 -> T_BYTE + | 6-> T_I16 + | 8 -> T_I32 + | 9 -> T_U64 + | 10 -> T_I64 + | 4 -> T_DOUBLE + | 11 -> T_STRING + | 12 -> T_STRUCT + | 13 -> T_MAP + | 14 -> T_SET + | 15 -> T_LIST + | 16 -> T_UTF8 + | 17 -> T_UTF16 + | _ -> raise Thrift_error + + type message_type = + | CALL + | REPLY + | EXCEPTION + | ONEWAY + + let message_type_to_i = function + | CALL -> 1 + | REPLY -> 2 + | EXCEPTION -> 3 + | ONEWAY -> 4 + + let message_type_of_i = function + | 1 -> CALL + | 2 -> REPLY + | 3 -> EXCEPTION + | 4 -> ONEWAY + | _ -> raise Thrift_error + + class virtual t (trans: Transport.t) = + object (self) + val mutable trans_ = trans + method getTransport = trans_ + (* writing methods *) + method virtual writeMessageBegin : string * message_type * int -> unit + method virtual writeMessageEnd : unit + method virtual writeStructBegin : string -> unit + method virtual writeStructEnd : unit + method virtual writeFieldBegin : string * t_type * int -> unit + method virtual writeFieldEnd : unit + method virtual writeFieldStop : unit + method virtual writeMapBegin : t_type * t_type * int -> unit + method virtual writeMapEnd : unit + method virtual writeListBegin : t_type * int -> unit + method virtual writeListEnd : unit + method virtual writeSetBegin : t_type * int -> unit + method virtual writeSetEnd : unit + method virtual writeBool : bool -> unit + method virtual writeByte : int -> unit + method virtual writeI16 : int -> unit + method virtual writeI32 : int -> unit + method virtual writeI64 : Int64.t -> unit + method virtual writeDouble : float -> unit + method virtual writeString : string -> unit + method virtual writeBinary : string -> unit + (* reading methods *) + method virtual readMessageBegin : string * message_type * int + method virtual readMessageEnd : unit + method virtual readStructBegin : string + method virtual readStructEnd : unit + method virtual readFieldBegin : string * t_type * int + method virtual readFieldEnd : unit + method virtual readMapBegin : t_type * t_type * int + method virtual readMapEnd : unit + method virtual readListBegin : t_type * int + method virtual readListEnd : unit + method virtual readSetBegin : t_type * int + method virtual readSetEnd : unit + method virtual readBool : bool + method virtual readByte : int + method virtual readI16 : int + method virtual readI32: int + method virtual readI64 : Int64.t + method virtual readDouble : float + method virtual readString : string + method virtual readBinary : string + (* skippage *) + method skip typ = + match typ with + | T_STOP -> () + | T_VOID -> () + | T_BOOL -> ignore self#readBool + | T_BYTE + | T_I08 -> ignore self#readByte + | T_I16 -> ignore self#readI16 + | T_I32 -> ignore self#readI32 + | T_U64 + | T_I64 -> ignore self#readI64 + | T_DOUBLE -> ignore self#readDouble + | T_STRING -> ignore self#readString + | T_UTF7 -> () + | T_STRUCT -> ignore ((ignore self#readStructBegin); + (try + while true do + let (_,t,_) = self#readFieldBegin in + if t = T_STOP then + raise Break + else + (self#skip t; + self#readFieldEnd) + done + with Break -> ()); + self#readStructEnd) + | T_MAP -> ignore (let (k,v,s) = self#readMapBegin in + for i=0 to s do + self#skip k; + self#skip v; + done; + self#readMapEnd) + | T_SET -> ignore (let (t,s) = self#readSetBegin in + for i=0 to s do + self#skip t + done; + self#readSetEnd) + | T_LIST -> ignore (let (t,s) = self#readListBegin in + for i=0 to s do + self#skip t + done; + self#readListEnd) + | T_UTF8 -> () + | T_UTF16 -> () + end + + class virtual factory = + object + method virtual getProtocol : Transport.t -> t + end + + type exn_type = + | UNKNOWN + | INVALID_DATA + | NEGATIVE_SIZE + | SIZE_LIMIT + | BAD_VERSION + + exception E of exn_type * string;; + +end;; + + +module Processor = +struct + class virtual t = + object + method virtual process : Protocol.t -> Protocol.t -> bool + end;; + + class factory (processor : t) = + object + val processor_ = processor + method getProcessor (trans : Transport.t) = processor_ + end;; +end + + +(* Ugly *) +module Application_Exn = +struct + type typ= + | UNKNOWN + | UNKNOWN_METHOD + | INVALID_MESSAGE_TYPE + | WRONG_METHOD_NAME + | BAD_SEQUENCE_ID + | MISSING_RESULT + + let typ_of_i = function + 0 -> UNKNOWN + | 1 -> UNKNOWN_METHOD + | 2 -> INVALID_MESSAGE_TYPE + | 3 -> WRONG_METHOD_NAME + | 4 -> BAD_SEQUENCE_ID + | 5 -> MISSING_RESULT + | _ -> raise Thrift_error;; + let typ_to_i = function + | UNKNOWN -> 0 + | UNKNOWN_METHOD -> 1 + | INVALID_MESSAGE_TYPE -> 2 + | WRONG_METHOD_NAME -> 3 + | BAD_SEQUENCE_ID -> 4 + | MISSING_RESULT -> 5 + + class t = + object (self) + inherit t_exn + val mutable typ = UNKNOWN + method get_type = typ + method set_type t = typ <- t + method write (oprot : Protocol.t) = + oprot#writeStructBegin "TApplicationExeception"; + if self#get_message != "" then + (oprot#writeFieldBegin ("message",Protocol.T_STRING, 1); + oprot#writeString self#get_message; + oprot#writeFieldEnd) + else (); + oprot#writeFieldBegin ("type",Protocol.T_I32,2); + oprot#writeI32 (typ_to_i typ); + oprot#writeFieldEnd; + oprot#writeFieldStop; + oprot#writeStructEnd + end;; + + let create typ msg = + let e = new t in + e#set_type typ; + e#set_message msg; + e + + let read (iprot : Protocol.t) = + let msg = ref "" in + let typ = ref 0 in + ignore iprot#readStructBegin; + (try + while true do + let (name,ft,id) =iprot#readFieldBegin in + if ft = Protocol.T_STOP then + raise Break + else (); + (match id with + | 1 -> (if ft = Protocol.T_STRING then + msg := (iprot#readString) + else + iprot#skip ft) + | 2 -> (if ft = Protocol.T_I32 then + typ := iprot#readI32 + else + iprot#skip ft) + | _ -> iprot#skip ft); + iprot#readFieldEnd + done + with Break -> ()); + iprot#readStructEnd; + let e = new t in + e#set_type (typ_of_i !typ); + e#set_message !msg; + e;; + + exception E of t +end;; diff --git a/lib/perl/Makefile.PL b/lib/perl/Makefile.PL new file mode 100644 index 000000000..94ea37ce8 --- /dev/null +++ b/lib/perl/Makefile.PL @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +use ExtUtils::MakeMaker; +WriteMakefile( 'NAME' => 'Thrift', + 'VERSION_FROM' => 'lib/Thrift.pm', + 'PREREQ_PM' => { + 'Bit::Vector' => 0, + 'Class::Accessor' => 0 + }, + ($] >= 5.005 ? + ( AUTHOR => 'T Jake Luciani <jakers@gmail.com>') : ()), + ); diff --git a/lib/perl/Makefile.am b/lib/perl/Makefile.am new file mode 100644 index 000000000..163d01583 --- /dev/null +++ b/lib/perl/Makefile.am @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SUBDIRS = test + +Makefile-perl.mk : Makefile.PL + $(PERL) Makefile.PL MAKEFILE=Makefile-perl.mk INSTALLDIRS=$(INSTALLDIRS) + +all-local: Makefile-perl.mk + $(MAKE) -f Makefile-perl.mk + find blib -name 'Makefile*' -exec rm -f {} \; + +check-local: + $(PERL) -Iblib/lib -I@abs_srcdir@ -I@builddir@/test/gen-perl \ + @abs_srcdir@/test.pl @abs_srcdir@/test/*.t + +install-exec-local: Makefile-perl.mk + $(MAKE) -f Makefile-perl.mk install DESTDIR=$(DESTDIR)/ + +clean-local: + if test -f Makefile-perl.mk ; then \ + $(MAKE) -f Makefile-perl.mk clean ; \ + fi + rm -f Makefile-perl.mk.old + +EXTRA_DIST = \ + Makefile.PL \ + test.pl \ + lib/Thrift.pm \ + lib/Thrift.pm \ + lib/Thrift/BinaryProtocol.pm \ + lib/Thrift/BufferedTransport.pm \ + lib/Thrift/FramedTransport.pm \ + lib/Thrift/HttpClient.pm \ + lib/Thrift/MemoryBuffer.pm \ + lib/Thrift/Protocol.pm \ + lib/Thrift/Socket.pm \ + lib/Thrift/Transport.pm diff --git a/lib/perl/README b/lib/perl/README new file mode 100644 index 000000000..691488b47 --- /dev/null +++ b/lib/perl/README @@ -0,0 +1,41 @@ +Thrift Perl Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with Perl +===================== + +Thrift requires Perl >= 5.6.0 + +Exceptions are thrown with die so be sure to wrap eval{} statments +around any code that contains exceptions. + +The 64bit Integers work only upto 2^42 on my machine :-? +Math::BigInt is probably needed. + +Please see tutoral and test dirs for examples... + +Dependencies +============ + +Bit::Vector - comes with modern perl installations. +Class::Accessor + diff --git a/lib/perl/lib/Thrift.pm b/lib/perl/lib/Thrift.pm new file mode 100644 index 000000000..fe0f8e726 --- /dev/null +++ b/lib/perl/lib/Thrift.pm @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +our $VERSION = '0.1'; + +require 5.6.0; +use strict; +use warnings; + +# +# Data types that can be sent via Thrift +# +package TType; +use constant STOP => 0; +use constant VOID => 1; +use constant BOOL => 2; +use constant BYTE => 3; +use constant I08 => 3; +use constant DOUBLE => 4; +use constant I16 => 6; +use constant I32 => 8; +use constant I64 => 10; +use constant STRING => 11; +use constant UTF7 => 11; +use constant STRUCT => 12; +use constant MAP => 13; +use constant SET => 14; +use constant LIST => 15; +use constant UTF8 => 16; +use constant UTF16 => 17; +1; + +# +# Message types for RPC +# +package TMessageType; +use constant CALL => 1; +use constant REPLY => 2; +use constant EXCEPTION => 3; +use constant ONEWAY => 4; +1; + +package Thrift::TException; + +sub new { + my $classname = shift; + my $self = {message => shift, code => shift || 0}; + + return bless($self,$classname); +} +1; + +package TApplicationException; +use base('Thrift::TException'); + +use constant UNKNOWN => 0; +use constant UNKNOWN_METHOD => 1; +use constant INVALID_MESSAGE_TYPE => 2; +use constant WRONG_METHOD_NAME => 3; +use constant BAD_SEQUENCE_ID => 4; +use constant MISSING_RESULT => 5; + +sub new { + my $classname = shift; + + my $self = $classname->SUPER::new(); + + return bless($self,$classname); +} + +sub read { + my $self = shift; + my $input = shift; + + my $xfer = 0; + my $fname = undef; + my $ftype = 0; + my $fid = 0; + + $xfer += $input->readStructBegin($fname); + + while (1) + { + $xfer += $input->readFieldBegin($fname, $ftype, $fid); + if ($ftype == TType::STOP) { + last; next; + } + + SWITCH: for($fid) + { + /1/ && do{ + + if ($ftype == TType::STRING) { + $xfer += $input->readString($self->{message}); + } else { + $xfer += $input->skip($ftype); + } + + last; + }; + + /2/ && do{ + if ($ftype == TType::I32) { + $xfer += $input->readI32($self->{code}); + } else { + $xfer += $input->skip($ftype); + } + last; + }; + + $xfer += $input->skip($ftype); + } + + $xfer += $input->readFieldEnd(); + } + $xfer += $input->readStructEnd(); + + return $xfer; +} + +sub write { + my $self = shift; + my $output = shift; + + my $xfer = 0; + + $xfer += $output->writeStructBegin('TApplicationException'); + + if ($self->getMessage()) { + $xfer += $output->writeFieldBegin('message', TType::STRING, 1); + $xfer += $output->writeString($self->getMessage()); + $xfer += $output->writeFieldEnd(); + } + + if ($self->getCode()) { + $xfer += $output->writeFieldBegin('type', TType::I32, 2); + $xfer += $output->writeI32($self->getCode()); + $xfer += $output->writeFieldEnd(); + } + + $xfer += $output->writeFieldStop(); + $xfer += $output->writeStructEnd(); + + return $xfer; +} + +sub getMessage +{ + my $self = shift; + + return $self->{message}; +} + +sub getCode +{ + my $self = shift; + + return $self->{code}; +} + +1; diff --git a/lib/perl/lib/Thrift/BinaryProtocol.pm b/lib/perl/lib/Thrift/BinaryProtocol.pm new file mode 100644 index 000000000..0e5d61d38 --- /dev/null +++ b/lib/perl/lib/Thrift/BinaryProtocol.pm @@ -0,0 +1,498 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; + +use strict; +use warnings; + +use utf8; +use Encode; + +use Thrift; +use Thrift::Protocol; + +use Bit::Vector; + +# +# Binary implementation of the Thrift protocol. +# +package Thrift::BinaryProtocol; +use base('Thrift::Protocol'); + +use constant VERSION_MASK => 0xffff0000; +use constant VERSION_1 => 0x80010000; + +sub new +{ + my $classname = shift; + my $trans = shift; + my $self = $classname->SUPER::new($trans); + + return bless($self,$classname); +} + +sub writeMessageBegin +{ + my $self = shift; + my ($name, $type, $seqid) = @_; + + return + $self->writeI32(VERSION_1 | $type) + + $self->writeString($name) + + $self->writeI32($seqid); +} + +sub writeMessageEnd +{ + my $self = shift; + return 0; +} + +sub writeStructBegin{ + my $self = shift; + my $name = shift; + return 0; +} + +sub writeStructEnd +{ + my $self = shift; + return 0; +} + +sub writeFieldBegin +{ + my $self = shift; + my ($fieldName, $fieldType, $fieldId) = @_; + + return + $self->writeByte($fieldType) + + $self->writeI16($fieldId); +} + +sub writeFieldEnd +{ + my $self = shift; + return 0; +} + +sub writeFieldStop +{ + my $self = shift; + return $self->writeByte(TType::STOP); +} + +sub writeMapBegin +{ + my $self = shift; + my ($keyType, $valType, $size) = @_; + + return + $self->writeByte($keyType) + + $self->writeByte($valType) + + $self->writeI32($size); +} + +sub writeMapEnd +{ + my $self = shift; + return 0; +} + +sub writeListBegin +{ + my $self = shift; + my ($elemType, $size) = @_; + + return + $self->writeByte($elemType) + + $self->writeI32($size); +} + +sub writeListEnd +{ + my $self = shift; + return 0; +} + +sub writeSetBegin +{ + my $self = shift; + my ($elemType, $size) = @_; + + return + $self->writeByte($elemType) + + $self->writeI32($size); +} + +sub writeSetEnd +{ + my $self = shift; + return 0; +} + +sub writeBool +{ + my $self = shift; + my $value = shift; + + my $data = pack('c', $value ? 1 : 0); + $self->{trans}->write($data, 1); + return 1; +} + +sub writeByte +{ + my $self = shift; + my $value= shift; + + my $data = pack('c', $value); + $self->{trans}->write($data, 1); + return 1; +} + +sub writeI16 +{ + my $self = shift; + my $value= shift; + + my $data = pack('n', $value); + $self->{trans}->write($data, 2); + return 2; +} + +sub writeI32 +{ + my $self = shift; + my $value= shift; + + my $data = pack('N', $value); + $self->{trans}->write($data, 4); + return 4; +} + +sub writeI64 +{ + my $self = shift; + my $value= shift; + my $data; + + my $vec; + #stop annoying error + $vec = Bit::Vector->new_Dec(64, $value); + $data = pack 'NN', $vec->Chunk_Read(32, 32), $vec->Chunk_Read(32, 0); + + $self->{trans}->write($data, 8); + + return 8; +} + + +sub writeDouble +{ + my $self = shift; + my $value= shift; + + my $data = pack('d', $value); + $self->{trans}->write(scalar reverse($data), 8); + return 8; +} + +sub writeString{ + my $self = shift; + my $value= shift; + + if( utf8::is_utf8($value) ){ + $value = Encode::encode_utf8($value); + } + + my $len = length($value); + + my $result = $self->writeI32($len); + + if ($len) { + $self->{trans}->write($value,$len); + } + return $result + $len; + } + + +# +#All references +# +sub readMessageBegin +{ + my $self = shift; + my ($name, $type, $seqid) = @_; + + my $version = 0; + my $result = $self->readI32(\$version); + if (($version & VERSION_MASK) > 0) { + if (($version & VERSION_MASK) != VERSION_1) { + die new Thrift::TException('Missing version identifier') + } + $$type = $version & 0x000000ff; + return + $result + + $self->readString($name) + + $self->readI32($seqid); + } else { # old client support code + return + $result + + $self->readStringBody($name, $version) + # version here holds the size of the string + $self->readByte($type) + + $self->readI32($seqid); + } +} + +sub readMessageEnd +{ + my $self = shift; + return 0; +} + +sub readStructBegin +{ + my $self = shift; + my $name = shift; + + $$name = ''; + + return 0; +} + +sub readStructEnd +{ + my $self = shift; + return 0; +} + +sub readFieldBegin +{ + my $self = shift; + my ($name, $fieldType, $fieldId) = @_; + + my $result = $self->readByte($fieldType); + + if ($$fieldType == TType::STOP) { + $$fieldId = 0; + return $result; + } + + $result += $self->readI16($fieldId); + + return $result; +} + +sub readFieldEnd() { + my $self = shift; + return 0; +} + +sub readMapBegin +{ + my $self = shift; + my ($keyType, $valType, $size) = @_; + + return + $self->readByte($keyType) + + $self->readByte($valType) + + $self->readI32($size); +} + +sub readMapEnd() +{ + my $self = shift; + return 0; +} + +sub readListBegin +{ + my $self = shift; + my ($elemType, $size) = @_; + + return + $self->readByte($elemType) + + $self->readI32($size); +} + +sub readListEnd +{ + my $self = shift; + return 0; +} + +sub readSetBegin +{ + my $self = shift; + my ($elemType, $size) = @_; + + return + $self->readByte($elemType) + + $self->readI32($size); +} + +sub readSetEnd +{ + my $self = shift; + return 0; +} + +sub readBool +{ + my $self = shift; + my $value = shift; + + my $data = $self->{trans}->readAll(1); + my @arr = unpack('c', $data); + $$value = $arr[0] == 1; + return 1; +} + +sub readByte +{ + my $self = shift; + my $value = shift; + + my $data = $self->{trans}->readAll(1); + my @arr = unpack('c', $data); + $$value = $arr[0]; + return 1; +} + +sub readI16 +{ + my $self = shift; + my $value = shift; + + my $data = $self->{trans}->readAll(2); + + my @arr = unpack('n', $data); + + $$value = $arr[0]; + + if ($$value > 0x7fff) { + $$value = 0 - (($$value - 1) ^ 0xffff); + } + + return 2; +} + +sub readI32 +{ + my $self = shift; + my $value= shift; + + my $data = $self->{trans}->readAll(4); + my @arr = unpack('N', $data); + + $$value = $arr[0]; + if ($$value > 0x7fffffff) { + $$value = 0 - (($$value - 1) ^ 0xffffffff); + } + return 4; +} + +sub readI64 +{ + my $self = shift; + my $value = shift; + + my $data = $self->{trans}->readAll(8); + + my ($hi,$lo)=unpack('NN',$data); + + my $vec = new Bit::Vector(64); + + $vec->Chunk_Store(32,32,$hi); + $vec->Chunk_Store(32,0,$lo); + + $$value = $vec->to_Dec(); + + return 8; +} + +sub readDouble +{ + my $self = shift; + my $value = shift; + + my $data = scalar reverse($self->{trans}->readAll(8)); + my @arr = unpack('d', $data); + + $$value = $arr[0]; + + return 8; +} + +sub readString +{ + my $self = shift; + my $value = shift; + + my $len; + my $result = $self->readI32(\$len); + + if ($len) { + $$value = $self->{trans}->readAll($len); + } else { + $$value = ''; + } + + return $result + $len; +} + +sub readStringBody +{ + my $self = shift; + my $value = shift; + my $len = shift; + + if ($len) { + $$value = $self->{trans}->readAll($len); + } else { + $$value = ''; + } + + return $len; +} + +# +# Binary Protocol Factory +# +package TBinaryProtocolFactory; +use base('TProtocolFactory'); + +sub new +{ + my $classname = shift; + my $self = $classname->SUPER::new(); + + return bless($self,$classname); +} + +sub getProtocol{ + my $self = shift; + my $trans = shift; + + return new TBinaryProtocol($trans); +} + +1; diff --git a/lib/perl/lib/Thrift/BufferedTransport.pm b/lib/perl/lib/Thrift/BufferedTransport.pm new file mode 100644 index 000000000..bef564d67 --- /dev/null +++ b/lib/perl/lib/Thrift/BufferedTransport.pm @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +package Thrift::BufferedTransport; +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + my $transport = shift; + my $rBufSize = shift || 512; + my $wBufSize = shift || 512; + + my $self = { + transport => $transport, + rBufSize => $rBufSize, + wBufSize => $wBufSize, + wBuf => '', + rBuf => '', + }; + + return bless($self,$classname); +} + +sub isOpen +{ + my $self = shift; + + return $self->{transport}->isOpen(); +} + +sub open +{ + my $self = shift; + $self->{transport}->open(); +} + +sub close() +{ + my $self = shift; + $self->{transport}->close(); +} + +sub readAll +{ + my $self = shift; + my $len = shift; + + return $self->{transport}->readAll($len); +} + +sub read +{ + my $self = shift; + my $len = shift; + my $ret; + + # Methinks Perl is already buffering these for us + return $self->{transport}->read($len); +} + +sub write +{ + my $self = shift; + my $buf = shift; + + $self->{wBuf} .= $buf; + if (length($self->{wBuf}) >= $self->{wBufSize}) { + $self->{transport}->write($self->{wBuf}); + $self->{wBuf} = ''; + } +} + +sub flush +{ + my $self = shift; + + if (length($self->{wBuf}) > 0) { + $self->{transport}->write($self->{wBuf}); + $self->{wBuf} = ''; + } + $self->{transport}->flush(); +} + + +1; diff --git a/lib/perl/lib/Thrift/FramedTransport.pm b/lib/perl/lib/Thrift/FramedTransport.pm new file mode 100644 index 000000000..b78b1989f --- /dev/null +++ b/lib/perl/lib/Thrift/FramedTransport.pm @@ -0,0 +1,164 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +# +# Framed transport. Writes and reads data in chunks that are stamped with +# their length. +# +# @package thrift.transport +# +package Thrift::FramedTransport; + +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + my $transport = shift; + my $read = shift || 1; + my $write = shift || 1; + + my $self = { + transport => $transport, + read => $read, + write => $write, + wBuf => '', + rBuf => '', + }; + + return bless($self,$classname); +} + +sub isOpen +{ + my $self = shift; + return $self->{transport}->isOpen(); +} + +sub open +{ + my $self = shift; + + $self->{transport}->open(); +} + +sub close +{ + my $self = shift; + + $self->{transport}->close(); +} + +# +# Reads from the buffer. When more data is required reads another entire +# chunk and serves future reads out of that. +# +# @param int $len How much data +# +sub read +{ + + my $self = shift; + my $len = shift; + + if (!$self->{read}) { + return $self->{transport}->read($len); + } + + if (length($self->{rBuf}) == 0) { + $self->_readFrame(); + } + + + # Just return full buff + if ($len > length($self->{rBuf})) { + my $out = $self->{rBuf}; + $self->{rBuf} = ''; + return $out; + } + + # Return substr + my $out = substr($self->{rBuf}, 0, $len); + $self->{rBuf} = substr($self->{rBuf}, $len); + return $out; +} + +# +# Reads a chunk of data into the internal read buffer. +# (private) +sub _readFrame +{ + my $self = shift; + my $buf = $self->{transport}->readAll(4); + my @val = unpack('N', $buf); + my $sz = $val[0]; + + $self->{rBuf} = $self->{transport}->readAll($sz); +} + +# +# Writes some data to the pending output buffer. +# +# @param string $buf The data +# @param int $len Limit of bytes to write +# +sub write +{ + my $self = shift; + my $buf = shift; + my $len = shift; + + unless($self->{write}) { + return $self->{transport}->write($buf, $len); + } + + if ( defined $len && $len < length($buf)) { + $buf = substr($buf, 0, $len); + } + + $self->{wBuf} .= $buf; + } + +# +# Writes the output buffer to the stream in the format of a 4-byte length +# followed by the actual data. +# +sub flush +{ + my $self = shift; + + unless ($self->{write}) { + return $self->{transport}->flush(); + } + + my $out = pack('N', length($self->{wBuf})); + $out .= $self->{wBuf}; + $self->{transport}->write($out); + $self->{transport}->flush(); + $self->{wBuf} = ''; + +} + +1; diff --git a/lib/perl/lib/Thrift/HttpClient.pm b/lib/perl/lib/Thrift/HttpClient.pm new file mode 100644 index 000000000..d6fc8be34 --- /dev/null +++ b/lib/perl/lib/Thrift/HttpClient.pm @@ -0,0 +1,200 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +use HTTP::Request; +use LWP::UserAgent; +use IO::String; + +package Thrift::HttpClient; + +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + my $url = shift || 'http://localhost:9090'; + my $debugHandler = shift; + + my $out = IO::String->new; + binmode($out); + + my $self = { + url => $url, + out => $out, + debugHandler => $debugHandler, + debug => 0, + sendTimeout => 100, + recvTimeout => 750, + handle => undef, + }; + + return bless($self,$classname); +} + +sub setSendTimeout +{ + my $self = shift; + my $timeout = shift; + + $self->{sendTimeout} = $timeout; +} + +sub setRecvTimeout +{ + my $self = shift; + my $timeout = shift; + + $self->{recvTimeout} = $timeout; +} + + +# +#Sets debugging output on or off +# +# @param bool $debug +# +sub setDebug +{ + my $self = shift; + my $debug = shift; + + $self->{debug} = $debug; +} + +# +# Tests whether this is open +# +# @return bool true if the socket is open +# +sub isOpen +{ + return 1; +} + +sub open {} + +# +# Cleans up the buffer. +# +sub close +{ + my $self = shift; + if (defined($self->{io})) { + close($self->{io}); + $self->{io} = undef; + } +} + +# +# Guarantees that the full amount of data is read. +# +# @return string The data, of exact length +# @throws TTransportException if cannot read data +# +sub readAll +{ + my $self = shift; + my $len = shift; + + my $buf = $self->read($len); + + if (!defined($buf)) { + die new Thrift::TException('TSocket: Could not read '.$len.' bytes from input buffer'); + } + return $buf; +} + +# +# Read and return string +# +sub read +{ + my $self = shift; + my $len = shift; + + my $buf; + + my $in = $self->{in}; + + if (!defined($in)) { + die new Thrift::TException("Response buffer is empty, no request."); + } + eval { + my $ret = sysread($in, $buf, $len); + if (! defined($ret)) { + die new Thrift::TException("No more data available."); + } + }; if($@){ + die new Thrift::TException($@); + } + + return $buf; +} + +# +# Write string +# +sub write +{ + my $self = shift; + my $buf = shift; + $self->{out}->print($buf); +} + +# +# Flush output (do the actual HTTP/HTTPS request) +# +sub flush +{ + my $self = shift; + + my $ua = LWP::UserAgent->new('timeout' => ($self->{sendTimeout} / 1000), + 'agent' => 'Perl/THttpClient' + ); + $ua->default_header('Accept' => 'application/x-thrift'); + $ua->default_header('Content-Type' => 'application/x-thrift'); + $ua->cookie_jar({}); # hash to remember cookies between redirects + + my $out = $self->{out}; + $out->setpos(0); # rewind + my $buf = join('', <$out>); + + my $request = new HTTP::Request(POST => $self->{url}, undef, $buf); + my $response = $ua->request($request); + my $content_ref = $response->content_ref; + + my $in = IO::String->new($content_ref); + binmode($in); + $self->{in} = $in; + $in->setpos(0); # rewind + + # reset write buffer + $out = IO::String->new; + binmode($out); + $self->{out} = $out; +} + +1; diff --git a/lib/perl/lib/Thrift/MemoryBuffer.pm b/lib/perl/lib/Thrift/MemoryBuffer.pm new file mode 100644 index 000000000..32f144241 --- /dev/null +++ b/lib/perl/lib/Thrift/MemoryBuffer.pm @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +package Thrift::MemoryBuffer; +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + + my $bufferSize= shift || 1024; + + my $self = { + buffer => '', + bufferSize=> $bufferSize, + wPos => 0, + rPos => 0, + }; + + return bless($self,$classname); +} + +sub isOpen +{ + return 1; +} + +sub open +{ + +} + +sub close +{ + +} + +sub peek +{ + my $self = shift; + return($self->{rPos} < $self->{wPos}); +} + + +sub getBuffer +{ + my $self = shift; + return $self->{buffer}; +} + +sub resetBuffer +{ + my $self = shift; + + my $new_buffer = shift || ''; + + $self->{buffer} = $new_buffer; + $self->{bufferSize} = length($new_buffer); + $self->{wPos} = length($new_buffer); + $self->{rPos} = 0; +} + +sub available +{ + my $self = shift; + return ($self->{wPos} - $self->{rPos}); +} + +sub read +{ + my $self = shift; + my $len = shift; + my $ret; + + my $avail = ($self->{wPos} - $self->{rPos}); + return '' if $avail == 0; + + #how much to give + my $give = $len; + $give = $avail if $avail < $len; + + $ret = substr($self->{buffer},$self->{rPos},$give); + + $self->{rPos} += $give; + + return $ret; +} + +sub write +{ + my $self = shift; + my $buf = shift; + + $self->{buffer} .= $buf; + $self->{wPos} += length($buf); +} + +sub flush +{ + +} + +1; diff --git a/lib/perl/lib/Thrift/Protocol.pm b/lib/perl/lib/Thrift/Protocol.pm new file mode 100644 index 000000000..034711f3e --- /dev/null +++ b/lib/perl/lib/Thrift/Protocol.pm @@ -0,0 +1,543 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; + +# +# Protocol exceptions +# +package TProtocolException; +use base('Thrift::TException'); + +use constant UNKNOWN => 0; +use constant INVALID_DATA => 1; +use constant NEGATIVE_SIZE => 2; +use constant SIZE_LIMIT => 3; +use constant BAD_VERSION => 4; + +sub new { + my $classname = shift; + + my $self = $classname->SUPER::new(); + + return bless($self,$classname); +} + +# +# Protocol base class module. +# +package Thrift::Protocol; + +sub new { + my $classname = shift; + my $self = {}; + + my $trans = shift; + $self->{trans}= $trans; + + return bless($self,$classname); +} + +sub getTransport +{ + my $self = shift; + + return $self->{trans}; +} + +# +# Writes the message header +# +# @param string $name Function name +# @param int $type message type TMessageType::CALL or TMessageType::REPLY +# @param int $seqid The sequence id of this message +# +sub writeMessageBegin +{ + my ($name, $type, $seqid); + die "abstract"; +} + +# +# Close the message +# +sub writeMessageEnd { + die "abstract"; +} + +# +# Writes a struct header. +# +# @param string $name Struct name +# @throws TException on write error +# @return int How many bytes written +# +sub writeStructBegin { + my ($name); + + die "abstract"; +} + +# +# Close a struct. +# +# @throws TException on write error +# @return int How many bytes written +# +sub writeStructEnd { + die "abstract"; +} + +# +# Starts a field. +# +# @param string $name Field name +# @param int $type Field type +# @param int $fid Field id +# @throws TException on write error +# @return int How many bytes written +# +sub writeFieldBegin { + my ($fieldName, $fieldType, $fieldId); + + die "abstract"; +} + +sub writeFieldEnd { + die "abstract"; +} + +sub writeFieldStop { + die "abstract"; +} + +sub writeMapBegin { + my ($keyType, $valType, $size); + + die "abstract"; +} + +sub writeMapEnd { + die "abstract"; +} + +sub writeListBegin { + my ($elemType, $size); + die "abstract"; +} + +sub writeListEnd { + die "abstract"; +} + +sub writeSetBegin { + my ($elemType, $size); + die "abstract"; +} + +sub writeSetEnd { + die "abstract"; +} + +sub writeBool { + my ($bool); + die "abstract"; +} + +sub writeByte { + my ($byte); + die "abstract"; +} + +sub writeI16 { + my ($i16); + die "abstract"; +} + +sub writeI32 { + my ($i32); + die "abstract"; +} + +sub writeI64 { + my ($i64); + die "abstract"; +} + +sub writeDouble { + my ($dub); + die "abstract"; +} + +sub writeString +{ + my ($str); + die "abstract"; +} + +# +# Reads the message header +# +# @param string $name Function name +# @param int $type message type TMessageType::CALL or TMessageType::REPLY +# @parem int $seqid The sequence id of this message +# +sub readMessageBegin +{ + my ($name, $type, $seqid); + die "abstract"; +} + +# +# Read the close of message +# +sub readMessageEnd +{ + die "abstract"; +} + +sub readStructBegin +{ + my($name); + + die "abstract"; +} + +sub readStructEnd +{ + die "abstract"; +} + +sub readFieldBegin +{ + my ($name, $fieldType, $fieldId); + die "abstract"; +} + +sub readFieldEnd +{ + die "abstract"; +} + +sub readMapBegin +{ + my ($keyType, $valType, $size); + die "abstract"; +} + +sub readMapEnd +{ + die "abstract"; +} + +sub readListBegin +{ + my ($elemType, $size); + die "abstract"; +} + +sub readListEnd +{ + die "abstract"; +} + +sub readSetBegin +{ + my ($elemType, $size); + die "abstract"; +} + +sub readSetEnd +{ + die "abstract"; +} + +sub readBool +{ + my ($bool); + die "abstract"; +} + +sub readByte +{ + my ($byte); + die "abstract"; +} + +sub readI16 +{ + my ($i16); + die "abstract"; +} + +sub readI32 +{ + my ($i32); + die "abstract"; +} + +sub readI64 +{ + my ($i64); + die "abstract"; +} + +sub readDouble +{ + my ($dub); + die "abstract"; +} + +sub readString +{ + my ($str); + die "abstract"; +} + +# +# The skip function is a utility to parse over unrecognized data without +# causing corruption. +# +# @param TType $type What type is it +# +sub skip +{ + my $self = shift; + my $type = shift; + + my $ref; + my $result; + my $i; + + if($type == TType::BOOL) + { + return $self->readBool(\$ref); + } + elsif($type == TType::BYTE){ + return $self->readByte(\$ref); + } + elsif($type == TType::I16){ + return $self->readI16(\$ref); + } + elsif($type == TType::I32){ + return $self->readI32(\$ref); + } + elsif($type == TType::I64){ + return $self->readI64(\$ref); + } + elsif($type == TType::DOUBLE){ + return $self->readDouble(\$ref); + } + elsif($type == TType::STRING) + { + return $self->readString(\$ref); + } + elsif($type == TType::STRUCT) + { + $result = $self->readStructBegin(\$ref); + while (1) { + my ($ftype,$fid); + $result += $self->readFieldBegin(\$ref, \$ftype, \$fid); + if ($ftype == TType::STOP) { + last; + } + $result += $self->skip($ftype); + $result += $self->readFieldEnd(); + } + $result += $self->readStructEnd(); + return $result; + } + elsif($type == TType::MAP) + { + my($keyType,$valType,$size); + $result = $self->readMapBegin(\$keyType, \$valType, \$size); + for ($i = 0; $i < $size; $i++) { + $result += $self->skip($keyType); + $result += $self->skip($valType); + } + $result += $self->readMapEnd(); + return $result; + } + elsif($type == TType::SET) + { + my ($elemType,$size); + $result = $self->readSetBegin(\$elemType, \$size); + for ($i = 0; $i < $size; $i++) { + $result += $self->skip($elemType); + } + $result += $self->readSetEnd(); + return $result; + } + elsif($type == TType::LIST) + { + my ($elemType,$size); + $result = $self->readListBegin(\$elemType, \$size); + for ($i = 0; $i < $size; $i++) { + $result += $self->skip($elemType); + } + $result += $self->readListEnd(); + return $result; + } + + + return 0; + + } + +# +# Utility for skipping binary data +# +# @param TTransport $itrans TTransport object +# @param int $type Field type +# +sub skipBinary +{ + my $self = shift; + my $itrans = shift; + my $type = shift; + + if($type == TType::BOOL) + { + return $itrans->readAll(1); + } + elsif($type == TType::BYTE) + { + return $itrans->readAll(1); + } + elsif($type == TType::I16) + { + return $itrans->readAll(2); + } + elsif($type == TType::I32) + { + return $itrans->readAll(4); + } + elsif($type == TType::I64) + { + return $itrans->readAll(8); + } + elsif($type == TType::DOUBLE) + { + return $itrans->readAll(8); + } + elsif( $type == TType::STRING ) + { + my @len = unpack('N', $itrans->readAll(4)); + my $len = $len[0]; + if ($len > 0x7fffffff) { + $len = 0 - (($len - 1) ^ 0xffffffff); + } + return 4 + $itrans->readAll($len); + } + elsif( $type == TType::STRUCT ) + { + my $result = 0; + while (1) { + my $ftype = 0; + my $fid = 0; + my $data = $itrans->readAll(1); + my @arr = unpack('c', $data); + $ftype = $arr[0]; + if ($ftype == TType::STOP) { + last; + } + # I16 field id + $result += $itrans->readAll(2); + $result += $self->skipBinary($itrans, $ftype); + } + return $result; + } + elsif($type == TType::MAP) + { + # Ktype + my $data = $itrans->readAll(1); + my @arr = unpack('c', $data); + my $ktype = $arr[0]; + # Vtype + $data = $itrans->readAll(1); + @arr = unpack('c', $data); + my $vtype = $arr[0]; + # Size + $data = $itrans->readAll(4); + @arr = unpack('N', $data); + my $size = $arr[0]; + if ($size > 0x7fffffff) { + $size = 0 - (($size - 1) ^ 0xffffffff); + } + my $result = 6; + for (my $i = 0; $i < $size; $i++) { + $result += $self->skipBinary($itrans, $ktype); + $result += $self->skipBinary($itrans, $vtype); + } + return $result; + } + elsif($type == TType::SET || $type == TType::LIST) + { + # Vtype + my $data = $itrans->readAll(1); + my @arr = unpack('c', $data); + my $vtype = $arr[0]; + # Size + $data = $itrans->readAll(4); + @arr = unpack('N', $data); + my $size = $arr[0]; + if ($size > 0x7fffffff) { + $size = 0 - (($size - 1) ^ 0xffffffff); + } + my $result = 5; + for (my $i = 0; $i < $size; $i++) { + $result += $self->skipBinary($itrans, $vtype); + } + return $result; + } + + return 0; + +} + +# +# Protocol factory creates protocol objects from transports +# +package TProtocolFactory; + + +sub new { + my $classname = shift; + my $self = {}; + + return bless($self,$classname); +} + +# +# Build a protocol from the base transport +# +# @return TProtcol protocol +# +sub getProtocol +{ + my ($trans); + die "interface"; +} + + +1; diff --git a/lib/perl/lib/Thrift/Socket.pm b/lib/perl/lib/Thrift/Socket.pm new file mode 100644 index 000000000..67faa5102 --- /dev/null +++ b/lib/perl/lib/Thrift/Socket.pm @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +use IO::Socket::INET; +use IO::Select; + +package Thrift::Socket; + +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + my $host = shift || "localhost"; + my $port = shift || 9090; + my $debugHandler = shift; + + my $self = { + host => $host, + port => $port, + debugHandler => $debugHandler, + debug => 0, + sendTimeout => 100, + recvTimeout => 750, + handle => undef, + }; + + return bless($self,$classname); +} + + +sub setSendTimeout +{ + my $self = shift; + my $timeout = shift; + + $self->{sendTimeout} = $timeout; +} + +sub setRecvTimeout +{ + my $self = shift; + my $timeout = shift; + + $self->{recvTimeout} = $timeout; +} + + +# +#Sets debugging output on or off +# +# @param bool $debug +# +sub setDebug +{ + my $self = shift; + my $debug = shift; + + $self->{debug} = $debug; +} + +# +# Tests whether this is open +# +# @return bool true if the socket is open +# +sub isOpen +{ + my $self = shift; + + if( defined $self->{handle} ){ + return ($self->{handle}->handles())[0]->connected; + } + + return 0; +} + +# +# Connects the socket. +# +sub open +{ + my $self = shift; + + my $sock = IO::Socket::INET->new(PeerAddr => $self->{host}, + PeerPort => $self->{port}, + Proto => 'tcp', + Timeout => $self->{sendTimeout}/1000) + || do { + my $error = 'TSocket: Could not connect to '.$self->{host}.':'.$self->{port}.' ('.$!.')'; + + if ($self->{debug}) { + $self->{debugHandler}->($error); + } + + die new Thrift::TException($error); + + }; + + + $self->{handle} = new IO::Select( $sock ); +} + +# +# Closes the socket. +# +sub close +{ + my $self = shift; + + if( defined $self->{handle} ){ + close( ($self->{handle}->handles())[0] ); + } +} + +# +# Uses stream get contents to do the reading +# +# @param int $len How many bytes +# @return string Binary data +# +sub readAll +{ + my $self = shift; + my $len = shift; + + + return unless defined $self->{handle}; + + my $pre = ""; + while (1) { + + #check for timeout + my @sockets = $self->{handle}->can_read( $self->{recvTimeout} / 1000 ); + + if(@sockets == 0){ + die new Thrift::TException('TSocket: timed out reading '.$len.' bytes from '. + $self->{host}.':'.$self->{port}); + } + + my $sock = $sockets[0]; + + my ($buf,$sz); + $sock->recv($buf, $len); + + if (!defined $buf || $buf eq '') { + + die new Thrift::TException('TSocket: Could not read '.$len.' bytes from '. + $self->{host}.':'.$self->{port}); + + } elsif (($sz = length($buf)) < $len) { + + $pre .= $buf; + $len -= $sz; + + } else { + return $pre.$buf; + } + } +} + +# +# Read from the socket +# +# @param int $len How many bytes +# @return string Binary data +# +sub read +{ + my $self = shift; + my $len = shift; + + return unless defined $self->{handle}; + + #check for timeout + my @sockets = $self->{handle}->can_read( $self->{sendTimeout} / 1000 ); + + if(@sockets == 0){ + die new Thrift::TException('TSocket: timed out reading '.$len.' bytes from '. + $self->{host}.':'.$self->{port}); + } + + my $sock = $sockets[0]; + + my ($buf,$sz); + $sock->recv($buf, $len); + + if (!defined $buf || $buf eq '') { + + die new TException('TSocket: Could not read '.$len.' bytes from '. + $self->{host}.':'.$self->{port}); + + } + + return $buf; +} + + +# +# Write to the socket. +# +# @param string $buf The data to write +# +sub write +{ + my $self = shift; + my $buf = shift; + + + return unless defined $self->{handle}; + + while (length($buf) > 0) { + + + #check for timeout + my @sockets = $self->{handle}->can_write( $self->{recvTimeout} / 1000 ); + + if(@sockets == 0){ + die new Thrift::TException('TSocket: timed out writing to bytes from '. + $self->{host}.':'.$self->{port}); + } + + my $sock = $sockets[0]; + + my $got = $sock->send($buf); + + if (!defined $got || $got == 0 ) { + die new Thrift::TException('TSocket: Could not write '.length($buf).' bytes '. + $self->{host}.':'.$self->{host}); + } + + $buf = substr($buf, $got); + } +} + +# +# Flush output to the socket. +# +sub flush +{ + my $self = shift; + + return unless defined $self->{handle}; + + my $ret = ($self->{handle}->handles())[0]->flush; +} + +1; diff --git a/lib/perl/lib/Thrift/Transport.pm b/lib/perl/lib/Thrift/Transport.pm new file mode 100644 index 000000000..e22592bec --- /dev/null +++ b/lib/perl/lib/Thrift/Transport.pm @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; + +# +# Transport exceptions +# +package TTransportException; +use base('Thrift::TException'); + +use constant UNKNOWN => 0; +use constant NOT_OPEN => 1; +use constant ALREADY_OPEN => 2; +use constant TIMED_OUT => 3; +use constant END_OF_FILE => 4; + +sub new{ + my $classname = shift; + my $self = $classname->SUPER::new(@_); + + return bless($self,$classname); +} + +package Thrift::Transport; + +# +# Whether this transport is open. +# +# @return boolean true if open +# +sub isOpen +{ + die "abstract"; +} + +# +# Open the transport for reading/writing +# +# @throws TTransportException if cannot open +# +sub open +{ + die "abstract"; +} + +# +# Close the transport. +# +sub close +{ + die "abstract"; +} + +# +# Read some data into the array. +# +# @param int $len How much to read +# @return string The data that has been read +# @throws TTransportException if cannot read any more data +# +sub read +{ + my ($len); + die("abstract"); +} + +# +# Guarantees that the full amount of data is read. +# +# @return string The data, of exact length +# @throws TTransportException if cannot read data +# +sub readAll +{ + my $self = shift; + my $len = shift; + + my $data = ''; + my $got = 0; + + while (($got = length($data)) < $len) { + $data .= $self->read($len - $got); + } + + return $data; +} + +# +# Writes the given data out. +# +# @param string $buf The data to write +# @throws TTransportException if writing fails +# +sub write +{ + my ($buf); + die "abstract"; +} + +# +# Flushes any pending data out of a buffer +# +# @throws TTransportException if a writing error occurs +# +sub flush {} + +1; + diff --git a/lib/perl/test.pl b/lib/perl/test.pl new file mode 100644 index 000000000..7e068402f --- /dev/null +++ b/lib/perl/test.pl @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +use strict; +use warnings; + +use Test::Harness; + +runtests(@ARGV); diff --git a/lib/perl/test/Makefile.am b/lib/perl/test/Makefile.am new file mode 100644 index 000000000..ce87c48d9 --- /dev/null +++ b/lib/perl/test/Makefile.am @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +THRIFT = @top_builddir@/compiler/cpp/thrift +THRIFT_IF = @top_srcdir@/test/ThriftTest.thrift + +check-local: gen-perl/ThriftTest/Types.pm + +gen-perl/ThriftTest/Types.pm: $(THRIFT_IF) + $(THRIFT) --gen perl $(THRIFT_IF) + +clean-local: + rm -rf gen-perl + +EXTRA_DIST = memory_buffer.t diff --git a/lib/perl/test/memory_buffer.t b/lib/perl/test/memory_buffer.t new file mode 100644 index 000000000..8fa9fd72e --- /dev/null +++ b/lib/perl/test/memory_buffer.t @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +use Test::More tests => 6; + +use strict; +use warnings; + +use Data::Dumper; + +use Thrift::BinaryProtocol; +use Thrift::MemoryBuffer; + +use ThriftTest::Types; + + +my $transport = Thrift::MemoryBuffer->new(); +my $protocol = Thrift::BinaryProtocol->new($transport); + +my $a = ThriftTest::Xtruct->new(); +$a->i32_thing(10); +$a->i64_thing(30); +$a->string_thing('Hello, world!'); +$a->write($protocol); + +my $b = ThriftTest::Xtruct->new(); +$b->read($protocol); +is($b->i32_thing, $a->i32_thing); +is($b->i64_thing, $a->i64_thing); +is($b->string_thing, $a->string_thing); + +$b->write($protocol); +my $c = ThriftTest::Xtruct->new(); +$c->read($protocol); +is($c->i32_thing, $a->i32_thing); +is($c->i64_thing, $a->i64_thing); +is($c->string_thing, $a->string_thing); diff --git a/lib/php/README b/lib/php/README new file mode 100644 index 000000000..bb566f42a --- /dev/null +++ b/lib/php/README @@ -0,0 +1,63 @@ +Thrift PHP Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with PHP +===================== + +Thrift requires PHP 5. Thrift makes as few assumptions about your PHP +environment as possible while trying to make some more advanced PHP +features (i.e. APC cacheing using asbolute path URLs) as simple as possible. + +To use Thrift in your PHP codebase, take the following steps: + +#1) Copy all of thrift/lib/php/src into your PHP codebase +#2) Set $GLOBALS['THRIFT_ROOT'] to the path you installed Thrift +#3) include_once $GLOBALS['THRIFT_ROOT'].'/Thrift.php'; + +Note that #3 must be done before including any other Thrift files. +If you do not do #2, Thrift.php will set this global for you, but it will be +done using dirname(__FILE__), which is less efficient than providing the static +string yourself. + +When you generate a Thrift package using the compiler, it makes an assumption +about where your generated code will live. If your file is "MyPackage.thrift", +the generated files must be installed into: + +$GLOBALS['THRIFT_ROOT'].'/packages/MyPackage/'; + +This allows the code generator to compile your code without any extra flags +for the target directory names while still allowing your include paths to +be absolute (if you have an absolute THRIFT_ROOT). + +Dependencies +============ + +PHP_INT_SIZE + + This built-in signals whether your architecture is 32 or 64 bit and is + used by the TBinaryProtocol to properly use pack() and unpack() to + serialize data. + +apc_fetch(), apc_store() + + APC cache is used by the TSocketPool class. If you do not have APC installed, + Thrift will fill in null stub function definitions. diff --git a/lib/php/README.apache b/lib/php/README.apache new file mode 100644 index 000000000..8c41833d1 --- /dev/null +++ b/lib/php/README.apache @@ -0,0 +1,62 @@ +Thrift PHP/Apache Integration + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Building PHP Thrift Services with Apache +======================================== + +Thrift can be embedded in the Apache webserver with PHP installed. Sample +code is provided below. Note that to make requests to this type of server +you must use a THttpClient transport. + +Sample Code +=========== + +<?php + +/** + * Example of how to build a Thrift server in Apache/PHP + * + */ + +$GLOBALS['THRIFT_ROOT'] = '/your/thrift/root'; + +include_once $GLOBALS['THRIFT_ROOT'].'/Thrift.php'; +include_once $GLOBALS['THRIFT_ROOT'].'/packages/Service/Service.php'; +include_once $GLOBALS['THRIFT_ROOT'].'/transport/TPhpStream.php'; +include_once $GLOBALS['THRIFT_ROOT'].'/protocol/TBinaryProtocol.php'; + +class ServiceHandler implements ServiceIf { + // Implement your interface and methods here +} + +header('Content-Type: application/x-thrift'); + +$handler = new ServiceHandler(); +$processor = new ServiceProcessor($handler); + +// Use the TPhpStream transport to read/write directly from HTTP +$transport = new TPhpStream(TPhpStream::MODE_R | TPhpStream::MODE_W); +$protocol = new TBinaryProtocol($transport); + +$transport->open(); +$processor->process($protocol, $protocol); +$transport->close(); diff --git a/lib/php/src/Thrift.php b/lib/php/src/Thrift.php new file mode 100644 index 000000000..ef6ab8a49 --- /dev/null +++ b/lib/php/src/Thrift.php @@ -0,0 +1,787 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift + */ + + +/** + * Data types that can be sent via Thrift + */ +class TType { + const STOP = 0; + const VOID = 1; + const BOOL = 2; + const BYTE = 3; + const I08 = 3; + const DOUBLE = 4; + const I16 = 6; + const I32 = 8; + const I64 = 10; + const STRING = 11; + const UTF7 = 11; + const STRUCT = 12; + const MAP = 13; + const SET = 14; + const LST = 15; // N.B. cannot use LIST keyword in PHP! + const UTF8 = 16; + const UTF16 = 17; +} + +/** + * Message types for RPC + */ +class TMessageType { + const CALL = 1; + const REPLY = 2; + const EXCEPTION = 3; + const ONEWAY = 4; +} + +/** + * NOTE(mcslee): This currently contains a ton of duplicated code from TBase + * because we need to save CPU cycles and this is not yet in an extension. + * Ideally we'd multiply-inherit TException from both Exception and Base, but + * that's not possible in PHP and there are no modules either, so for now we + * apologetically take a trip to HackTown. + * + * Can be called with standard Exception constructor (message, code) or with + * Thrift Base object constructor (spec, vals). + * + * @param mixed $p1 Message (string) or type-spec (array) + * @param mixed $p2 Code (integer) or values (array) + */ +class TException extends Exception { + function __construct($p1=null, $p2=0) { + if (is_array($p1) && is_array($p2)) { + $spec = $p1; + $vals = $p2; + foreach ($spec as $fid => $fspec) { + $var = $fspec['var']; + if (isset($vals[$var])) { + $this->$var = $vals[$var]; + } + } + } else { + parent::__construct($p1, $p2); + } + } + + static $tmethod = array(TType::BOOL => 'Bool', + TType::BYTE => 'Byte', + TType::I16 => 'I16', + TType::I32 => 'I32', + TType::I64 => 'I64', + TType::DOUBLE => 'Double', + TType::STRING => 'String'); + + private function _readMap(&$var, $spec, $input) { + $xfer = 0; + $ktype = $spec['ktype']; + $vtype = $spec['vtype']; + $kread = $vread = null; + if (isset(TBase::$tmethod[$ktype])) { + $kread = 'read'.TBase::$tmethod[$ktype]; + } else { + $kspec = $spec['key']; + } + if (isset(TBase::$tmethod[$vtype])) { + $vread = 'read'.TBase::$tmethod[$vtype]; + } else { + $vspec = $spec['val']; + } + $var = array(); + $_ktype = $_vtype = $size = 0; + $xfer += $input->readMapBegin($_ktype, $_vtype, $size); + for ($i = 0; $i < $size; ++$i) { + $key = $val = null; + if ($kread !== null) { + $xfer += $input->$kread($key); + } else { + switch ($ktype) { + case TType::STRUCT: + $class = $kspec['class']; + $key = new $class(); + $xfer += $key->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($key, $kspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($key, $kspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($key, $kspec, $input, true); + break; + } + } + if ($vread !== null) { + $xfer += $input->$vread($val); + } else { + switch ($vtype) { + case TType::STRUCT: + $class = $vspec['class']; + $val = new $class(); + $xfer += $val->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($val, $vspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($val, $vspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($val, $vspec, $input, true); + break; + } + } + $var[$key] = $val; + } + $xfer += $input->readMapEnd(); + return $xfer; + } + + private function _readList(&$var, $spec, $input, $set=false) { + $xfer = 0; + $etype = $spec['etype']; + $eread = $vread = null; + if (isset(TBase::$tmethod[$etype])) { + $eread = 'read'.TBase::$tmethod[$etype]; + } else { + $espec = $spec['elem']; + } + $var = array(); + $_etype = $size = 0; + if ($set) { + $xfer += $input->readSetBegin($_etype, $size); + } else { + $xfer += $input->readListBegin($_etype, $size); + } + for ($i = 0; $i < $size; ++$i) { + $elem = null; + if ($eread !== null) { + $xfer += $input->$eread($elem); + } else { + $espec = $spec['elem']; + switch ($etype) { + case TType::STRUCT: + $class = $espec['class']; + $elem = new $class(); + $xfer += $elem->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($elem, $espec, $input); + break; + case TType::LST: + $xfer += $this->_readList($elem, $espec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($elem, $espec, $input, true); + break; + } + } + if ($set) { + $var[$elem] = true; + } else { + $var []= $elem; + } + } + if ($set) { + $xfer += $input->readSetEnd(); + } else { + $xfer += $input->readListEnd(); + } + return $xfer; + } + + protected function _read($class, $spec, $input) { + $xfer = 0; + $fname = null; + $ftype = 0; + $fid = 0; + $xfer += $input->readStructBegin($fname); + while (true) { + $xfer += $input->readFieldBegin($fname, $ftype, $fid); + if ($ftype == TType::STOP) { + break; + } + if (isset($spec[$fid])) { + $fspec = $spec[$fid]; + $var = $fspec['var']; + if ($ftype == $fspec['type']) { + $xfer = 0; + if (isset(TBase::$tmethod[$ftype])) { + $func = 'read'.TBase::$tmethod[$ftype]; + $xfer += $input->$func($this->$var); + } else { + switch ($ftype) { + case TType::STRUCT: + $class = $fspec['class']; + $this->$var = new $class(); + $xfer += $this->$var->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($this->$var, $fspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($this->$var, $fspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($this->$var, $fspec, $input, true); + break; + } + } + } else { + $xfer += $input->skip($ftype); + } + } else { + $xfer += $input->skip($ftype); + } + $xfer += $input->readFieldEnd(); + } + $xfer += $input->readStructEnd(); + return $xfer; + } + + private function _writeMap($var, $spec, $output) { + $xfer = 0; + $ktype = $spec['ktype']; + $vtype = $spec['vtype']; + $kwrite = $vwrite = null; + if (isset(TBase::$tmethod[$ktype])) { + $kwrite = 'write'.TBase::$tmethod[$ktype]; + } else { + $kspec = $spec['key']; + } + if (isset(TBase::$tmethod[$vtype])) { + $vwrite = 'write'.TBase::$tmethod[$vtype]; + } else { + $vspec = $spec['val']; + } + $xfer += $output->writeMapBegin($ktype, $vtype, count($var)); + foreach ($var as $key => $val) { + if (isset($kwrite)) { + $xfer += $output->$kwrite($key); + } else { + switch ($ktype) { + case TType::STRUCT: + $xfer += $key->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($key, $kspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($key, $kspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($key, $kspec, $output, true); + break; + } + } + if (isset($vwrite)) { + $xfer += $output->$vwrite($val); + } else { + switch ($vtype) { + case TType::STRUCT: + $xfer += $val->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($val, $vspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($val, $vspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($val, $vspec, $output, true); + break; + } + } + } + $xfer += $output->writeMapEnd(); + return $xfer; + } + + private function _writeList($var, $spec, $output, $set=false) { + $xfer = 0; + $etype = $spec['etype']; + $ewrite = null; + if (isset(TBase::$tmethod[$etype])) { + $ewrite = 'write'.TBase::$tmethod[$etype]; + } else { + $espec = $spec['elem']; + } + if ($set) { + $xfer += $output->writeSetBegin($etype, count($var)); + } else { + $xfer += $output->writeListBegin($etype, count($var)); + } + foreach ($var as $key => $val) { + $elem = $set ? $key : $val; + if (isset($ewrite)) { + $xfer += $output->$ewrite($elem); + } else { + switch ($etype) { + case TType::STRUCT: + $xfer += $elem->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($elem, $espec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($elem, $espec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($elem, $espec, $output, true); + break; + } + } + } + if ($set) { + $xfer += $output->writeSetEnd(); + } else { + $xfer += $output->writeListEnd(); + } + return $xfer; + } + + protected function _write($class, $spec, $output) { + $xfer = 0; + $xfer += $output->writeStructBegin($class); + foreach ($spec as $fid => $fspec) { + $var = $fspec['var']; + if ($this->$var !== null) { + $ftype = $fspec['type']; + $xfer += $output->writeFieldBegin($var, $ftype, $fid); + if (isset(TBase::$tmethod[$ftype])) { + $func = 'write'.TBase::$tmethod[$ftype]; + $xfer += $output->$func($this->$var); + } else { + switch ($ftype) { + case TType::STRUCT: + $xfer += $this->$var->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($this->$var, $fspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($this->$var, $fspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($this->$var, $fspec, $output, true); + break; + } + } + $xfer += $output->writeFieldEnd(); + } + } + $xfer += $output->writeFieldStop(); + $xfer += $output->writeStructEnd(); + return $xfer; + } + +} + +/** + * Base class from which other Thrift structs extend. This is so that we can + * cut back on the size of the generated code which is turning out to have a + * nontrivial cost just to load thanks to the wondrously abysmal implementation + * of PHP. Note that code is intentionally duplicated in here to avoid making + * function calls for every field or member of a container.. + */ +abstract class TBase { + + static $tmethod = array(TType::BOOL => 'Bool', + TType::BYTE => 'Byte', + TType::I16 => 'I16', + TType::I32 => 'I32', + TType::I64 => 'I64', + TType::DOUBLE => 'Double', + TType::STRING => 'String'); + + abstract function read($input); + + abstract function write($output); + + public function __construct($spec=null, $vals=null) { + if (is_array($spec) && is_array($vals)) { + foreach ($spec as $fid => $fspec) { + $var = $fspec['var']; + if (isset($vals[$var])) { + $this->$var = $vals[$var]; + } + } + } + } + + private function _readMap(&$var, $spec, $input) { + $xfer = 0; + $ktype = $spec['ktype']; + $vtype = $spec['vtype']; + $kread = $vread = null; + if (isset(TBase::$tmethod[$ktype])) { + $kread = 'read'.TBase::$tmethod[$ktype]; + } else { + $kspec = $spec['key']; + } + if (isset(TBase::$tmethod[$vtype])) { + $vread = 'read'.TBase::$tmethod[$vtype]; + } else { + $vspec = $spec['val']; + } + $var = array(); + $_ktype = $_vtype = $size = 0; + $xfer += $input->readMapBegin($_ktype, $_vtype, $size); + for ($i = 0; $i < $size; ++$i) { + $key = $val = null; + if ($kread !== null) { + $xfer += $input->$kread($key); + } else { + switch ($ktype) { + case TType::STRUCT: + $class = $kspec['class']; + $key = new $class(); + $xfer += $key->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($key, $kspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($key, $kspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($key, $kspec, $input, true); + break; + } + } + if ($vread !== null) { + $xfer += $input->$vread($val); + } else { + switch ($vtype) { + case TType::STRUCT: + $class = $vspec['class']; + $val = new $class(); + $xfer += $val->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($val, $vspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($val, $vspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($val, $vspec, $input, true); + break; + } + } + $var[$key] = $val; + } + $xfer += $input->readMapEnd(); + return $xfer; + } + + private function _readList(&$var, $spec, $input, $set=false) { + $xfer = 0; + $etype = $spec['etype']; + $eread = $vread = null; + if (isset(TBase::$tmethod[$etype])) { + $eread = 'read'.TBase::$tmethod[$etype]; + } else { + $espec = $spec['elem']; + } + $var = array(); + $_etype = $size = 0; + if ($set) { + $xfer += $input->readSetBegin($_etype, $size); + } else { + $xfer += $input->readListBegin($_etype, $size); + } + for ($i = 0; $i < $size; ++$i) { + $elem = null; + if ($eread !== null) { + $xfer += $input->$eread($elem); + } else { + $espec = $spec['elem']; + switch ($etype) { + case TType::STRUCT: + $class = $espec['class']; + $elem = new $class(); + $xfer += $elem->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($elem, $espec, $input); + break; + case TType::LST: + $xfer += $this->_readList($elem, $espec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($elem, $espec, $input, true); + break; + } + } + if ($set) { + $var[$elem] = true; + } else { + $var []= $elem; + } + } + if ($set) { + $xfer += $input->readSetEnd(); + } else { + $xfer += $input->readListEnd(); + } + return $xfer; + } + + protected function _read($class, $spec, $input) { + $xfer = 0; + $fname = null; + $ftype = 0; + $fid = 0; + $xfer += $input->readStructBegin($fname); + while (true) { + $xfer += $input->readFieldBegin($fname, $ftype, $fid); + if ($ftype == TType::STOP) { + break; + } + if (isset($spec[$fid])) { + $fspec = $spec[$fid]; + $var = $fspec['var']; + if ($ftype == $fspec['type']) { + $xfer = 0; + if (isset(TBase::$tmethod[$ftype])) { + $func = 'read'.TBase::$tmethod[$ftype]; + $xfer += $input->$func($this->$var); + } else { + switch ($ftype) { + case TType::STRUCT: + $class = $fspec['class']; + $this->$var = new $class(); + $xfer += $this->$var->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($this->$var, $fspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($this->$var, $fspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($this->$var, $fspec, $input, true); + break; + } + } + } else { + $xfer += $input->skip($ftype); + } + } else { + $xfer += $input->skip($ftype); + } + $xfer += $input->readFieldEnd(); + } + $xfer += $input->readStructEnd(); + return $xfer; + } + + private function _writeMap($var, $spec, $output) { + $xfer = 0; + $ktype = $spec['ktype']; + $vtype = $spec['vtype']; + $kwrite = $vwrite = null; + if (isset(TBase::$tmethod[$ktype])) { + $kwrite = 'write'.TBase::$tmethod[$ktype]; + } else { + $kspec = $spec['key']; + } + if (isset(TBase::$tmethod[$vtype])) { + $vwrite = 'write'.TBase::$tmethod[$vtype]; + } else { + $vspec = $spec['val']; + } + $xfer += $output->writeMapBegin($ktype, $vtype, count($var)); + foreach ($var as $key => $val) { + if (isset($kwrite)) { + $xfer += $output->$kwrite($key); + } else { + switch ($ktype) { + case TType::STRUCT: + $xfer += $key->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($key, $kspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($key, $kspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($key, $kspec, $output, true); + break; + } + } + if (isset($vwrite)) { + $xfer += $output->$vwrite($val); + } else { + switch ($vtype) { + case TType::STRUCT: + $xfer += $val->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($val, $vspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($val, $vspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($val, $vspec, $output, true); + break; + } + } + } + $xfer += $output->writeMapEnd(); + return $xfer; + } + + private function _writeList($var, $spec, $output, $set=false) { + $xfer = 0; + $etype = $spec['etype']; + $ewrite = null; + if (isset(TBase::$tmethod[$etype])) { + $ewrite = 'write'.TBase::$tmethod[$etype]; + } else { + $espec = $spec['elem']; + } + if ($set) { + $xfer += $output->writeSetBegin($etype, count($var)); + } else { + $xfer += $output->writeListBegin($etype, count($var)); + } + foreach ($var as $key => $val) { + $elem = $set ? $key : $val; + if (isset($ewrite)) { + $xfer += $output->$ewrite($elem); + } else { + switch ($etype) { + case TType::STRUCT: + $xfer += $elem->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($elem, $espec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($elem, $espec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($elem, $espec, $output, true); + break; + } + } + } + if ($set) { + $xfer += $output->writeSetEnd(); + } else { + $xfer += $output->writeListEnd(); + } + return $xfer; + } + + protected function _write($class, $spec, $output) { + $xfer = 0; + $xfer += $output->writeStructBegin($class); + foreach ($spec as $fid => $fspec) { + $var = $fspec['var']; + if ($this->$var !== null) { + $ftype = $fspec['type']; + $xfer += $output->writeFieldBegin($var, $ftype, $fid); + if (isset(TBase::$tmethod[$ftype])) { + $func = 'write'.TBase::$tmethod[$ftype]; + $xfer += $output->$func($this->$var); + } else { + switch ($ftype) { + case TType::STRUCT: + $xfer += $this->$var->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($this->$var, $fspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($this->$var, $fspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($this->$var, $fspec, $output, true); + break; + } + } + $xfer += $output->writeFieldEnd(); + } + } + $xfer += $output->writeFieldStop(); + $xfer += $output->writeStructEnd(); + return $xfer; + } +} + +class TApplicationException extends TException { + static $_TSPEC = + array(1 => array('var' => 'message', + 'type' => TType::STRING), + 2 => array('var' => 'code', + 'type' => TType::I32)); + + const UNKNOWN = 0; + const UNKNOWN_METHOD = 1; + const INVALID_MESSAGE_TYPE = 2; + const WRONG_METHOD_NAME = 3; + const BAD_SEQUENCE_ID = 4; + const MISSING_RESULT = 5; + + function __construct($message=null, $code=0) { + parent::__construct($message, $code); + } + + public function read($output) { + return $this->_read('TApplicationException', self::$_TSPEC, $output); + } + + public function write($output) { + $xfer = 0; + $xfer += $output->writeStructBegin('TApplicationException'); + if ($message = $this->getMessage()) { + $xfer += $output->writeFieldBegin('message', TType::STRING, 1); + $xfer += $output->writeString($message); + $xfer += $output->writeFieldEnd(); + } + if ($code = $this->getCode()) { + $xfer += $output->writeFieldBegin('type', TType::I32, 2); + $xfer += $output->writeI32($code); + $xfer += $output->writeFieldEnd(); + } + $xfer += $output->writeFieldStop(); + $xfer += $output->writeStructEnd(); + return $xfer; + } +} + +/** + * Set global THRIFT ROOT automatically via inclusion here + */ +if (!isset($GLOBALS['THRIFT_ROOT'])) { + $GLOBALS['THRIFT_ROOT'] = dirname(__FILE__); +} +include_once $GLOBALS['THRIFT_ROOT'].'/protocol/TProtocol.php'; +include_once $GLOBALS['THRIFT_ROOT'].'/transport/TTransport.php'; + +?> diff --git a/lib/php/src/autoload.php b/lib/php/src/autoload.php new file mode 100644 index 000000000..3a35545d1 --- /dev/null +++ b/lib/php/src/autoload.php @@ -0,0 +1,51 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift + */ + + +/** + * Include this file if you wish to use autoload with your PHP generated Thrift + * code. The generated code will *not* include any defined Thrift classes by + * default, except for the service interfaces. The generated code will populate + * values into $GLOBALS['THRIFT_AUTOLOAD'] which can be used by the autoload + * method below. If you have your own autoload system already in place, rename your + * __autoload function to something else and then do: + * $GLOBALS['AUTOLOAD_HOOKS'][] = 'my_autoload_func'; + * + * Generate this code using the --gen php:autoload Thrift generator flag. + */ + +$GLOBALS['THRIFT_AUTOLOAD'] = array(); +$GLOBALS['AUTOLOAD_HOOKS'] = array(); + +if (!function_exists('__autoload')) { + function __autoload($class) { + global $THRIFT_AUTOLOAD; + $classl = strtolower($class); + if (isset($THRIFT_AUTOLOAD[$classl])) { + include_once $GLOBALS['THRIFT_ROOT'].'/packages/'.$THRIFT_AUTOLOAD[$classl]; + } else if (!empty($GLOBALS['AUTOLOAD_HOOKS'])) { + foreach ($GLOBALS['AUTOLOAD_HOOKS'] as $hook) { + $hook($class); + } + } + } +} diff --git a/lib/php/src/ext/thrift_protocol/config.m4 b/lib/php/src/ext/thrift_protocol/config.m4 new file mode 100644 index 000000000..8cfb37d7f --- /dev/null +++ b/lib/php/src/ext/thrift_protocol/config.m4 @@ -0,0 +1,13 @@ +dnl Copyright (C) 2009 Facebook +dnl Copying and distribution of this file, with or without modification, +dnl are permitted in any medium without royalty provided the copyright +dnl notice and this notice are preserved. + +PHP_ARG_ENABLE(thrift_protocol, whether to enable the thrift_protocol extension, +[ --enable-thrift_protocol Enable the fbthrift_protocol extension]) + +if test "$PHP_THRIFT_PROTOCOL" != "no"; then + PHP_REQUIRE_CXX() + PHP_NEW_EXTENSION(thrift_protocol, php_thrift_protocol.cpp, $ext_shared) +fi + diff --git a/lib/php/src/ext/thrift_protocol/php_thrift_protocol.cpp b/lib/php/src/ext/thrift_protocol/php_thrift_protocol.cpp new file mode 100644 index 000000000..399cbe624 --- /dev/null +++ b/lib/php/src/ext/thrift_protocol/php_thrift_protocol.cpp @@ -0,0 +1,999 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include <sys/types.h> +#include <netinet/in.h> +#include <unistd.h> +#include <endian.h> +#include <byteswap.h> +#include <stdexcept> + +#if __BYTE_ORDER == __LITTLE_ENDIAN +#define htonll(x) bswap_64(x) +#define ntohll(x) bswap_64(x) +#else +#define htonll(x) x +#define ntohll(x) x +#endif + +enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +}; + +const int32_t VERSION_MASK = 0xffff0000; +const int32_t VERSION_1 = 0x80010000; +const int8_t T_CALL = 1; +const int8_t T_REPLY = 2; +const int8_t T_EXCEPTION = 3; +// tprotocolexception +const int INVALID_DATA = 1; +const int BAD_VERSION = 4; + +#include "php.h" +#include "zend_interfaces.h" +#include "zend_exceptions.h" +#include "php_thrift_protocol.h" + +static function_entry thrift_protocol_functions[] = { + PHP_FE(thrift_protocol_write_binary, NULL) + PHP_FE(thrift_protocol_read_binary, NULL) + {NULL, NULL, NULL} +} ; + +zend_module_entry thrift_protocol_module_entry = { + STANDARD_MODULE_HEADER, + "thrift_protocol", + thrift_protocol_functions, + NULL, + NULL, + NULL, + NULL, + NULL, + "1.0", + STANDARD_MODULE_PROPERTIES +}; + +#ifdef COMPILE_DL_THRIFT_PROTOCOL +ZEND_GET_MODULE(thrift_protocol) +#endif + +class PHPExceptionWrapper : public std::exception { +public: + PHPExceptionWrapper(zval* _ex) throw() : ex(_ex) { + snprintf(_what, 40, "PHP exception zval=%p", ex); + } + const char* what() const throw() { return _what; } + ~PHPExceptionWrapper() throw() {} + operator zval*() const throw() { return const_cast<zval*>(ex); } // Zend API doesn't do 'const'... +protected: + zval* ex; + char _what[40]; +} ; + +class PHPTransport { +public: + zval* protocol() { return p; } + zval* transport() { return t; } +protected: + PHPTransport() {} + + void construct_with_zval(zval* _p, size_t _buffer_size) { + buffer = reinterpret_cast<char*>(emalloc(_buffer_size)); + buffer_ptr = buffer; + buffer_used = 0; + buffer_size = _buffer_size; + p = _p; + + // Get the transport for the passed protocol + zval gettransport; + ZVAL_STRING(&gettransport, "getTransport", 0); + MAKE_STD_ZVAL(t); + ZVAL_NULL(t); + TSRMLS_FETCH(); + call_user_function(EG(function_table), &p, &gettransport, t, 0, NULL TSRMLS_CC); + } + ~PHPTransport() { + efree(buffer); + zval_ptr_dtor(&t); + } + + char* buffer; + char* buffer_ptr; + size_t buffer_used; + size_t buffer_size; + + zval* p; + zval* t; +}; + + +class PHPOutputTransport : public PHPTransport { +public: + PHPOutputTransport(zval* _p, size_t _buffer_size = 8192) { + construct_with_zval(_p, _buffer_size); + } + + ~PHPOutputTransport() { + flush(); + directFlush(); + } + + void write(const char* data, size_t len) { + if ((len + buffer_used) > buffer_size) { + flush(); + } + if (len > buffer_size) { + directWrite(data, len); + } else { + memcpy(buffer_ptr, data, len); + buffer_used += len; + buffer_ptr += len; + } + } + + void writeI64(int64_t i) { + i = htonll(i); + write((const char*)&i, 8); + } + + void writeU32(uint32_t i) { + i = htonl(i); + write((const char*)&i, 4); + } + + void writeI32(int32_t i) { + i = htonl(i); + write((const char*)&i, 4); + } + + void writeI16(int16_t i) { + i = htons(i); + write((const char*)&i, 2); + } + + void writeI8(int8_t i) { + write((const char*)&i, 1); + } + + void writeString(const char* str, size_t len) { + writeU32(len); + write(str, len); + } + + void flush() { + if (buffer_used) { + directWrite(buffer, buffer_used); + buffer_ptr = buffer; + buffer_used = 0; + } + } + +protected: + void directFlush() { + zval ret; + ZVAL_NULL(&ret); + zval flushfn; + ZVAL_STRING(&flushfn, "flush", 0); + TSRMLS_FETCH(); + call_user_function(EG(function_table), &t, &flushfn, &ret, 0, NULL TSRMLS_CC); + zval_dtor(&ret); + } + void directWrite(const char* data, size_t len) { + zval writefn; + ZVAL_STRING(&writefn, "write", 0); + char* newbuf = (char*)emalloc(buffer_used + 1); + memcpy(newbuf, buffer, buffer_used); + newbuf[buffer_used] = '\0'; + zval *args[1]; + MAKE_STD_ZVAL(args[0]); + ZVAL_STRINGL(args[0], newbuf, buffer_used, 0); + TSRMLS_FETCH(); + zval ret; + ZVAL_NULL(&ret); + call_user_function(EG(function_table), &t, &writefn, &ret, 1, args TSRMLS_CC); + zval_ptr_dtor(args); + zval_dtor(&ret); + if (EG(exception)) { + zval* ex = EG(exception); + EG(exception) = NULL; + throw PHPExceptionWrapper(ex); + } + } +}; + +class PHPInputTransport : public PHPTransport { +public: + PHPInputTransport(zval* _p, size_t _buffer_size = 8192) { + construct_with_zval(_p, _buffer_size); + } + + ~PHPInputTransport() { + put_back(); + } + + void put_back() { + if (buffer_used) { + zval putbackfn; + ZVAL_STRING(&putbackfn, "putBack", 0); + + char* newbuf = (char*)emalloc(buffer_used + 1); + memcpy(newbuf, buffer_ptr, buffer_used); + newbuf[buffer_used] = '\0'; + + zval *args[1]; + MAKE_STD_ZVAL(args[0]); + ZVAL_STRINGL(args[0], newbuf, buffer_used, 0); + + TSRMLS_FETCH(); + + zval ret; + ZVAL_NULL(&ret); + call_user_function(EG(function_table), &t, &putbackfn, &ret, 1, args TSRMLS_CC); + zval_ptr_dtor(args); + zval_dtor(&ret); + } + buffer_used = 0; + buffer_ptr = buffer; + } + + void skip(size_t len) { + while (len) { + size_t chunk_size = MIN(len, buffer_used); + if (chunk_size) { + buffer_ptr = reinterpret_cast<char*>(buffer_ptr) + chunk_size; + buffer_used -= chunk_size; + len -= chunk_size; + } + if (! len) break; + refill(); + } + } + + void readBytes(void* buf, size_t len) { + while (len) { + size_t chunk_size = MIN(len, buffer_used); + if (chunk_size) { + memcpy(buf, buffer_ptr, chunk_size); + buffer_ptr = reinterpret_cast<char*>(buffer_ptr) + chunk_size; + buffer_used -= chunk_size; + buf = reinterpret_cast<char*>(buf) + chunk_size; + len -= chunk_size; + } + if (! len) break; + refill(); + } + } + + int8_t readI8() { + int8_t c; + readBytes(&c, 1); + return c; + } + + int16_t readI16() { + int16_t c; + readBytes(&c, 2); + return (int16_t)ntohs(c); + } + + uint32_t readU32() { + uint32_t c; + readBytes(&c, 4); + return (uint32_t)ntohl(c); + } + + int32_t readI32() { + int32_t c; + readBytes(&c, 4); + return (int32_t)ntohl(c); + } + +protected: + void refill() { + assert(buffer_used == 0); + zval retval; + ZVAL_NULL(&retval); + + zval *args[1]; + MAKE_STD_ZVAL(args[0]); + ZVAL_LONG(args[0], buffer_size); + + TSRMLS_FETCH(); + + zval funcname; + ZVAL_STRING(&funcname, "read", 0); + + call_user_function(EG(function_table), &t, &funcname, &retval, 1, args TSRMLS_CC); + zval_ptr_dtor(args); + + if (EG(exception)) { + zval_dtor(&retval); + zval* ex = EG(exception); + EG(exception) = NULL; + throw PHPExceptionWrapper(ex); + } + + buffer_used = Z_STRLEN(retval); + memcpy(buffer, Z_STRVAL(retval), buffer_used); + zval_dtor(&retval); + + buffer_ptr = buffer; + } + +}; + +void binary_deserialize_spec(zval* zthis, PHPInputTransport& transport, HashTable* spec); +void binary_serialize_spec(zval* zthis, PHPOutputTransport& transport, HashTable* spec); +void binary_serialize(int8_t thrift_typeID, PHPOutputTransport& transport, zval** value, HashTable* fieldspec); +void skip_element(long thrift_typeID, PHPInputTransport& transport); + +// Create a PHP object given a typename and call the ctor, optionally passing up to 2 arguments +void createObject(char* obj_typename, zval* return_value, int nargs = 0, zval* arg1 = NULL, zval* arg2 = NULL) { + TSRMLS_FETCH(); + size_t obj_typename_len = strlen(obj_typename); + zend_class_entry* ce = zend_fetch_class(obj_typename, obj_typename_len, ZEND_FETCH_CLASS_DEFAULT TSRMLS_CC); + if (! ce) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "Class %s does not exist", obj_typename); + RETURN_NULL(); + } + + object_and_properties_init(return_value, ce, NULL); + zend_function* constructor = zend_std_get_constructor(return_value TSRMLS_CC); + zval* ctor_rv = NULL; + zend_call_method(&return_value, ce, &constructor, NULL, 0, &ctor_rv, nargs, arg1, arg2 TSRMLS_CC); + zval_ptr_dtor(&ctor_rv); +} + +void throw_tprotocolexception(char* what, long errorcode) { + TSRMLS_FETCH(); + + zval *zwhat, *zerrorcode; + MAKE_STD_ZVAL(zwhat); + MAKE_STD_ZVAL(zerrorcode); + + ZVAL_STRING(zwhat, what, 1); + ZVAL_LONG(zerrorcode, errorcode); + + zval* ex; + MAKE_STD_ZVAL(ex); + createObject("TProtocolException", ex, 2, zwhat, zerrorcode); + zval_ptr_dtor(&zwhat); + zval_ptr_dtor(&zerrorcode); + throw PHPExceptionWrapper(ex); +} + +void binary_deserialize(int8_t thrift_typeID, PHPInputTransport& transport, zval* return_value, HashTable* fieldspec) { + zval** val_ptr; + Z_TYPE_P(return_value) = IS_NULL; // just in case + + switch (thrift_typeID) { + case T_STOP: + case T_VOID: + RETURN_NULL(); + return; + case T_STRUCT: { + if (zend_hash_find(fieldspec, "class", 6, (void**)&val_ptr) != SUCCESS) { + throw_tprotocolexception("no class type in spec", INVALID_DATA); + skip_element(T_STRUCT, transport); + RETURN_NULL(); + } + char* structType = Z_STRVAL_PP(val_ptr); + createObject(structType, return_value); + if (Z_TYPE_P(return_value) == IS_NULL) { + // unable to create class entry + skip_element(T_STRUCT, transport); + RETURN_NULL(); + } + TSRMLS_FETCH(); + zval* spec = zend_read_static_property(zend_get_class_entry(return_value TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + if (Z_TYPE_P(spec) != IS_ARRAY) { + char errbuf[128]; + snprintf(errbuf, 128, "spec for %s is wrong type: %d\n", structType, Z_TYPE_P(spec)); + throw_tprotocolexception(errbuf, INVALID_DATA); + RETURN_NULL(); + } + binary_deserialize_spec(return_value, transport, Z_ARRVAL_P(spec)); + return; + } break; + case T_BOOL: { + uint8_t c; + transport.readBytes(&c, 1); + RETURN_BOOL(c != 0); + } + //case T_I08: // same numeric value as T_BYTE + case T_BYTE: { + uint8_t c; + transport.readBytes(&c, 1); + RETURN_LONG(c); + } + case T_I16: { + uint16_t c; + transport.readBytes(&c, 2); + RETURN_LONG(ntohs(c)); + } + case T_I32: { + uint32_t c; + transport.readBytes(&c, 4); + RETURN_LONG(ntohl(c)); + } + case T_U64: + case T_I64: { + uint64_t c; + transport.readBytes(&c, 8); + RETURN_LONG(ntohll(c)); + } + case T_DOUBLE: { + union { + uint64_t c; + double d; + } a; + transport.readBytes(&(a.c), 8); + a.c = ntohll(a.c); + RETURN_DOUBLE(a.d); + } + //case T_UTF7: // aliases T_STRING + case T_UTF8: + case T_UTF16: + case T_STRING: { + uint32_t size = transport.readU32(); + if (size) { + char* strbuf = (char*) emalloc(size + 1); + transport.readBytes(strbuf, size); + strbuf[size] = '\0'; + ZVAL_STRINGL(return_value, strbuf, size, 0); + } else { + ZVAL_EMPTY_STRING(return_value); + } + return; + } + case T_MAP: { // array of key -> value + uint8_t types[2]; + transport.readBytes(types, 2); + uint32_t size = transport.readU32(); + array_init(return_value); + + zend_hash_find(fieldspec, "key", 4, (void**)&val_ptr); + HashTable* keyspec = Z_ARRVAL_PP(val_ptr); + zend_hash_find(fieldspec, "val", 4, (void**)&val_ptr); + HashTable* valspec = Z_ARRVAL_PP(val_ptr); + + for (uint32_t s = 0; s < size; ++s) { + zval *value; + MAKE_STD_ZVAL(value); + + zval* key; + MAKE_STD_ZVAL(key); + + binary_deserialize(types[0], transport, key, keyspec); + binary_deserialize(types[1], transport, value, valspec); + if (Z_TYPE_P(key) == IS_LONG) { + zend_hash_index_update(return_value->value.ht, Z_LVAL_P(key), &value, sizeof(zval *), NULL); + } + else { + if (Z_TYPE_P(key) != IS_STRING) convert_to_string(key); + zend_hash_update(return_value->value.ht, Z_STRVAL_P(key), Z_STRLEN_P(key) + 1, &value, sizeof(zval *), NULL); + } + zval_ptr_dtor(&key); + } + return; // return_value already populated + } + case T_LIST: { // array with autogenerated numeric keys + int8_t type = transport.readI8(); + uint32_t size = transport.readU32(); + zend_hash_find(fieldspec, "elem", 5, (void**)&val_ptr); + HashTable* elemspec = Z_ARRVAL_PP(val_ptr); + + array_init(return_value); + for (uint32_t s = 0; s < size; ++s) { + zval *value; + MAKE_STD_ZVAL(value); + binary_deserialize(type, transport, value, elemspec); + zend_hash_next_index_insert(return_value->value.ht, &value, sizeof(zval *), NULL); + } + return; + } + case T_SET: { // array of key -> TRUE + uint8_t type; + uint32_t size; + transport.readBytes(&type, 1); + transport.readBytes(&size, 4); + size = ntohl(size); + zend_hash_find(fieldspec, "elem", 5, (void**)&val_ptr); + HashTable* elemspec = Z_ARRVAL_PP(val_ptr); + + array_init(return_value); + + for (uint32_t s = 0; s < size; ++s) { + zval* key; + zval* value; + MAKE_STD_ZVAL(key); + MAKE_STD_ZVAL(value); + ZVAL_TRUE(value); + + binary_deserialize(type, transport, key, elemspec); + + if (Z_TYPE_P(key) == IS_LONG) { + zend_hash_index_update(return_value->value.ht, Z_LVAL_P(key), &value, sizeof(zval *), NULL); + } + else { + if (Z_TYPE_P(key) != IS_STRING) convert_to_string(key); + zend_hash_update(return_value->value.ht, Z_STRVAL_P(key), Z_STRLEN_P(key) + 1, &value, sizeof(zval *), NULL); + } + zval_ptr_dtor(&key); + } + return; + } + }; + + char errbuf[128]; + sprintf(errbuf, "Unknown thrift typeID %d", thrift_typeID); + throw_tprotocolexception(errbuf, INVALID_DATA); +} + +void skip_element(long thrift_typeID, PHPInputTransport& transport) { + switch (thrift_typeID) { + case T_STOP: + case T_VOID: + return; + case T_STRUCT: + while (true) { + int8_t ttype = transport.readI8(); // get field type + if (ttype == T_STOP) break; + transport.skip(2); // skip field number, I16 + skip_element(ttype, transport); // skip field payload + } + return; + case T_BOOL: + case T_BYTE: + transport.skip(1); + return; + case T_I16: + transport.skip(2); + return; + case T_I32: + transport.skip(4); + return; + case T_U64: + case T_I64: + case T_DOUBLE: + transport.skip(8); + return; + //case T_UTF7: // aliases T_STRING + case T_UTF8: + case T_UTF16: + case T_STRING: { + uint32_t len = transport.readU32(); + transport.skip(len); + } return; + case T_MAP: { + int8_t keytype = transport.readI8(); + int8_t valtype = transport.readI8(); + uint32_t size = transport.readU32(); + for (uint32_t i = 0; i < size; ++i) { + skip_element(keytype, transport); + skip_element(valtype, transport); + } + } return; + case T_LIST: + case T_SET: { + int8_t valtype = transport.readI8(); + uint32_t size = transport.readU32(); + for (uint32_t i = 0; i < size; ++i) { + skip_element(valtype, transport); + } + } return; + }; + + char errbuf[128]; + sprintf(errbuf, "Unknown thrift typeID %ld", thrift_typeID); + throw_tprotocolexception(errbuf, INVALID_DATA); +} + +void binary_serialize_hashtable_key(int8_t keytype, PHPOutputTransport& transport, HashTable* ht, HashPosition& ht_pos) { + bool keytype_is_numeric = (!((keytype == T_STRING) || (keytype == T_UTF8) || (keytype == T_UTF16))); + + char* key; + uint key_len; + long index = 0; + + zval* z; + MAKE_STD_ZVAL(z); + + int res = zend_hash_get_current_key_ex(ht, &key, &key_len, (ulong*)&index, 0, &ht_pos); + if (keytype_is_numeric) { + if (res == HASH_KEY_IS_STRING) { + index = strtol(key, NULL, 10); + } + ZVAL_LONG(z, index); + } else { + char buf[64]; + if (res == HASH_KEY_IS_STRING) { + key_len -= 1; // skip the null terminator + } else { + sprintf(buf, "%ld", index); + key = buf; key_len = strlen(buf); + } + ZVAL_STRINGL(z, key, key_len, 1); + } + binary_serialize(keytype, transport, &z, NULL); + zval_ptr_dtor(&z); +} + +inline bool ttype_is_int(int8_t t) { + return ((t == T_BYTE) || ((t >= T_I16) && (t <= T_I64))); +} + +inline bool ttypes_are_compatible(int8_t t1, int8_t t2) { + // Integer types of different widths are considered compatible; + // otherwise the typeID must match. + return ((t1 == t2) || (ttype_is_int(t1) && ttype_is_int(t2))); +} + +void binary_deserialize_spec(zval* zthis, PHPInputTransport& transport, HashTable* spec) { + // SET and LIST have 'elem' => array('type', [optional] 'class') + // MAP has 'val' => array('type', [optiona] 'class') + TSRMLS_FETCH(); + zend_class_entry* ce = zend_get_class_entry(zthis TSRMLS_CC); + while (true) { + zval** val_ptr = NULL; + + int8_t ttype = transport.readI8(); + if (ttype == T_STOP) return; + int16_t fieldno = transport.readI16(); + if (zend_hash_index_find(spec, fieldno, (void**)&val_ptr) == SUCCESS) { + HashTable* fieldspec = Z_ARRVAL_PP(val_ptr); + // pull the field name + // zend hash tables use the null at the end in the length... so strlen(hash key) + 1. + zend_hash_find(fieldspec, "var", 4, (void**)&val_ptr); + char* varname = Z_STRVAL_PP(val_ptr); + + // and the type + zend_hash_find(fieldspec, "type", 5, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + int8_t expected_ttype = Z_LVAL_PP(val_ptr); + + if (ttypes_are_compatible(ttype, expected_ttype)) { + zval* rv = NULL; + MAKE_STD_ZVAL(rv); + binary_deserialize(ttype, transport, rv, fieldspec); + zend_update_property(ce, zthis, varname, strlen(varname), rv TSRMLS_CC); + zval_ptr_dtor(&rv); + } else { + skip_element(ttype, transport); + } + } else { + skip_element(ttype, transport); + } + } +} + +void binary_serialize(int8_t thrift_typeID, PHPOutputTransport& transport, zval** value, HashTable* fieldspec) { + // At this point the typeID (and field num, if applicable) should've already been written to the output so all we need to do is write the payload. + switch (thrift_typeID) { + case T_STOP: + case T_VOID: + return; + case T_STRUCT: { + TSRMLS_FETCH(); + if (Z_TYPE_PP(value) != IS_OBJECT) { + throw_tprotocolexception("Attempt to send non-object type as a T_STRUCT", INVALID_DATA); + } + zval* spec = zend_read_static_property(zend_get_class_entry(*value TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + binary_serialize_spec(*value, transport, Z_ARRVAL_P(spec)); + } return; + case T_BOOL: + if (Z_TYPE_PP(value) != IS_BOOL) convert_to_boolean(*value); + transport.writeI8(Z_BVAL_PP(value) ? 1 : 0); + return; + case T_BYTE: + if (Z_TYPE_PP(value) != IS_LONG) convert_to_long(*value); + transport.writeI8(Z_LVAL_PP(value)); + return; + case T_I16: + if (Z_TYPE_PP(value) != IS_LONG) convert_to_long(*value); + transport.writeI16(Z_LVAL_PP(value)); + return; + case T_I32: + if (Z_TYPE_PP(value) != IS_LONG) convert_to_long(*value); + transport.writeI32(Z_LVAL_PP(value)); + return; + case T_I64: + case T_U64: + if (Z_TYPE_PP(value) != IS_LONG) convert_to_long(*value); + transport.writeI64(Z_LVAL_PP(value)); + return; + case T_DOUBLE: { + union { + int64_t c; + double d; + } a; + if (Z_TYPE_PP(value) != IS_DOUBLE) convert_to_double(*value); + a.d = Z_DVAL_PP(value); + transport.writeI64(a.c); + } return; + //case T_UTF7: + case T_UTF8: + case T_UTF16: + case T_STRING: + if (Z_TYPE_PP(value) != IS_STRING) convert_to_string(*value); + transport.writeString(Z_STRVAL_PP(value), Z_STRLEN_PP(value)); + return; + case T_MAP: { + if (Z_TYPE_PP(value) != IS_ARRAY) convert_to_array(*value); + if (Z_TYPE_PP(value) != IS_ARRAY) { + throw_tprotocolexception("Attempt to send an incompatible type as an array (T_MAP)", INVALID_DATA); + } + HashTable* ht = Z_ARRVAL_PP(value); + zval** val_ptr; + + zend_hash_find(fieldspec, "ktype", 6, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + uint8_t keytype = Z_LVAL_PP(val_ptr); + transport.writeI8(keytype); + zend_hash_find(fieldspec, "vtype", 6, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + uint8_t valtype = Z_LVAL_PP(val_ptr); + transport.writeI8(valtype); + + zend_hash_find(fieldspec, "val", 4, (void**)&val_ptr); + HashTable* valspec = Z_ARRVAL_PP(val_ptr); + + transport.writeI32(zend_hash_num_elements(ht)); + HashPosition key_ptr; + for (zend_hash_internal_pointer_reset_ex(ht, &key_ptr); zend_hash_get_current_data_ex(ht, (void**)&val_ptr, &key_ptr) == SUCCESS; zend_hash_move_forward_ex(ht, &key_ptr)) { + binary_serialize_hashtable_key(keytype, transport, ht, key_ptr); + binary_serialize(valtype, transport, val_ptr, valspec); + } + } return; + case T_LIST: { + if (Z_TYPE_PP(value) != IS_ARRAY) convert_to_array(*value); + if (Z_TYPE_PP(value) != IS_ARRAY) { + throw_tprotocolexception("Attempt to send an incompatible type as an array (T_LIST)", INVALID_DATA); + } + HashTable* ht = Z_ARRVAL_PP(value); + zval** val_ptr; + + zend_hash_find(fieldspec, "etype", 6, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + uint8_t valtype = Z_LVAL_PP(val_ptr); + transport.writeI8(valtype); + + zend_hash_find(fieldspec, "elem", 5, (void**)&val_ptr); + HashTable* valspec = Z_ARRVAL_PP(val_ptr); + + transport.writeI32(zend_hash_num_elements(ht)); + HashPosition key_ptr; + for (zend_hash_internal_pointer_reset_ex(ht, &key_ptr); zend_hash_get_current_data_ex(ht, (void**)&val_ptr, &key_ptr) == SUCCESS; zend_hash_move_forward_ex(ht, &key_ptr)) { + binary_serialize(valtype, transport, val_ptr, valspec); + } + } return; + case T_SET: { + if (Z_TYPE_PP(value) != IS_ARRAY) convert_to_array(*value); + if (Z_TYPE_PP(value) != IS_ARRAY) { + throw_tprotocolexception("Attempt to send an incompatible type as an array (T_SET)", INVALID_DATA); + } + HashTable* ht = Z_ARRVAL_PP(value); + zval** val_ptr; + + zend_hash_find(fieldspec, "etype", 6, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + uint8_t keytype = Z_LVAL_PP(val_ptr); + transport.writeI8(keytype); + + transport.writeI32(zend_hash_num_elements(ht)); + HashPosition key_ptr; + for (zend_hash_internal_pointer_reset_ex(ht, &key_ptr); zend_hash_get_current_data_ex(ht, (void**)&val_ptr, &key_ptr) == SUCCESS; zend_hash_move_forward_ex(ht, &key_ptr)) { + binary_serialize_hashtable_key(keytype, transport, ht, key_ptr); + } + } return; + }; + char errbuf[128]; + sprintf(errbuf, "Unknown thrift typeID %d", thrift_typeID); + throw_tprotocolexception(errbuf, INVALID_DATA); +} + + +void binary_serialize_spec(zval* zthis, PHPOutputTransport& transport, HashTable* spec) { + HashPosition key_ptr; + zval** val_ptr; + + TSRMLS_FETCH(); + zend_class_entry* ce = zend_get_class_entry(zthis TSRMLS_CC); + + for (zend_hash_internal_pointer_reset_ex(spec, &key_ptr); zend_hash_get_current_data_ex(spec, (void**)&val_ptr, &key_ptr) == SUCCESS; zend_hash_move_forward_ex(spec, &key_ptr)) { + ulong fieldno; + if (zend_hash_get_current_key_ex(spec, NULL, NULL, &fieldno, 0, &key_ptr) != HASH_KEY_IS_LONG) { + throw_tprotocolexception("Bad keytype in TSPEC (expected 'long')", INVALID_DATA); + return; + } + HashTable* fieldspec = Z_ARRVAL_PP(val_ptr); + + // field name + zend_hash_find(fieldspec, "var", 4, (void**)&val_ptr); + char* varname = Z_STRVAL_PP(val_ptr); + + // thrift type + zend_hash_find(fieldspec, "type", 5, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + int8_t ttype = Z_LVAL_PP(val_ptr); + + zval* prop = zend_read_property(ce, zthis, varname, strlen(varname), false TSRMLS_CC); + if (Z_TYPE_P(prop) != IS_NULL) { + transport.writeI8(ttype); + transport.writeI16(fieldno); + binary_serialize(ttype, transport, &prop, fieldspec); + } + } + transport.writeI8(T_STOP); // struct end +} + +// 6 params: $transport $method_name $ttype $request_struct $seqID $strict_write +PHP_FUNCTION(thrift_protocol_write_binary) { + int argc = ZEND_NUM_ARGS(); + if (argc < 6) { + WRONG_PARAM_COUNT; + } + + zval ***args = (zval***) emalloc(argc * sizeof(zval**)); + zend_get_parameters_array_ex(argc, args); + + if (Z_TYPE_PP(args[0]) != IS_OBJECT) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "1st parameter is not an object (transport)"); + efree(args); + RETURN_NULL(); + } + + if (Z_TYPE_PP(args[1]) != IS_STRING) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "2nd parameter is not a string (method name)"); + efree(args); + RETURN_NULL(); + } + + if (Z_TYPE_PP(args[3]) != IS_OBJECT) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "4th parameter is not an object (request struct)"); + efree(args); + RETURN_NULL(); + } + + PHPOutputTransport transport(*args[0]); + const char* method_name = Z_STRVAL_PP(args[1]); + convert_to_long(*args[2]); + int32_t msgtype = Z_LVAL_PP(args[2]); + zval* request_struct = *args[3]; + convert_to_long(*args[4]); + int32_t seqID = Z_LVAL_PP(args[4]); + convert_to_boolean(*args[5]); + bool strictWrite = Z_BVAL_PP(args[5]); + efree(args); + args = NULL; + + try { + if (strictWrite) { + int32_t version = VERSION_1 | msgtype; + transport.writeI32(version); + transport.writeString(method_name, strlen(method_name)); + transport.writeI32(seqID); + } else { + transport.writeString(method_name, strlen(method_name)); + transport.writeI8(msgtype); + transport.writeI32(seqID); + } + + zval* spec = zend_read_static_property(zend_get_class_entry(request_struct TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + binary_serialize_spec(request_struct, transport, Z_ARRVAL_P(spec)); + } catch (const PHPExceptionWrapper& ex) { + zend_throw_exception_object(ex TSRMLS_CC); + RETURN_NULL(); + } +} + +// 3 params: $transport $response_Typename $strict_read +PHP_FUNCTION(thrift_protocol_read_binary) { + int argc = ZEND_NUM_ARGS(); + + if (argc < 3) { + WRONG_PARAM_COUNT; + } + + zval ***args = (zval***) emalloc(argc * sizeof(zval**)); + zend_get_parameters_array_ex(argc, args); + + if (Z_TYPE_PP(args[0]) != IS_OBJECT) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "1st parameter is not an object (transport)"); + efree(args); + RETURN_NULL(); + } + + if (Z_TYPE_PP(args[1]) != IS_STRING) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "2nd parameter is not a string (typename of expected response struct)"); + efree(args); + RETURN_NULL(); + } + + PHPInputTransport transport(*args[0]); + char* obj_typename = Z_STRVAL_PP(args[1]); + convert_to_boolean(*args[2]); + bool strict_read = Z_BVAL_PP(args[2]); + efree(args); + args = NULL; + + try { + int8_t messageType = 0; + int32_t sz = transport.readI32(); + + if (sz < 0) { + // Check for correct version number + int32_t version = sz & VERSION_MASK; + if (version != VERSION_1) { + throw_tprotocolexception("Bad version identifier", BAD_VERSION); + } + messageType = (sz & 0x000000ff); + int32_t namelen = transport.readI32(); + // skip the name string and the sequence ID, we don't care about those + transport.skip(namelen + 4); + } else { + if (strict_read) { + throw_tprotocolexception("No version identifier... old protocol client in strict mode?", BAD_VERSION); + } else { + // Handle pre-versioned input + transport.skip(sz); // skip string body + messageType = transport.readI8(); + transport.skip(4); // skip sequence number + } + } + + if (messageType == T_EXCEPTION) { + zval* ex; + MAKE_STD_ZVAL(ex); + createObject("TApplicationException", ex); + zval* spec = zend_read_static_property(zend_get_class_entry(ex TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + binary_deserialize_spec(ex, transport, Z_ARRVAL_P(spec)); + throw PHPExceptionWrapper(ex); + } + + createObject(obj_typename, return_value); + zval* spec = zend_read_static_property(zend_get_class_entry(return_value TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + binary_deserialize_spec(return_value, transport, Z_ARRVAL_P(spec)); + } catch (const PHPExceptionWrapper& ex) { + zend_throw_exception_object(ex TSRMLS_CC); + RETURN_NULL(); + } +} + diff --git a/lib/php/src/ext/thrift_protocol/php_thrift_protocol.h b/lib/php/src/ext/thrift_protocol/php_thrift_protocol.h new file mode 100644 index 000000000..c9a3e00f0 --- /dev/null +++ b/lib/php/src/ext/thrift_protocol/php_thrift_protocol.h @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +PHP_FUNCTION(thrift_protocol_write_binary); +PHP_FUNCTION(thrift_protocol_read_binary); + +extern zend_module_entry thrift_protocole_module_entry; + diff --git a/lib/php/src/protocol/TBinaryProtocol.php b/lib/php/src/protocol/TBinaryProtocol.php new file mode 100644 index 000000000..31bbbf9d0 --- /dev/null +++ b/lib/php/src/protocol/TBinaryProtocol.php @@ -0,0 +1,431 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.protocol + */ + +include_once $GLOBALS['THRIFT_ROOT'].'/transport/TBufferedTransport.php'; + +/** + * Binary implementation of the Thrift protocol. + * + */ +class TBinaryProtocol extends TProtocol { + + const VERSION_MASK = 0xffff0000; + const VERSION_1 = 0x80010000; + + protected $strictRead_ = false; + protected $strictWrite_ = true; + + public function __construct($trans, $strictRead=false, $strictWrite=true) { + parent::__construct($trans); + $this->strictRead_ = $strictRead; + $this->strictWrite_ = $strictWrite; + } + + public function writeMessageBegin($name, $type, $seqid) { + if ($this->strictWrite_) { + $version = self::VERSION_1 | $type; + return + $this->writeI32($version) + + $this->writeString($name) + + $this->writeI32($seqid); + } else { + return + $this->writeString($name) + + $this->writeByte($type) + + $this->writeI32($seqid); + } + } + + public function writeMessageEnd() { + return 0; + } + + public function writeStructBegin($name) { + return 0; + } + + public function writeStructEnd() { + return 0; + } + + public function writeFieldBegin($fieldName, $fieldType, $fieldId) { + return + $this->writeByte($fieldType) + + $this->writeI16($fieldId); + } + + public function writeFieldEnd() { + return 0; + } + + public function writeFieldStop() { + return + $this->writeByte(TType::STOP); + } + + public function writeMapBegin($keyType, $valType, $size) { + return + $this->writeByte($keyType) + + $this->writeByte($valType) + + $this->writeI32($size); + } + + public function writeMapEnd() { + return 0; + } + + public function writeListBegin($elemType, $size) { + return + $this->writeByte($elemType) + + $this->writeI32($size); + } + + public function writeListEnd() { + return 0; + } + + public function writeSetBegin($elemType, $size) { + return + $this->writeByte($elemType) + + $this->writeI32($size); + } + + public function writeSetEnd() { + return 0; + } + + public function writeBool($value) { + $data = pack('c', $value ? 1 : 0); + $this->trans_->write($data, 1); + return 1; + } + + public function writeByte($value) { + $data = pack('c', $value); + $this->trans_->write($data, 1); + return 1; + } + + public function writeI16($value) { + $data = pack('n', $value); + $this->trans_->write($data, 2); + return 2; + } + + public function writeI32($value) { + $data = pack('N', $value); + $this->trans_->write($data, 4); + return 4; + } + + public function writeI64($value) { + // If we are on a 32bit architecture we have to explicitly deal with + // 64-bit twos-complement arithmetic since PHP wants to treat all ints + // as signed and any int over 2^31 - 1 as a float + if (PHP_INT_SIZE == 4) { + $neg = $value < 0; + + if ($neg) { + $value *= -1; + } + + $hi = (int)($value / 4294967296); + $lo = (int)$value; + + if ($neg) { + $hi = ~$hi; + $lo = ~$lo; + if (($lo & (int)0xffffffff) == (int)0xffffffff) { + $lo = 0; + $hi++; + } else { + $lo++; + } + } + $data = pack('N2', $hi, $lo); + + } else { + $hi = $value >> 32; + $lo = $value & 0xFFFFFFFF; + $data = pack('N2', $hi, $lo); + } + + $this->trans_->write($data, 8); + return 8; + } + + public function writeDouble($value) { + $data = pack('d', $value); + $this->trans_->write(strrev($data), 8); + return 8; + } + + public function writeString($value) { + $len = strlen($value); + $result = $this->writeI32($len); + if ($len) { + $this->trans_->write($value, $len); + } + return $result + $len; + } + + public function readMessageBegin(&$name, &$type, &$seqid) { + $result = $this->readI32($sz); + if ($sz < 0) { + $version = (int) ($sz & self::VERSION_MASK); + if ($version != (int) self::VERSION_1) { + throw new TProtocolException('Bad version identifier: '.$sz, TProtocolException::BAD_VERSION); + } + $type = $sz & 0x000000ff; + $result += + $this->readString($name) + + $this->readI32($seqid); + } else { + if ($this->strictRead_) { + throw new TProtocolException('No version identifier, old protocol client?', TProtocolException::BAD_VERSION); + } else { + // Handle pre-versioned input + $name = $this->trans_->readAll($sz); + $result += + $sz + + $this->readByte($type) + + $this->readI32($seqid); + } + } + return $result; + } + + public function readMessageEnd() { + return 0; + } + + public function readStructBegin(&$name) { + $name = ''; + return 0; + } + + public function readStructEnd() { + return 0; + } + + public function readFieldBegin(&$name, &$fieldType, &$fieldId) { + $result = $this->readByte($fieldType); + if ($fieldType == TType::STOP) { + $fieldId = 0; + return $result; + } + $result += $this->readI16($fieldId); + return $result; + } + + public function readFieldEnd() { + return 0; + } + + public function readMapBegin(&$keyType, &$valType, &$size) { + return + $this->readByte($keyType) + + $this->readByte($valType) + + $this->readI32($size); + } + + public function readMapEnd() { + return 0; + } + + public function readListBegin(&$elemType, &$size) { + return + $this->readByte($elemType) + + $this->readI32($size); + } + + public function readListEnd() { + return 0; + } + + public function readSetBegin(&$elemType, &$size) { + return + $this->readByte($elemType) + + $this->readI32($size); + } + + public function readSetEnd() { + return 0; + } + + public function readBool(&$value) { + $data = $this->trans_->readAll(1); + $arr = unpack('c', $data); + $value = $arr[1] == 1; + return 1; + } + + public function readByte(&$value) { + $data = $this->trans_->readAll(1); + $arr = unpack('c', $data); + $value = $arr[1]; + return 1; + } + + public function readI16(&$value) { + $data = $this->trans_->readAll(2); + $arr = unpack('n', $data); + $value = $arr[1]; + if ($value > 0x7fff) { + $value = 0 - (($value - 1) ^ 0xffff); + } + return 2; + } + + public function readI32(&$value) { + $data = $this->trans_->readAll(4); + $arr = unpack('N', $data); + $value = $arr[1]; + if ($value > 0x7fffffff) { + $value = 0 - (($value - 1) ^ 0xffffffff); + } + return 4; + } + + public function readI64(&$value) { + $data = $this->trans_->readAll(8); + + $arr = unpack('N2', $data); + + // If we are on a 32bit architecture we have to explicitly deal with + // 64-bit twos-complement arithmetic since PHP wants to treat all ints + // as signed and any int over 2^31 - 1 as a float + if (PHP_INT_SIZE == 4) { + + $hi = $arr[1]; + $lo = $arr[2]; + $isNeg = $hi < 0; + + // Check for a negative + if ($isNeg) { + $hi = ~$hi & (int)0xffffffff; + $lo = ~$lo & (int)0xffffffff; + + if ($lo == (int)0xffffffff) { + $hi++; + $lo = 0; + } else { + $lo++; + } + } + + // Force 32bit words in excess of 2G to pe positive - we deal wigh sign + // explicitly below + + if ($hi & (int)0x80000000) { + $hi &= (int)0x7fffffff; + $hi += 0x80000000; + } + + if ($lo & (int)0x80000000) { + $lo &= (int)0x7fffffff; + $lo += 0x80000000; + } + + $value = $hi * 4294967296 + $lo; + + if ($isNeg) { + $value = 0 - $value; + } + } else { + + // Upcast negatives in LSB bit + if ($arr[2] & 0x80000000) { + $arr[2] = $arr[2] & 0xffffffff; + } + + // Check for a negative + if ($arr[1] & 0x80000000) { + $arr[1] = $arr[1] & 0xffffffff; + $arr[1] = $arr[1] ^ 0xffffffff; + $arr[2] = $arr[2] ^ 0xffffffff; + $value = 0 - $arr[1]*4294967296 - $arr[2] - 1; + } else { + $value = $arr[1]*4294967296 + $arr[2]; + } + } + + return 8; + } + + public function readDouble(&$value) { + $data = strrev($this->trans_->readAll(8)); + $arr = unpack('d', $data); + $value = $arr[1]; + return 8; + } + + public function readString(&$value) { + $result = $this->readI32($len); + if ($len) { + $value = $this->trans_->readAll($len); + } else { + $value = ''; + } + return $result + $len; + } +} + +/** + * Binary Protocol Factory + */ +class TBinaryProtocolFactory implements TProtocolFactory { + private $strictRead_ = false; + private $strictWrite_ = false; + + public function __construct($strictRead=false, $strictWrite=false) { + $this->strictRead_ = $strictRead; + $this->strictWrite_ = $strictWrite; + } + + public function getProtocol($trans) { + return new TBinaryProtocol($trans, $this->strictRead, $this->strictWrite); + } +} + +/** + * Accelerated binary protocol: used in conjunction with the thrift_protocol + * extension for faster deserialization + */ +class TBinaryProtocolAccelerated extends TBinaryProtocol { + public function __construct($trans, $strictRead=false, $strictWrite=true) { + // If the transport doesn't implement putBack, wrap it in a + // TBufferedTransport (which does) + if (!method_exists($trans, 'putBack')) { + $trans = new TBufferedTransport($trans); + } + parent::__construct($trans, $strictRead, $strictWrite); + } + public function isStrictRead() { + return $this->strictRead_; + } + public function isStrictWrite() { + return $this->strictWrite_; + } +} + +?> diff --git a/lib/php/src/protocol/TProtocol.php b/lib/php/src/protocol/TProtocol.php new file mode 100644 index 000000000..e9ff41a32 --- /dev/null +++ b/lib/php/src/protocol/TProtocol.php @@ -0,0 +1,377 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.protocol + */ + + +/** + * Protocol module. Contains all the types and definitions needed to implement + * a protocol encoder/decoder. + * + * @package thrift.protocol + */ + +/** + * Protocol exceptions + */ +class TProtocolException extends TException { + const UNKNOWN = 0; + const INVALID_DATA = 1; + const NEGATIVE_SIZE = 2; + const SIZE_LIMIT = 3; + const BAD_VERSION = 4; + + function __construct($message=null, $code=0) { + parent::__construct($message, $code); + } +} + +/** + * Protocol base class module. + */ +abstract class TProtocol { + // The below may seem silly, but it is to get around the problem that the + // "instanceof" operator can only take in a T_VARIABLE and not a T_STRING + // or T_CONSTANT_ENCAPSED_STRING. Using "is_a()" instead of "instanceof" is + // a workaround but is deprecated in PHP5. This is used in the generated + // deserialization code. + static $TBINARYPROTOCOLACCELERATED = 'TBinaryProtocolAccelerated'; + + /** + * Underlying transport + * + * @var TTransport + */ + protected $trans_; + + /** + * Constructor + */ + protected function __construct($trans) { + $this->trans_ = $trans; + } + + /** + * Accessor for transport + * + * @return TTransport + */ + public function getTransport() { + return $this->trans_; + } + + /** + * Writes the message header + * + * @param string $name Function name + * @param int $type message type TMessageType::CALL or TMessageType::REPLY + * @param int $seqid The sequence id of this message + */ + public abstract function writeMessageBegin($name, $type, $seqid); + + /** + * Close the message + */ + public abstract function writeMessageEnd(); + + /** + * Writes a struct header. + * + * @param string $name Struct name + * @throws TException on write error + * @return int How many bytes written + */ + public abstract function writeStructBegin($name); + + /** + * Close a struct. + * + * @throws TException on write error + * @return int How many bytes written + */ + public abstract function writeStructEnd(); + + /* + * Starts a field. + * + * @param string $name Field name + * @param int $type Field type + * @param int $fid Field id + * @throws TException on write error + * @return int How many bytes written + */ + public abstract function writeFieldBegin($fieldName, $fieldType, $fieldId); + + public abstract function writeFieldEnd(); + + public abstract function writeFieldStop(); + + public abstract function writeMapBegin($keyType, $valType, $size); + + public abstract function writeMapEnd(); + + public abstract function writeListBegin($elemType, $size); + + public abstract function writeListEnd(); + + public abstract function writeSetBegin($elemType, $size); + + public abstract function writeSetEnd(); + + public abstract function writeBool($bool); + + public abstract function writeByte($byte); + + public abstract function writeI16($i16); + + public abstract function writeI32($i32); + + public abstract function writeI64($i64); + + public abstract function writeDouble($dub); + + public abstract function writeString($str); + + /** + * Reads the message header + * + * @param string $name Function name + * @param int $type message type TMessageType::CALL or TMessageType::REPLY + * @parem int $seqid The sequence id of this message + */ + public abstract function readMessageBegin(&$name, &$type, &$seqid); + + /** + * Read the close of message + */ + public abstract function readMessageEnd(); + + public abstract function readStructBegin(&$name); + + public abstract function readStructEnd(); + + public abstract function readFieldBegin(&$name, &$fieldType, &$fieldId); + + public abstract function readFieldEnd(); + + public abstract function readMapBegin(&$keyType, &$valType, &$size); + + public abstract function readMapEnd(); + + public abstract function readListBegin(&$elemType, &$size); + + public abstract function readListEnd(); + + public abstract function readSetBegin(&$elemType, &$size); + + public abstract function readSetEnd(); + + public abstract function readBool(&$bool); + + public abstract function readByte(&$byte); + + public abstract function readI16(&$i16); + + public abstract function readI32(&$i32); + + public abstract function readI64(&$i64); + + public abstract function readDouble(&$dub); + + public abstract function readString(&$str); + + /** + * The skip function is a utility to parse over unrecognized date without + * causing corruption. + * + * @param TType $type What type is it + */ + public function skip($type) { + switch ($type) { + case TType::BOOL: + return $this->readBool($bool); + case TType::BYTE: + return $this->readByte($byte); + case TType::I16: + return $this->readI16($i16); + case TType::I32: + return $this->readI32($i32); + case TType::I64: + return $this->readI64($i64); + case TType::DOUBLE: + return $this->readDouble($dub); + case TType::STRING: + return $this->readString($str); + case TType::STRUCT: + { + $result = $this->readStructBegin($name); + while (true) { + $result += $this->readFieldBegin($name, $ftype, $fid); + if ($ftype == TType::STOP) { + break; + } + $result += $this->skip($ftype); + $result += $this->readFieldEnd(); + } + $result += $this->readStructEnd(); + return $result; + } + case TType::MAP: + { + $result = $this->readMapBegin($keyType, $valType, $size); + for ($i = 0; $i < $size; $i++) { + $result += $this->skip($keyType); + $result += $this->skip($valType); + } + $result += $this->readMapEnd(); + return $result; + } + case TType::SET: + { + $result = $this->readSetBegin($elemType, $size); + for ($i = 0; $i < $size; $i++) { + $result += $this->skip($elemType); + } + $result += $this->readSetEnd(); + return $result; + } + case TType::LST: + { + $result = $this->readListBegin($elemType, $size); + for ($i = 0; $i < $size; $i++) { + $result += $this->skip($elemType); + } + $result += $this->readListEnd(); + return $result; + } + default: + return 0; + } + } + + /** + * Utility for skipping binary data + * + * @param TTransport $itrans TTransport object + * @param int $type Field type + */ + public static function skipBinary($itrans, $type) { + switch ($type) { + case TType::BOOL: + return $itrans->readAll(1); + case TType::BYTE: + return $itrans->readAll(1); + case TType::I16: + return $itrans->readAll(2); + case TType::I32: + return $itrans->readAll(4); + case TType::I64: + return $itrans->readAll(8); + case TType::DOUBLE: + return $itrans->readAll(8); + case TType::STRING: + $len = unpack('N', $itrans->readAll(4)); + $len = $len[1]; + if ($len > 0x7fffffff) { + $len = 0 - (($len - 1) ^ 0xffffffff); + } + return 4 + $itrans->readAll($len); + case TType::STRUCT: + { + $result = 0; + while (true) { + $ftype = 0; + $fid = 0; + $data = $itrans->readAll(1); + $arr = unpack('c', $data); + $ftype = $arr[1]; + if ($ftype == TType::STOP) { + break; + } + // I16 field id + $result += $itrans->readAll(2); + $result += self::skipBinary($itrans, $ftype); + } + return $result; + } + case TType::MAP: + { + // Ktype + $data = $itrans->readAll(1); + $arr = unpack('c', $data); + $ktype = $arr[1]; + // Vtype + $data = $itrans->readAll(1); + $arr = unpack('c', $data); + $vtype = $arr[1]; + // Size + $data = $itrans->readAll(4); + $arr = unpack('N', $data); + $size = $arr[1]; + if ($size > 0x7fffffff) { + $size = 0 - (($size - 1) ^ 0xffffffff); + } + $result = 6; + for ($i = 0; $i < $size; $i++) { + $result += self::skipBinary($itrans, $ktype); + $result += self::skipBinary($itrans, $vtype); + } + return $result; + } + case TType::SET: + case TType::LST: + { + // Vtype + $data = $itrans->readAll(1); + $arr = unpack('c', $data); + $vtype = $arr[1]; + // Size + $data = $itrans->readAll(4); + $arr = unpack('N', $data); + $size = $arr[1]; + if ($size > 0x7fffffff) { + $size = 0 - (($size - 1) ^ 0xffffffff); + } + $result = 5; + for ($i = 0; $i < $size; $i++) { + $result += self::skipBinary($itrans, $vtype); + } + return $result; + } + default: + return 0; + } + } +} + +/** + * Protocol factory creates protocol objects from transports + */ +interface TProtocolFactory { + /** + * Build a protocol from the base transport + * + * @return TProtcol protocol + */ + public function getProtocol($trans); +} + + +?> diff --git a/lib/php/src/transport/TBufferedTransport.php b/lib/php/src/transport/TBufferedTransport.php new file mode 100644 index 000000000..cfae767ec --- /dev/null +++ b/lib/php/src/transport/TBufferedTransport.php @@ -0,0 +1,163 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.transport + */ + + +/** + * Buffered transport. Stores data to an internal buffer that it doesn't + * actually write out until flush is called. For reading, we do a greedy + * read and then serve data out of the internal buffer. + * + * @package thrift.transport + */ +class TBufferedTransport extends TTransport { + + /** + * Constructor. Creates a buffered transport around an underlying transport + */ + public function __construct($transport=null, $rBufSize=512, $wBufSize=512) { + $this->transport_ = $transport; + $this->rBufSize_ = $rBufSize; + $this->wBufSize_ = $wBufSize; + } + + /** + * The underlying transport + * + * @var TTransport + */ + protected $transport_ = null; + + /** + * The receive buffer size + * + * @var int + */ + protected $rBufSize_ = 512; + + /** + * The write buffer size + * + * @var int + */ + protected $wBufSize_ = 512; + + /** + * The write buffer. + * + * @var string + */ + protected $wBuf_ = ''; + + /** + * The read buffer. + * + * @var string + */ + protected $rBuf_ = ''; + + public function isOpen() { + return $this->transport_->isOpen(); + } + + public function open() { + $this->transport_->open(); + } + + public function close() { + $this->transport_->close(); + } + + public function putBack($data) { + if (strlen($this->rBuf_) === 0) { + $this->rBuf_ = $data; + } else { + $this->rBuf_ = ($data . $this->rBuf_); + } + } + + /** + * The reason that we customize readAll here is that the majority of PHP + * streams are already internally buffered by PHP. The socket stream, for + * example, buffers internally and blocks if you call read with $len greater + * than the amount of data available, unlike recv() in C. + * + * Therefore, use the readAll method of the wrapped transport inside + * the buffered readAll. + */ + public function readAll($len) { + $have = strlen($this->rBuf_); + if ($have == 0) { + $data = $this->transport_->readAll($len); + } else if ($have < $len) { + $data = $this->rBuf_; + $this->rBuf_ = ''; + $data .= $this->transport_->readAll($len - $have); + } else if ($have == $len) { + $data = $this->rBuf_; + $this->rBuf_ = ''; + } else if ($have > $len) { + $data = substr($this->rBuf_, 0, $len); + $this->rBuf_ = substr($this->rBuf_, $len); + } + return $data; + } + + public function read($len) { + if (strlen($this->rBuf_) === 0) { + $this->rBuf_ = $this->transport_->read($this->rBufSize_); + } + + if (strlen($this->rBuf_) <= $len) { + $ret = $this->rBuf_; + $this->rBuf_ = ''; + return $ret; + } + + $ret = substr($this->rBuf_, 0, $len); + $this->rBuf_ = substr($this->rBuf_, $len); + return $ret; + } + + public function write($buf) { + $this->wBuf_ .= $buf; + if (strlen($this->wBuf_) >= $this->wBufSize_) { + $out = $this->wBuf_; + + // Note that we clear the internal wBuf_ prior to the underlying write + // to ensure we're in a sane state (i.e. internal buffer cleaned) + // if the underlying write throws up an exception + $this->wBuf_ = ''; + $this->transport_->write($out); + } + } + + public function flush() { + if (strlen($this->wBuf_) > 0) { + $this->transport_->write($this->wBuf_); + $this->wBuf_ = ''; + } + $this->transport_->flush(); + } + +} + +?> diff --git a/lib/php/src/transport/TFramedTransport.php b/lib/php/src/transport/TFramedTransport.php new file mode 100644 index 000000000..dc57392f7 --- /dev/null +++ b/lib/php/src/transport/TFramedTransport.php @@ -0,0 +1,179 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.transport + */ + + +/** + * Framed transport. Writes and reads data in chunks that are stamped with + * their length. + * + * @package thrift.transport + */ +class TFramedTransport extends TTransport { + + /** + * Underlying transport object. + * + * @var TTransport + */ + private $transport_; + + /** + * Buffer for read data. + * + * @var string + */ + private $rBuf_; + + /** + * Buffer for queued output data + * + * @var string + */ + private $wBuf_; + + /** + * Whether to frame reads + * + * @var bool + */ + private $read_; + + /** + * Whether to frame writes + * + * @var bool + */ + private $write_; + + /** + * Constructor. + * + * @param TTransport $transport Underlying transport + */ + public function __construct($transport=null, $read=true, $write=true) { + $this->transport_ = $transport; + $this->read_ = $read; + $this->write_ = $write; + } + + public function isOpen() { + return $this->transport_->isOpen(); + } + + public function open() { + $this->transport_->open(); + } + + public function close() { + $this->transport_->close(); + } + + /** + * Reads from the buffer. When more data is required reads another entire + * chunk and serves future reads out of that. + * + * @param int $len How much data + */ + public function read($len) { + if (!$this->read_) { + return $this->transport_->read($len); + } + + if (strlen($this->rBuf_) === 0) { + $this->readFrame(); + } + + // Just return full buff + if ($len >= strlen($this->rBuf_)) { + $out = $this->rBuf_; + $this->rBuf_ = null; + return $out; + } + + // Return substr + $out = substr($this->rBuf_, 0, $len); + $this->rBuf_ = substr($this->rBuf_, $len); + return $out; + } + + /** + * Put previously read data back into the buffer + * + * @param string $data data to return + */ + public function putBack($data) { + if (strlen($this->rBuf_) === 0) { + $this->rBuf_ = $data; + } else { + $this->rBuf_ = ($data . $this->rBuf_); + } + } + + /** + * Reads a chunk of data into the internal read buffer. + */ + private function readFrame() { + $buf = $this->transport_->readAll(4); + $val = unpack('N', $buf); + $sz = $val[1]; + + $this->rBuf_ = $this->transport_->readAll($sz); + } + + /** + * Writes some data to the pending output buffer. + * + * @param string $buf The data + * @param int $len Limit of bytes to write + */ + public function write($buf, $len=null) { + if (!$this->write_) { + return $this->transport_->write($buf, $len); + } + + if ($len !== null && $len < strlen($buf)) { + $buf = substr($buf, 0, $len); + } + $this->wBuf_ .= $buf; + } + + /** + * Writes the output buffer to the stream in the format of a 4-byte length + * followed by the actual data. + */ + public function flush() { + if (!$this->write_) { + return $this->transport_->flush(); + } + + $out = pack('N', strlen($this->wBuf_)); + $out .= $this->wBuf_; + + // Note that we clear the internal wBuf_ prior to the underlying write + // to ensure we're in a sane state (i.e. internal buffer cleaned) + // if the underlying write throws up an exception + $this->wBuf_ = ''; + $this->transport_->write($out); + $this->transport_->flush(); + } + +} diff --git a/lib/php/src/transport/THttpClient.php b/lib/php/src/transport/THttpClient.php new file mode 100644 index 000000000..224d403b4 --- /dev/null +++ b/lib/php/src/transport/THttpClient.php @@ -0,0 +1,202 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.transport + */ + + +/** + * HTTP client for Thrift + * + * @package thrift.transport + */ +class THttpClient extends TTransport { + + /** + * The host to connect to + * + * @var string + */ + protected $host_; + + /** + * The port to connect on + * + * @var int + */ + protected $port_; + + /** + * The URI to request + * + * @var string + */ + protected $uri_; + + /** + * The scheme to use for the request, i.e. http, https + * + * @var string + */ + protected $scheme_; + + /** + * Buffer for the HTTP request data + * + * @var string + */ + protected $buf_; + + /** + * Input socket stream. + * + * @var resource + */ + protected $handle_; + + /** + * Read timeout + * + * @var float + */ + protected $timeout_; + + /** + * Make a new HTTP client. + * + * @param string $host + * @param int $port + * @param string $uri + */ + public function __construct($host, $port=80, $uri='', $scheme = 'http') { + if ((strlen($uri) > 0) && ($uri{0} != '/')) { + $uri = '/'.$uri; + } + $this->scheme_ = $scheme; + $this->host_ = $host; + $this->port_ = $port; + $this->uri_ = $uri; + $this->buf_ = ''; + $this->handle_ = null; + $this->timeout_ = null; + } + + /** + * Set read timeout + * + * @param float $timeout + */ + public function setTimeoutSecs($timeout) { + $this->timeout_ = $timeout; + } + + /** + * Whether this transport is open. + * + * @return boolean true if open + */ + public function isOpen() { + return true; + } + + /** + * Open the transport for reading/writing + * + * @throws TTransportException if cannot open + */ + public function open() {} + + /** + * Close the transport. + */ + public function close() { + if ($this->handle_) { + @fclose($this->handle_); + $this->handle_ = null; + } + } + + /** + * Read some data into the array. + * + * @param int $len How much to read + * @return string The data that has been read + * @throws TTransportException if cannot read any more data + */ + public function read($len) { + $data = @fread($this->handle_, $len); + if ($data === FALSE || $data === '') { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TTransportException('THttpClient: timed out reading '.$len.' bytes from '.$this->host_.':'.$this->port_.'/'.$this->uri_, TTransportException::TIMED_OUT); + } else { + throw new TTransportException('THttpClient: Could not read '.$len.' bytes from '.$this->host_.':'.$this->port_.'/'.$this->uri_, TTransportException::UNKNOWN); + } + } + return $data; + } + + /** + * Writes some data into the pending buffer + * + * @param string $buf The data to write + * @throws TTransportException if writing fails + */ + public function write($buf) { + $this->buf_ .= $buf; + } + + /** + * Opens and sends the actual request over the HTTP connection + * + * @throws TTransportException if a writing error occurs + */ + public function flush() { + // God, PHP really has some esoteric ways of doing simple things. + $host = $this->host_.($this->port_ != 80 ? ':'.$this->port_ : ''); + + $headers = array('Host: '.$host, + 'Accept: application/x-thrift', + 'User-Agent: PHP/THttpClient', + 'Content-Type: application/x-thrift', + 'Content-Length: '.strlen($this->buf_)); + + $options = array('method' => 'POST', + 'header' => implode("\r\n", $headers), + 'max_redirects' => 1, + 'content' => $this->buf_); + if ($this->timeout_ > 0) { + $options['timeout'] = $this->timeout_; + } + $this->buf_ = ''; + + $contextid = stream_context_create(array('http' => $options)); + $this->handle_ = @fopen($this->scheme_.'://'.$host.$this->uri_, 'r', false, $contextid); + + // Connect failed? + if ($this->handle_ === FALSE) { + $this->handle_ = null; + $error = 'THttpClient: Could not connect to '.$host.$this->uri_; + throw new TTransportException($error, TTransportException::NOT_OPEN); + } + } + +} + +?> diff --git a/lib/php/src/transport/TMemoryBuffer.php b/lib/php/src/transport/TMemoryBuffer.php new file mode 100644 index 000000000..01eb0f5a6 --- /dev/null +++ b/lib/php/src/transport/TMemoryBuffer.php @@ -0,0 +1,84 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.transport + */ + + +/** + * A memory buffer is a tranpsort that simply reads from and writes to an + * in-memory string buffer. Anytime you call write on it, the data is simply + * placed into a buffer, and anytime you call read, data is read from that + * buffer. + * + * @package thrift.transport + */ +class TMemoryBuffer extends TTransport { + + /** + * Constructor. Optionally pass an initial value + * for the buffer. + */ + public function __construct($buf = '') { + $this->buf_ = $buf; + } + + protected $buf_ = ''; + + public function isOpen() { + return true; + } + + public function open() {} + + public function close() {} + + public function write($buf) { + $this->buf_ .= $buf; + } + + public function read($len) { + if (strlen($this->buf_) === 0) { + throw new TTransportException('TMemoryBuffer: Could not read ' . + $len . ' bytes from buffer.', + TTransportException::UNKNOWN); + } + + if (strlen($this->buf_) <= $len) { + $ret = $this->buf_; + $this->buf_ = ''; + return $ret; + } + + $ret = substr($this->buf_, 0, $len); + $this->buf_ = substr($this->buf_, $len); + + return $ret; + } + + function getBuffer() { + return $this->buf_; + } + + public function available() { + return strlen($this->buf_); + } +} + +?> diff --git a/lib/php/src/transport/TNullTransport.php b/lib/php/src/transport/TNullTransport.php new file mode 100644 index 000000000..bada5dfb7 --- /dev/null +++ b/lib/php/src/transport/TNullTransport.php @@ -0,0 +1,48 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.transport + */ + + +/** + * Transport that only accepts writes and ignores them. + * This is useful for measuring the serialized size of structures. + * + * @package thrift.transport + */ +class TNullTransport extends TTransport { + + public function isOpen() { + return true; + } + + public function open() {} + + public function close() {} + + public function read($len) { + throw new TTransportException("Can't read from TNullTransport."); + } + + public function write($buf) {} + +} + +?> diff --git a/lib/php/src/transport/TPhpStream.php b/lib/php/src/transport/TPhpStream.php new file mode 100644 index 000000000..3a1c80b8f --- /dev/null +++ b/lib/php/src/transport/TPhpStream.php @@ -0,0 +1,111 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.transport + */ + + +/** + * Php stream transport. Reads to and writes from the php standard streams + * php://input and php://output + * + * @package thrift.transport + */ +class TPhpStream extends TTransport { + + const MODE_R = 1; + const MODE_W = 2; + + private $inStream_ = null; + + private $outStream_ = null; + + private $read_ = false; + + private $write_ = false; + + public function __construct($mode) { + $this->read_ = $mode & self::MODE_R; + $this->write_ = $mode & self::MODE_W; + } + + public function open() { + if ($this->read_) { + $this->inStream_ = @fopen(self::inStreamName(), 'r'); + if (!is_resource($this->inStream_)) { + throw new TException('TPhpStream: Could not open php://input'); + } + } + if ($this->write_) { + $this->outStream_ = @fopen('php://output', 'w'); + if (!is_resource($this->outStream_)) { + throw new TException('TPhpStream: Could not open php://output'); + } + } + } + + public function close() { + if ($this->read_) { + @fclose($this->inStream_); + $this->inStream_ = null; + } + if ($this->write_) { + @fclose($this->outStream_); + $this->outStream_ = null; + } + } + + public function isOpen() { + return + (!$this->read_ || is_resource($this->inStream_)) && + (!$this->write_ || is_resource($this->outStream_)); + } + + public function read($len) { + $data = @fread($this->inStream_, $len); + if ($data === FALSE || $data === '') { + throw new TException('TPhpStream: Could not read '.$len.' bytes'); + } + return $data; + } + + public function write($buf) { + while (strlen($buf) > 0) { + $got = @fwrite($this->outStream_, $buf); + if ($got === 0 || $got === FALSE) { + throw new TException('TPhpStream: Could not write '.strlen($buf).' bytes'); + } + $buf = substr($buf, $got); + } + } + + public function flush() { + @fflush($this->outStream_); + } + + private static function inStreamName() { + if (php_sapi_name() == 'cli') { + return 'php://stdin'; + } + return 'php://input'; + } + +} + +?> diff --git a/lib/php/src/transport/TSocket.php b/lib/php/src/transport/TSocket.php new file mode 100644 index 000000000..ba3a63184 --- /dev/null +++ b/lib/php/src/transport/TSocket.php @@ -0,0 +1,312 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.transport + */ + + +/** + * Sockets implementation of the TTransport interface. + * + * @package thrift.transport + */ +class TSocket extends TTransport { + + /** + * Handle to PHP socket + * + * @var resource + */ + private $handle_ = null; + + /** + * Remote hostname + * + * @var string + */ + protected $host_ = 'localhost'; + + /** + * Remote port + * + * @var int + */ + protected $port_ = '9090'; + + /** + * Send timeout in milliseconds + * + * @var int + */ + private $sendTimeout_ = 100; + + /** + * Recv timeout in milliseconds + * + * @var int + */ + private $recvTimeout_ = 750; + + /** + * Is send timeout set? + * + * @var bool + */ + private $sendTimeoutSet_ = FALSE; + + /** + * Persistent socket or plain? + * + * @var bool + */ + private $persist_ = FALSE; + + /** + * Debugging on? + * + * @var bool + */ + protected $debug_ = FALSE; + + /** + * Debug handler + * + * @var mixed + */ + protected $debugHandler_ = null; + + /** + * Socket constructor + * + * @param string $host Remote hostname + * @param int $port Remote port + * @param bool $persist Whether to use a persistent socket + * @param string $debugHandler Function to call for error logging + */ + public function __construct($host='localhost', + $port=9090, + $persist=FALSE, + $debugHandler=null) { + $this->host_ = $host; + $this->port_ = $port; + $this->persist_ = $persist; + $this->debugHandler_ = $debugHandler ? $debugHandler : 'error_log'; + } + + /** + * Sets the send timeout. + * + * @param int $timeout Timeout in milliseconds. + */ + public function setSendTimeout($timeout) { + $this->sendTimeout_ = $timeout; + } + + /** + * Sets the receive timeout. + * + * @param int $timeout Timeout in milliseconds. + */ + public function setRecvTimeout($timeout) { + $this->recvTimeout_ = $timeout; + } + + /** + * Sets debugging output on or off + * + * @param bool $debug + */ + public function setDebug($debug) { + $this->debug_ = $debug; + } + + /** + * Get the host that this socket is connected to + * + * @return string host + */ + public function getHost() { + return $this->host_; + } + + /** + * Get the remote port that this socket is connected to + * + * @return int port + */ + public function getPort() { + return $this->port_; + } + + /** + * Tests whether this is open + * + * @return bool true if the socket is open + */ + public function isOpen() { + return is_resource($this->handle_); + } + + /** + * Connects the socket. + */ + public function open() { + + if ($this->persist_) { + $this->handle_ = @pfsockopen($this->host_, + $this->port_, + $errno, + $errstr, + $this->sendTimeout_/1000.0); + } else { + $this->handle_ = @fsockopen($this->host_, + $this->port_, + $errno, + $errstr, + $this->sendTimeout_/1000.0); + } + + // Connect failed? + if ($this->handle_ === FALSE) { + $error = 'TSocket: Could not connect to '.$this->host_.':'.$this->port_.' ('.$errstr.' ['.$errno.'])'; + if ($this->debug_) { + call_user_func($this->debugHandler_, $error); + } + throw new TException($error); + } + + stream_set_timeout($this->handle_, 0, $this->sendTimeout_*1000); + $this->sendTimeoutSet_ = TRUE; + } + + /** + * Closes the socket. + */ + public function close() { + if (!$this->persist_) { + @fclose($this->handle_); + $this->handle_ = null; + } + } + + /** + * Uses stream get contents to do the reading + * + * @param int $len How many bytes + * @return string Binary data + */ + public function readAll($len) { + if ($this->sendTimeoutSet_) { + stream_set_timeout($this->handle_, 0, $this->recvTimeout_*1000); + $this->sendTimeoutSet_ = FALSE; + } + // This call does not obey stream_set_timeout values! + // $buf = @stream_get_contents($this->handle_, $len); + + $pre = null; + while (TRUE) { + $buf = @fread($this->handle_, $len); + if ($buf === FALSE || $buf === '') { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TException('TSocket: timed out reading '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } else { + throw new TException('TSocket: Could not read '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } + } else if (($sz = strlen($buf)) < $len) { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TException('TSocket: timed out reading '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } else { + $pre .= $buf; + $len -= $sz; + } + } else { + return $pre.$buf; + } + } + } + + /** + * Read from the socket + * + * @param int $len How many bytes + * @return string Binary data + */ + public function read($len) { + if ($this->sendTimeoutSet_) { + stream_set_timeout($this->handle_, 0, $this->recvTimeout_*1000); + $this->sendTimeoutSet_ = FALSE; + } + $data = @fread($this->handle_, $len); + if ($data === FALSE || $data === '') { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TException('TSocket: timed out reading '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } else { + throw new TException('TSocket: Could not read '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } + } + return $data; + } + + /** + * Write to the socket. + * + * @param string $buf The data to write + */ + public function write($buf) { + if (!$this->sendTimeoutSet_) { + stream_set_timeout($this->handle_, 0, $this->sendTimeout_*1000); + $this->sendTimeoutSet_ = TRUE; + } + while (strlen($buf) > 0) { + $got = @fwrite($this->handle_, $buf); + if ($got === 0 || $got === FALSE) { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TException('TSocket: timed out writing '.strlen($buf).' bytes from '. + $this->host_.':'.$this->port_); + } else { + throw new TException('TSocket: Could not write '.strlen($buf).' bytes '. + $this->host_.':'.$this->port_); + } + } + $buf = substr($buf, $got); + } + } + + /** + * Flush output to the socket. + */ + public function flush() { + $ret = fflush($this->handle_); + if ($ret === FALSE) { + throw new TException('TSocket: Could not flush: '. + $this->host_.':'.$this->port_); + } + } +} + +?> diff --git a/lib/php/src/transport/TSocketPool.php b/lib/php/src/transport/TSocketPool.php new file mode 100644 index 000000000..7f1157cb4 --- /dev/null +++ b/lib/php/src/transport/TSocketPool.php @@ -0,0 +1,296 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.transport + */ + + +/** Inherits from Socket */ +include_once $GLOBALS['THRIFT_ROOT'].'/transport/TSocket.php'; + +/** + * This library makes use of APC cache to make hosts as down in a web + * environment. If you are running from the CLI or on a system without APC + * installed, then these null functions will step in and act like cache + * misses. + */ +if (!function_exists('apc_fetch')) { + function apc_fetch($key) { return FALSE; } + function apc_store($key, $var, $ttl=0) { return FALSE; } +} + +/** + * Sockets implementation of the TTransport interface that allows connection + * to a pool of servers. + * + * @package thrift.transport + */ +class TSocketPool extends TSocket { + + /** + * Remote servers. Array of associative arrays with 'host' and 'port' keys + */ + private $servers_ = array(); + + /** + * How many times to retry each host in connect + * + * @var int + */ + private $numRetries_ = 1; + + /** + * Retry interval in seconds, how long to not try a host if it has been + * marked as down. + * + * @var int + */ + private $retryInterval_ = 60; + + /** + * Max consecutive failures before marking a host down. + * + * @var int + */ + private $maxConsecutiveFailures_ = 1; + + /** + * Try hosts in order? or Randomized? + * + * @var bool + */ + private $randomize_ = TRUE; + + /** + * Always try last host, even if marked down? + * + * @var bool + */ + private $alwaysTryLast_ = TRUE; + + /** + * Socket pool constructor + * + * @param array $hosts List of remote hostnames + * @param mixed $ports Array of remote ports, or a single common port + * @param bool $persist Whether to use a persistent socket + * @param mixed $debugHandler Function for error logging + */ + public function __construct($hosts=array('localhost'), + $ports=array(9090), + $persist=FALSE, + $debugHandler=null) { + parent::__construct(null, 0, $persist, $debugHandler); + + if (!is_array($ports)) { + $port = $ports; + $ports = array(); + foreach ($hosts as $key => $val) { + $ports[$key] = $port; + } + } + + foreach ($hosts as $key => $host) { + $this->servers_ []= array('host' => $host, + 'port' => $ports[$key]); + } + } + + /** + * Add a server to the pool + * + * This function does not prevent you from adding a duplicate server entry. + * + * @param string $host hostname or IP + * @param int $port port + */ + public function addServer($host, $port) { + $this->servers_[] = array('host' => $host, 'port' => $port); + } + + /** + * Sets how many time to keep retrying a host in the connect function. + * + * @param int $numRetries + */ + public function setNumRetries($numRetries) { + $this->numRetries_ = $numRetries; + } + + /** + * Sets how long to wait until retrying a host if it was marked down + * + * @param int $numRetries + */ + public function setRetryInterval($retryInterval) { + $this->retryInterval_ = $retryInterval; + } + + /** + * Sets how many time to keep retrying a host before marking it as down. + * + * @param int $numRetries + */ + public function setMaxConsecutiveFailures($maxConsecutiveFailures) { + $this->maxConsecutiveFailures_ = $maxConsecutiveFailures; + } + + /** + * Turns randomization in connect order on or off. + * + * @param bool $randomize + */ + public function setRandomize($randomize) { + $this->randomize_ = $randomize; + } + + /** + * Whether to always try the last server. + * + * @param bool $alwaysTryLast + */ + public function setAlwaysTryLast($alwaysTryLast) { + $this->alwaysTryLast_ = $alwaysTryLast; + } + + + /** + * Connects the socket by iterating through all the servers in the pool + * and trying to find one that works. + */ + public function open() { + // Check if we want order randomization + if ($this->randomize_) { + shuffle($this->servers_); + } + + // Count servers to identify the "last" one + $numServers = count($this->servers_); + + for ($i = 0; $i < $numServers; ++$i) { + + // This extracts the $host and $port variables + extract($this->servers_[$i]); + + // Check APC cache for a record of this server being down + $failtimeKey = 'thrift_failtime:'.$host.':'.$port.'~'; + + // Cache miss? Assume it's OK + $lastFailtime = apc_fetch($failtimeKey); + if ($lastFailtime === FALSE) { + $lastFailtime = 0; + } + + $retryIntervalPassed = FALSE; + + // Cache hit...make sure enough the retry interval has elapsed + if ($lastFailtime > 0) { + $elapsed = time() - $lastFailtime; + if ($elapsed > $this->retryInterval_) { + $retryIntervalPassed = TRUE; + if ($this->debug_) { + call_user_func($this->debugHandler_, + 'TSocketPool: retryInterval '. + '('.$this->retryInterval_.') '. + 'has passed for host '.$host.':'.$port); + } + } + } + + // Only connect if not in the middle of a fail interval, OR if this + // is the LAST server we are trying, just hammer away on it + $isLastServer = FALSE; + if ($this->alwaysTryLast_) { + $isLastServer = ($i == ($numServers - 1)); + } + + if (($lastFailtime === 0) || + ($isLastServer) || + ($lastFailtime > 0 && $retryIntervalPassed)) { + + // Set underlying TSocket params to this one + $this->host_ = $host; + $this->port_ = $port; + + // Try up to numRetries_ connections per server + for ($attempt = 0; $attempt < $this->numRetries_; $attempt++) { + try { + // Use the underlying TSocket open function + parent::open(); + + // Only clear the failure counts if required to do so + if ($lastFailtime > 0) { + apc_store($failtimeKey, 0); + } + + // Successful connection, return now + return; + + } catch (TException $tx) { + // Connection failed + } + } + + // Mark failure of this host in the cache + $consecfailsKey = 'thrift_consecfails:'.$host.':'.$port.'~'; + + // Ignore cache misses + $consecfails = apc_fetch($consecfailsKey); + if ($consecfails === FALSE) { + $consecfails = 0; + } + + // Increment by one + $consecfails++; + + // Log and cache this failure + if ($consecfails >= $this->maxConsecutiveFailures_) { + if ($this->debug_) { + call_user_func($this->debugHandler_, + 'TSocketPool: marking '.$host.':'.$port. + ' as down for '.$this->retryInterval_.' secs '. + 'after '.$consecfails.' failed attempts.'); + } + // Store the failure time + apc_store($failtimeKey, time()); + + // Clear the count of consecutive failures + apc_store($consecfailsKey, 0); + } else { + apc_store($consecfailsKey, $consecfails); + } + } + } + + // Holy shit we failed them all. The system is totally ill! + $error = 'TSocketPool: All hosts in pool are down. '; + $hosts = array(); + foreach ($this->servers_ as $server) { + $hosts []= $server['host'].':'.$server['port']; + } + $hostlist = implode(',', $hosts); + $error .= '('.$hostlist.')'; + if ($this->debug_) { + call_user_func($this->debugHandler_, $error); + } + throw new TException($error); + } +} + +?> diff --git a/lib/php/src/transport/TTransport.php b/lib/php/src/transport/TTransport.php new file mode 100644 index 000000000..e24452596 --- /dev/null +++ b/lib/php/src/transport/TTransport.php @@ -0,0 +1,108 @@ +<?php +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * @package thrift.transport + */ + + +/** + * Transport exceptions + */ +class TTransportException extends TException { + + const UNKNOWN = 0; + const NOT_OPEN = 1; + const ALREADY_OPEN = 2; + const TIMED_OUT = 3; + const END_OF_FILE = 4; + + function __construct($message=null, $code=0) { + parent::__construct($message, $code); + } +} + +/** + * Base interface for a transport agent. + * + * @package thrift.transport + */ +abstract class TTransport { + + /** + * Whether this transport is open. + * + * @return boolean true if open + */ + public abstract function isOpen(); + + /** + * Open the transport for reading/writing + * + * @throws TTransportException if cannot open + */ + public abstract function open(); + + /** + * Close the transport. + */ + public abstract function close(); + + /** + * Read some data into the array. + * + * @param int $len How much to read + * @return string The data that has been read + * @throws TTransportException if cannot read any more data + */ + public abstract function read($len); + + /** + * Guarantees that the full amount of data is read. + * + * @return string The data, of exact length + * @throws TTransportException if cannot read data + */ + public function readAll($len) { + // return $this->read($len); + + $data = ''; + $got = 0; + while (($got = strlen($data)) < $len) { + $data .= $this->read($len - $got); + } + return $data; + } + + /** + * Writes the given data out. + * + * @param string $buf The data to write + * @throws TTransportException if writing fails + */ + public abstract function write($buf); + + /** + * Flushes any pending data out of a buffer + * + * @throws TTransportException if a writing error occurs + */ + public function flush() {} +} + +?> diff --git a/lib/py/Makefile.am b/lib/py/Makefile.am new file mode 100644 index 000000000..6d3bbeaa3 --- /dev/null +++ b/lib/py/Makefile.am @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +DESTDIR ?= / +EXTRA_DIST = setup.py src + +all-local: + $(PYTHON) setup.py build + +# We're ignoring prefix here because site-packages seems to be +# the equivalent of /usr/local/lib in Python land. +# Old version (can't put inline because it's not portable). +#$(PYTHON) setup.py install --prefix=$(prefix) --root=$(DESTDIR) $(PYTHON_SETUPUTIL_ARGS) +install-exec-hook: + $(PYTHON) setup.py install --root=$(DESTDIR) --prefix=$(PY_PREFIX) $(PYTHON_SETUPUTIL_ARGS) + +clean-local: + $(RM) -r build + +check-local: all diff --git a/lib/py/README b/lib/py/README new file mode 100644 index 000000000..29b8c73c4 --- /dev/null +++ b/lib/py/README @@ -0,0 +1,35 @@ +Thrift Python Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with Python +======================== + +Thrift is provided as a set of Python packages. The top level package is +thrift, and there are subpackages for the protocol, transport, and server +code. Each package contains modules using standard Thrift naming conventions +(i.e. TProtocol, TTransport) and implementations in corresponding modules +(i.e. TSocket). There is also a subpackage reflection, which contains +the generated code for the reflection structures. + +The Python libraries can be installed manually using the provided setup.py +file, or automatically using the install hook provided via autoconf/automake. +To use the latter, become superuser and do make install. diff --git a/lib/py/setup.py b/lib/py/setup.py new file mode 100644 index 000000000..748371156 --- /dev/null +++ b/lib/py/setup.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from distutils.core import setup, Extension + +fastbinarymod = Extension('thrift.protocol.fastbinary', + sources = ['src/protocol/fastbinary.c'], + ) + +setup(name = 'Thrift', + version = '1.0', + description = 'Thrift Python Libraries', + author = ['Mark Slee'], + author_email = ['mcslee@facebook.com'], + url = 'http://code.facebook.com/thrift', + packages = [ + 'thrift', + 'thrift.protocol', + 'thrift.transport', + 'thrift.server', + ], + package_dir = {'thrift' : 'src'}, + ext_modules = [fastbinarymod], + ) + diff --git a/lib/py/src/TSCons.py b/lib/py/src/TSCons.py new file mode 100644 index 000000000..24046256c --- /dev/null +++ b/lib/py/src/TSCons.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from os import path +from SCons.Builder import Builder + +def scons_env(env, add=''): + opath = path.dirname(path.abspath('$TARGET')) + lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' + cppbuild = Builder(action = lstr) + env.Append(BUILDERS = {'ThriftCpp' : cppbuild}) + +def gen_cpp(env, dir, file): + scons_env(env) + suffixes = ['_types.h', '_types.cpp'] + targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) + return env.ThriftCpp(targets, dir+file+'.thrift') diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py new file mode 100644 index 000000000..21d7aa4e5 --- /dev/null +++ b/lib/py/src/Thrift.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +class TType: + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 + +class TMessageType: + CALL = 1 + REPLY = 2 + EXCEPTION = 3 + ONEWAY = 4 + +class TProcessor: + + """Base class for procsessor, which works on two streams.""" + + def process(iprot, oprot): + pass + +class TException(Exception): + + """Base class for all thrift exceptions.""" + + def __init__(self, message=None): + Exception.__init__(self, message) + self.message = message + +class TApplicationException(TException): + + """Application level thrift exceptions.""" + + UNKNOWN = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + + def __str__(self): + if self.message: + return self.message + elif self.type == UNKNOWN_METHOD: + return 'Unknown method' + elif self.type == INVALID_MESSAGE_TYPE: + return 'Invalid message type' + elif self.type == WRONG_METHOD_NAME: + return 'Wrong method name' + elif self.type == BAD_SEQUENCE_ID: + return 'Bad sequence ID' + elif self.type == MISSING_RESULT: + return 'Missing result' + else: + return 'Default (unknown) TApplicationException' + + def read(self, iprot): + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRING: + self.message = iprot.readString(); + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.I32: + self.type = iprot.readI32(); + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + oprot.writeStructBegin('TApplicationException') + if self.message != None: + oprot.writeFieldBegin('message', TType.STRING, 1) + oprot.writeString(self.message) + oprot.writeFieldEnd() + if self.type != None: + oprot.writeFieldBegin('type', TType.I32, 2) + oprot.writeI32(self.type) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() diff --git a/lib/py/src/__init__.py b/lib/py/src/__init__.py new file mode 100644 index 000000000..48d659c40 --- /dev/null +++ b/lib/py/src/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['Thrift', 'TSCons'] diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py new file mode 100644 index 000000000..db1a7a407 --- /dev/null +++ b/lib/py/src/protocol/TBinaryProtocol.py @@ -0,0 +1,259 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import * +from struct import pack, unpack + +class TBinaryProtocol(TProtocolBase): + + """Binary implementation of the Thrift protocol driver.""" + + # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be + # positive, converting this into a long. If we hardcode the int value + # instead it'll stay in 32 bit-land. + + # VERSION_MASK = 0xffff0000 + VERSION_MASK = -65536 + + # VERSION_1 = 0x80010000 + VERSION_1 = -2147418112 + + TYPE_MASK = 0x000000ff + + def __init__(self, trans, strictRead=False, strictWrite=True): + TProtocolBase.__init__(self, trans) + self.strictRead = strictRead + self.strictWrite = strictWrite + + def writeMessageBegin(self, name, type, seqid): + if self.strictWrite: + self.writeI32(TBinaryProtocol.VERSION_1 | type) + self.writeString(name) + self.writeI32(seqid) + else: + self.writeString(name) + self.writeByte(type) + self.writeI32(seqid) + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + pass + + def writeStructEnd(self): + pass + + def writeFieldBegin(self, name, type, id): + self.writeByte(type) + self.writeI16(id) + + def writeFieldEnd(self): + pass + + def writeFieldStop(self): + self.writeByte(TType.STOP); + + def writeMapBegin(self, ktype, vtype, size): + self.writeByte(ktype) + self.writeByte(vtype) + self.writeI32(size) + + def writeMapEnd(self): + pass + + def writeListBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeListEnd(self): + pass + + def writeSetBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeSetEnd(self): + pass + + def writeBool(self, bool): + if bool: + self.writeByte(1) + else: + self.writeByte(0) + + def writeByte(self, byte): + buff = pack("!b", byte) + self.trans.write(buff) + + def writeI16(self, i16): + buff = pack("!h", i16) + self.trans.write(buff) + + def writeI32(self, i32): + buff = pack("!i", i32) + self.trans.write(buff) + + def writeI64(self, i64): + buff = pack("!q", i64) + self.trans.write(buff) + + def writeDouble(self, dub): + buff = pack("!d", dub) + self.trans.write(buff) + + def writeString(self, str): + self.writeI32(len(str)) + self.trans.write(str) + + def readMessageBegin(self): + sz = self.readI32() + if sz < 0: + version = sz & TBinaryProtocol.VERSION_MASK + if version != TBinaryProtocol.VERSION_1: + raise TProtocolException(TProtocolException.BAD_VERSION, 'Bad version in readMessageBegin: %d' % (sz)) + type = sz & TBinaryProtocol.TYPE_MASK + name = self.readString() + seqid = self.readI32() + else: + if self.strictRead: + raise TProtocolException(TProtocolException.BAD_VERSION, 'No protocol version header') + name = self.trans.readAll(sz) + type = self.readByte() + seqid = self.readI32() + return (name, type, seqid) + + def readMessageEnd(self): + pass + + def readStructBegin(self): + pass + + def readStructEnd(self): + pass + + def readFieldBegin(self): + type = self.readByte() + if type == TType.STOP: + return (None, type, 0) + id = self.readI16() + return (None, type, id) + + def readFieldEnd(self): + pass + + def readMapBegin(self): + ktype = self.readByte() + vtype = self.readByte() + size = self.readI32() + return (ktype, vtype, size) + + def readMapEnd(self): + pass + + def readListBegin(self): + etype = self.readByte() + size = self.readI32() + return (etype, size) + + def readListEnd(self): + pass + + def readSetBegin(self): + etype = self.readByte() + size = self.readI32() + return (etype, size) + + def readSetEnd(self): + pass + + def readBool(self): + byte = self.readByte() + if byte == 0: + return False + return True + + def readByte(self): + buff = self.trans.readAll(1) + val, = unpack('!b', buff) + return val + + def readI16(self): + buff = self.trans.readAll(2) + val, = unpack('!h', buff) + return val + + def readI32(self): + buff = self.trans.readAll(4) + val, = unpack('!i', buff) + return val + + def readI64(self): + buff = self.trans.readAll(8) + val, = unpack('!q', buff) + return val + + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('!d', buff) + return val + + def readString(self): + len = self.readI32() + str = self.trans.readAll(len) + return str + + +class TBinaryProtocolFactory: + def __init__(self, strictRead=False, strictWrite=True): + self.strictRead = strictRead + self.strictWrite = strictWrite + + def getProtocol(self, trans): + prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) + return prot + + +class TBinaryProtocolAccelerated(TBinaryProtocol): + + """C-Accelerated version of TBinaryProtocol. + + This class does not override any of TBinaryProtocol's methods, + but the generated code recognizes it directly and will call into + our C module to do the encoding, bypassing this object entirely. + We inherit from TBinaryProtocol so that the normal TBinaryProtocol + encoding can happen if the fastbinary module doesn't work for some + reason. (TODO(dreiss): Make this happen sanely in more cases.) + + In order to take advantage of the C module, just use + TBinaryProtocolAccelerated instead of TBinaryProtocol. + + NOTE: This code was contributed by an external developer. + The internal Thrift team has reviewed and tested it, + but we cannot guarantee that it is production-ready. + Please feel free to report bugs and/or success stories + to the public mailing list. + """ + + pass + + +class TBinaryProtocolAcceleratedFactory: + def getProtocol(self, trans): + return TBinaryProtocolAccelerated(trans) diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py new file mode 100644 index 000000000..be3cb1403 --- /dev/null +++ b/lib/py/src/protocol/TProtocol.py @@ -0,0 +1,205 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * + +class TProtocolException(TException): + + """Custom Protocol Exception class""" + + UNKNOWN = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + +class TProtocolBase: + + """Base class for Thrift protocol driver.""" + + def __init__(self, trans): + self.trans = trans + + def writeMessageBegin(self, name, type, seqid): + pass + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + pass + + def writeStructEnd(self): + pass + + def writeFieldBegin(self, name, type, id): + pass + + def writeFieldEnd(self): + pass + + def writeFieldStop(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + pass + + def writeMapEnd(self): + pass + + def writeListBegin(self, etype, size): + pass + + def writeListEnd(self): + pass + + def writeSetBegin(self, etype, size): + pass + + def writeSetEnd(self): + pass + + def writeBool(self, bool): + pass + + def writeByte(self, byte): + pass + + def writeI16(self, i16): + pass + + def writeI32(self, i32): + pass + + def writeI64(self, i64): + pass + + def writeDouble(self, dub): + pass + + def writeString(self, str): + pass + + def readMessageBegin(self): + pass + + def readMessageEnd(self): + pass + + def readStructBegin(self): + pass + + def readStructEnd(self): + pass + + def readFieldBegin(self): + pass + + def readFieldEnd(self): + pass + + def readMapBegin(self): + pass + + def readMapEnd(self): + pass + + def readListBegin(self): + pass + + def readListEnd(self): + pass + + def readSetBegin(self): + pass + + def readSetEnd(self): + pass + + def readBool(self): + pass + + def readByte(self): + pass + + def readI16(self): + pass + + def readI32(self): + pass + + def readI64(self): + pass + + def readDouble(self): + pass + + def readString(self): + pass + + def skip(self, type): + if type == TType.STOP: + return + elif type == TType.BOOL: + self.readBool() + elif type == TType.BYTE: + self.readByte() + elif type == TType.I16: + self.readI16() + elif type == TType.I32: + self.readI32() + elif type == TType.I64: + self.readI64() + elif type == TType.DOUBLE: + self.readDouble() + elif type == TType.STRING: + self.readString() + elif type == TType.STRUCT: + name = self.readStructBegin() + while True: + (name, type, id) = self.readFieldBegin() + if type == TType.STOP: + break + self.skip(type) + self.readFieldEnd() + self.readStructEnd() + elif type == TType.MAP: + (ktype, vtype, size) = self.readMapBegin() + for i in range(size): + self.skip(ktype) + self.skip(vtype) + self.readMapEnd() + elif type == TType.SET: + (etype, size) = self.readSetBegin() + for i in range(size): + self.skip(etype) + self.readSetEnd() + elif type == TType.LIST: + (etype, size) = self.readListBegin() + for i in range(size): + self.skip(etype) + self.readListEnd() + +class TProtocolFactory: + def getProtocol(self, trans): + pass diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py new file mode 100644 index 000000000..01bfe18e5 --- /dev/null +++ b/lib/py/src/protocol/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary'] diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c new file mode 100644 index 000000000..67b215a83 --- /dev/null +++ b/lib/py/src/protocol/fastbinary.c @@ -0,0 +1,1203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <Python.h> +#include "cStringIO.h" +#include <stdbool.h> +#include <stdint.h> +#include <netinet/in.h> + +/* Fix endianness issues on Solaris */ +#if defined (__SVR4) && defined (__sun) + #if defined(__i386) && !defined(__i386__) + #define __i386__ + #endif + + #ifndef BIG_ENDIAN + #define BIG_ENDIAN (4321) + #endif + #ifndef LITTLE_ENDIAN + #define LITTLE_ENDIAN (1234) + #endif + + /* I386 is LE, even on Solaris */ + #if !defined(BYTE_ORDER) && defined(__i386__) + #define BYTE_ORDER LITTLE_ENDIAN + #endif +#endif + +// TODO(dreiss): defval appears to be unused. Look into removing it. +// TODO(dreiss): Make parse_spec_args recursive, and cache the output +// permanently in the object. (Malloc and orphan.) +// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? +// Can cStringIO let us work with a BufferedTransport? +// TODO(dreiss): Don't ignore the rv from cwrite (maybe). + +/* ====== BEGIN UTILITIES ====== */ + +#define INIT_OUTBUF_SIZE 128 + +// Stolen out of TProtocol.h. +// It would be a huge pain to have both get this from one place. +typedef enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +} TType; + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +# define __BYTE_ORDER BYTE_ORDER +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __BIG_ENDIAN BIG_ENDIAN +# else +# error "Cannot determine endianness" +# endif +#endif + +// Same comment as the enum. Sorry. +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# if defined(__GNUC__) && defined(__GLIBC__) +# include <byteswap.h> +# define ntohll(n) bswap_64(n) +# define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +// Doing a benchmark shows that interning actually makes a difference, amazingly. +#define INTERN_STRING(value) _intern_ ## value + +#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) +#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) + +// Py_ssize_t was not defined before Python 2.5 +#if (PY_VERSION_HEX < 0x02050000) +typedef int Py_ssize_t; +#endif + +/** + * A cache of the spec_args for a set or list, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType element_type; + PyObject* typeargs; +} SetListTypeArgs; + +/** + * A cache of the spec_args for a map, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType ktag; + TType vtag; + PyObject* ktypeargs; + PyObject* vtypeargs; +} MapTypeArgs; + +/** + * A cache of the spec_args for a struct, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + PyObject* klass; + PyObject* spec; +} StructTypeArgs; + +/** + * A cache of the item spec from a struct specification, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + int tag; + TType type; + PyObject* attrname; + PyObject* typeargs; + PyObject* defval; +} StructItemSpec; + +/** + * A cache of the two key attributes of a CReadableTransport, + * so we don't have to keep calling PyObject_GetAttr. + */ +typedef struct { + PyObject* stringiobuf; + PyObject* refill_callable; +} DecodeBuffer; + +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_buf); +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_refill); + +static inline bool +check_ssize_t_32(Py_ssize_t len) { + // error from getting the int + if (INT_CONV_ERROR_OCCURRED(len)) { + return false; + } + if (!CHECK_RANGE(len, 0, INT32_MAX)) { + PyErr_SetString(PyExc_OverflowError, "string size out of range"); + return false; + } + return true; +} + +static inline bool +parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { + long val = PyInt_AsLong(o); + + if (INT_CONV_ERROR_OCCURRED(val)) { + return false; + } + if (!CHECK_RANGE(val, min, max)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + *ret = (int32_t) val; + return true; +} + + +/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ + +static bool +parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); + return false; + } + + dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { + return false; + } + + dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static bool +parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 4) { + PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); + return false; + } + + dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { + return false; + } + + dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); + if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { + return false; + } + + dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); + dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + + return true; +} + +static bool +parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); + return false; + } + + dest->klass = PyTuple_GET_ITEM(typeargs, 0); + dest->spec = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static int +parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { + + // i'd like to use ParseArgs here, but it seems to be a bottleneck. + if (PyTuple_Size(spec_tuple) != 5) { + PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); + return false; + } + + dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->tag)) { + return false; + } + + dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); + if (INT_CONV_ERROR_OCCURRED(dest->type)) { + return false; + } + + dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); + dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); + dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); + return true; +} + +/* ====== END UTILITIES ====== */ + + +/* ====== BEGIN WRITING FUNCTIONS ====== */ + +/* --- LOW-LEVEL WRITING FUNCTIONS --- */ + +static void writeByte(PyObject* outbuf, int8_t val) { + int8_t net = val; + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); +} + +static void writeI16(PyObject* outbuf, int16_t val) { + int16_t net = (int16_t)htons(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); +} + +static void writeI32(PyObject* outbuf, int32_t val) { + int32_t net = (int32_t)htonl(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); +} + +static void writeI64(PyObject* outbuf, int64_t val) { + int64_t net = (int64_t)htonll(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); +} + +static void writeDouble(PyObject* outbuf, double dub) { + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t t; + } transfer; + transfer.f = dub; + writeI64(outbuf, transfer.t); +} + + +/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ + +static int +output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { + /* + * Refcounting Strategy: + * + * We assume that elements of the thrift_spec tuple are not going to be + * mutated, so we don't ref count those at all. Other than that, we try to + * keep a reference to all the user-created objects while we work with them. + * output_val assumes that a reference is already held. The *caller* is + * responsible for handling references + */ + + switch (type) { + + case T_BOOL: { + int v = PyObject_IsTrue(value); + if (v == -1) { + return false; + } + + writeByte(output, (int8_t) v); + break; + } + case T_I08: { + int32_t val; + + if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { + return false; + } + + writeByte(output, (int8_t) val); + break; + } + case T_I16: { + int32_t val; + + if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { + return false; + } + + writeI16(output, (int16_t) val); + break; + } + case T_I32: { + int32_t val; + + if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { + return false; + } + + writeI32(output, val); + break; + } + case T_I64: { + int64_t nval = PyLong_AsLongLong(value); + + if (INT_CONV_ERROR_OCCURRED(nval)) { + return false; + } + + if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + writeI64(output, nval); + break; + } + + case T_DOUBLE: { + double nval = PyFloat_AsDouble(value); + if (nval == -1.0 && PyErr_Occurred()) { + return false; + } + + writeDouble(output, nval); + break; + } + + case T_STRING: { + Py_ssize_t len = PyString_Size(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeI32(output, (int32_t) len); + PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); + break; + } + + case T_LIST: + case T_SET: { + Py_ssize_t len; + SetListTypeArgs parsedargs; + PyObject *item; + PyObject *iterator; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return false; + } + + len = PyObject_Length(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeByte(output, parsedargs.element_type); + writeI32(output, (int32_t) len); + + iterator = PyObject_GetIter(value); + if (iterator == NULL) { + return false; + } + + while ((item = PyIter_Next(iterator))) { + if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { + Py_DECREF(item); + Py_DECREF(iterator); + return false; + } + Py_DECREF(item); + } + + Py_DECREF(iterator); + + if (PyErr_Occurred()) { + return false; + } + + break; + } + + case T_MAP: { + PyObject *k, *v; + Py_ssize_t pos = 0; + Py_ssize_t len; + + MapTypeArgs parsedargs; + + len = PyDict_Size(value); + if (!check_ssize_t_32(len)) { + return false; + } + + if (!parse_map_args(&parsedargs, typeargs)) { + return false; + } + + writeByte(output, parsedargs.ktag); + writeByte(output, parsedargs.vtag); + writeI32(output, len); + + // TODO(bmaurer): should support any mapping, not just dicts + while (PyDict_Next(value, &pos, &k, &v)) { + // TODO(dreiss): Think hard about whether these INCREFs actually + // turn any unsafe scenarios into safe scenarios. + Py_INCREF(k); + Py_INCREF(v); + + if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) + || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { + Py_DECREF(k); + Py_DECREF(v); + return false; + } + Py_DECREF(k); + Py_DECREF(v); + } + break; + } + + // TODO(dreiss): Consider breaking this out as a function + // the way we did for decode_struct. + case T_STRUCT: { + StructTypeArgs parsedargs; + Py_ssize_t nspec; + Py_ssize_t i; + + if (!parse_struct_args(&parsedargs, typeargs)) { + return false; + } + + nspec = PyTuple_Size(parsedargs.spec); + + if (nspec == -1) { + return false; + } + + for (i = 0; i < nspec; i++) { + StructItemSpec parsedspec; + PyObject* spec_tuple; + PyObject* instval = NULL; + + spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); + if (spec_tuple == Py_None) { + continue; + } + + if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { + return false; + } + + instval = PyObject_GetAttr(value, parsedspec.attrname); + + if (!instval) { + return false; + } + + if (instval == Py_None) { + Py_DECREF(instval); + continue; + } + + writeByte(output, (int8_t) parsedspec.type); + writeI16(output, parsedspec.tag); + + if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { + Py_DECREF(instval); + return false; + } + + Py_DECREF(instval); + } + + writeByte(output, (int8_t)T_STOP); + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return true; +} + + +/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ + +static PyObject * +encode_binary(PyObject *self, PyObject *args) { + PyObject* enc_obj; + PyObject* type_args; + PyObject* buf; + PyObject* ret = NULL; + + if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { + return NULL; + } + + buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); + if (output_val(buf, enc_obj, T_STRUCT, type_args)) { + ret = PycStringIO->cgetvalue(buf); + } + + Py_DECREF(buf); + return ret; +} + +/* ====== END WRITING FUNCTIONS ====== */ + + +/* ====== BEGIN READING FUNCTIONS ====== */ + +/* --- LOW-LEVEL READING FUNCTIONS --- */ + +static void +free_decodebuf(DecodeBuffer* d) { + Py_XDECREF(d->stringiobuf); + Py_XDECREF(d->refill_callable); +} + +static bool +decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { + dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); + if (!dest->stringiobuf) { + return false; + } + + if (!PycStringIO_InputCheck(dest->stringiobuf)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting stringio input"); + return false; + } + + dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); + + if(!dest->refill_callable) { + free_decodebuf(dest); + return false; + } + + if (!PyCallable_Check(dest->refill_callable)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting callable"); + return false; + } + + return true; +} + +static bool readBytes(DecodeBuffer* input, char** output, int len) { + int read; + + // TODO(dreiss): Don't fear the malloc. Think about taking a copy of + // the partial read instead of forcing the transport + // to prepend it to its buffer. + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + PyObject* newiobuf; + + // using building functions as this is a rare codepath + newiobuf = PyObject_CallFunction( + input->refill_callable, "s#i", *output, read, len, NULL); + if (newiobuf == NULL) { + return false; + } + + // must do this *AFTER* the call so that we don't deref the io buffer + Py_CLEAR(input->stringiobuf); + input->stringiobuf = newiobuf; + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + // TODO(dreiss): This could be a valid code path for big binary blobs. + PyErr_SetString(PyExc_TypeError, + "refill claimed to have refilled the buffer, but didn't!!"); + return false; + } + } +} + +static int8_t readByte(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int8_t))) { + return -1; + } + + return *(int8_t*) buf; +} + +static int16_t readI16(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int16_t))) { + return -1; + } + + return (int16_t) ntohs(*(int16_t*) buf); +} + +static int32_t readI32(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int32_t))) { + return -1; + } + return (int32_t) ntohl(*(int32_t*) buf); +} + + +static int64_t readI64(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int64_t))) { + return -1; + } + + return (int64_t) ntohll(*(int64_t*) buf); +} + +static double readDouble(DecodeBuffer* input) { + union { + int64_t f; + double t; + } transfer; + + transfer.f = readI64(input); + if (transfer.f == -1) { + return -1; + } + return transfer.t; +} + +static bool +checkTypeByte(DecodeBuffer* input, TType expected) { + TType got = readByte(input); + if (INT_CONV_ERROR_OCCURRED(got)) { + return false; + } + + if (expected != got) { + PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); + return false; + } + return true; +} + +static bool +skip(DecodeBuffer* input, TType type) { +#define SKIPBYTES(n) \ + do { \ + if (!readBytes(input, &dummy_buf, (n))) { \ + return false; \ + } \ + } while(0) + + char* dummy_buf; + + switch (type) { + + case T_BOOL: + case T_I08: SKIPBYTES(1); break; + case T_I16: SKIPBYTES(2); break; + case T_I32: SKIPBYTES(4); break; + case T_I64: + case T_DOUBLE: SKIPBYTES(8); break; + + case T_STRING: { + // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. + int len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + SKIPBYTES(len); + break; + } + + case T_LIST: + case T_SET: { + TType etype; + int len, i; + + etype = readByte(input); + if (etype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!skip(input, etype)) { + return false; + } + } + break; + } + + case T_MAP: { + TType ktype, vtype; + int len, i; + + ktype = readByte(input); + if (ktype == -1) { + return false; + } + + vtype = readByte(input); + if (vtype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!(skip(input, ktype) && skip(input, vtype))) { + return false; + } + } + break; + } + + case T_STRUCT: { + while (true) { + TType type; + + type = readByte(input); + if (type == -1) { + return false; + } + + if (type == T_STOP) + break; + + SKIPBYTES(2); // tag + if (!skip(input, type)) { + return false; + } + } + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return true; + +#undef SKIPBYTES +} + + +/* --- HELPER FUNCTION FOR DECODE_VAL --- */ + +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); + +static bool +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { + int spec_seq_len = PyTuple_Size(spec_seq); + if (spec_seq_len == -1) { + return false; + } + + while (true) { + TType type; + int16_t tag; + PyObject* item_spec; + PyObject* fieldval = NULL; + StructItemSpec parsedspec; + + type = readByte(input); + if (type == -1) { + return false; + } + if (type == T_STOP) { + break; + } + tag = readI16(input); + if (INT_CONV_ERROR_OCCURRED(tag)) { + return false; + } + if (tag >= 0 && tag < spec_seq_len) { + item_spec = PyTuple_GET_ITEM(spec_seq, tag); + } else { + item_spec = Py_None; + } + + if (item_spec == Py_None) { + if (!skip(input, type)) { + return false; + } else { + continue; + } + } + + if (!parse_struct_item_spec(&parsedspec, item_spec)) { + return false; + } + if (parsedspec.type != type) { + if (!skip(input, type)) { + PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); + return false; + } else { + continue; + } + } + + fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); + if (fieldval == NULL) { + return false; + } + + if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { + Py_DECREF(fieldval); + return false; + } + Py_DECREF(fieldval); + } + return true; +} + + +/* --- MAIN RECURSIVE INPUT FUCNTION --- */ + +// Returns a new reference. +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { + switch (type) { + + case T_BOOL: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + switch (v) { + case 0: Py_RETURN_FALSE; + case 1: Py_RETURN_TRUE; + // Don't laugh. This is a potentially serious issue. + default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; + } + break; + } + case T_I08: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + return PyInt_FromLong(v); + } + case T_I16: { + int16_t v = readI16(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + case T_I32: { + int32_t v = readI32(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + + case T_I64: { + int64_t v = readI64(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + // TODO(dreiss): Find out if we can take this fastpath always when + // sizeof(long) == sizeof(long long). + if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { + return PyInt_FromLong((long) v); + } + + return PyLong_FromLongLong(v); + } + + case T_DOUBLE: { + double v = readDouble(input); + if (v == -1.0 && PyErr_Occurred()) { + return false; + } + return PyFloat_FromDouble(v); + } + + case T_STRING: { + Py_ssize_t len = readI32(input); + char* buf; + if (!readBytes(input, &buf, len)) { + return NULL; + } + + return PyString_FromStringAndSize(buf, len); + } + + case T_LIST: + case T_SET: { + SetListTypeArgs parsedargs; + int32_t len; + PyObject* ret = NULL; + int i; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.element_type)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return NULL; + } + + ret = PyList_New(len); + if (!ret) { + return NULL; + } + + for (i = 0; i < len; i++) { + PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); + if (!item) { + Py_DECREF(ret); + return NULL; + } + PyList_SET_ITEM(ret, i, item); + } + + // TODO(dreiss): Consider biting the bullet and making two separate cases + // for list and set, avoiding this post facto conversion. + if (type == T_SET) { + PyObject* setret; +#if (PY_VERSION_HEX < 0x02050000) + // hack needed for older versions + setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); +#else + // official version + setret = PySet_New(ret); +#endif + Py_DECREF(ret); + return setret; + } + return ret; + } + + case T_MAP: { + int32_t len; + int i; + MapTypeArgs parsedargs; + PyObject* ret = NULL; + + if (!parse_map_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.ktag)) { + return NULL; + } + if (!checkTypeByte(input, parsedargs.vtag)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + ret = PyDict_New(); + if (!ret) { + goto error; + } + + for (i = 0; i < len; i++) { + PyObject* k = NULL; + PyObject* v = NULL; + k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); + if (k == NULL) { + goto loop_error; + } + v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); + if (v == NULL) { + goto loop_error; + } + if (PyDict_SetItem(ret, k, v) == -1) { + goto loop_error; + } + + Py_DECREF(k); + Py_DECREF(v); + continue; + + // Yuck! Destructors, anyone? + loop_error: + Py_XDECREF(k); + Py_XDECREF(v); + goto error; + } + + return ret; + + error: + Py_XDECREF(ret); + return NULL; + } + + case T_STRUCT: { + StructTypeArgs parsedargs; + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + PyObject* ret = PyObject_CallObject(parsedargs.klass, NULL); + if (!ret) { + return NULL; + } + + if (!decode_struct(input, ret, parsedargs.spec)) { + Py_DECREF(ret); + return NULL; + } + + return ret; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return NULL; + } +} + + +/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ + +static PyObject* +decode_binary(PyObject *self, PyObject *args) { + PyObject* output_obj = NULL; + PyObject* transport = NULL; + PyObject* typeargs = NULL; + StructTypeArgs parsedargs; + DecodeBuffer input = {}; + + if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { + return NULL; + } + + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!decode_buffer_from_obj(&input, transport)) { + return NULL; + } + + if (!decode_struct(&input, output_obj, parsedargs.spec)) { + free_decodebuf(&input); + return NULL; + } + + free_decodebuf(&input); + + Py_RETURN_NONE; +} + +/* ====== END READING FUNCTIONS ====== */ + + +/* -- PYTHON MODULE SETUP STUFF --- */ + +static PyMethodDef ThriftFastBinaryMethods[] = { + + {"encode_binary", encode_binary, METH_VARARGS, ""}, + {"decode_binary", decode_binary, METH_VARARGS, ""}, + + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +PyMODINIT_FUNC +initfastbinary(void) { +#define INIT_INTERN_STRING(value) \ + do { \ + INTERN_STRING(value) = PyString_InternFromString(#value); \ + if(!INTERN_STRING(value)) return; \ + } while(0) + + INIT_INTERN_STRING(cstringio_buf); + INIT_INTERN_STRING(cstringio_refill); +#undef INIT_INTERN_STRING + + PycString_IMPORT; + if (PycStringIO == NULL) return; + + (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); +} diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/server/THttpServer.py new file mode 100644 index 000000000..21fc31418 --- /dev/null +++ b/lib/py/src/server/THttpServer.py @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import BaseHTTPServer + +from thrift.server import TServer +from thrift.transport import TTransport + +class THttpServer(TServer.TServer): + """A simple HTTP-based Thrift server + + This class is not very performant, but it is useful (for example) for + acting as a mock version of an Apache-based PHP Thrift endpoint.""" + + def __init__(self, processor, server_address, + inputProtocolFactory, outputProtocolFactory = None): + """Set up protocol factories and HTTP server. + + See BaseHTTPServer for server_address. + See TServer for protocol factories.""" + + if outputProtocolFactory is None: + outputProtocolFactory = inputProtocolFactory + + TServer.TServer.__init__(self, processor, None, None, None, + inputProtocolFactory, outputProtocolFactory) + + thttpserver = self + + class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): + def do_POST(self): + # Don't care about the request path. + self.send_response(200) + self.send_header("content-type", "application/x-thrift") + self.end_headers() + + itrans = TTransport.TFileObjectTransport(self.rfile) + otrans = TTransport.TFileObjectTransport(self.wfile) + iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) + oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) + thttpserver.processor.process(iprot, oprot) + otrans.flush() + + self.httpd = BaseHTTPServer.HTTPServer(server_address, RequestHander) + + def serve(self): + self.httpd.serve_forever() diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py new file mode 100644 index 000000000..deec708a1 --- /dev/null +++ b/lib/py/src/server/TNonblockingServer.py @@ -0,0 +1,309 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Implementation of non-blocking server. + +The main idea of the server is reciving and sending requests +only from main thread. + +It also makes thread pool server in tasks terms, not connections. +""" +import threading +import socket +import Queue +import select +import struct +import logging + +from thrift.transport import TTransport +from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory + +__all__ = ['TNonblockingServer'] + +class Worker(threading.Thread): + """Worker is a small helper to process incoming connection.""" + def __init__(self, queue): + threading.Thread.__init__(self) + self.queue = queue + + def run(self): + """Process queries from task queue, stop if processor is None.""" + while True: + try: + processor, iprot, oprot, otrans, callback = self.queue.get() + if processor is None: + break + processor.process(iprot, oprot) + callback(True, otrans.getvalue()) + except Exception: + logging.exception("Exception while processing request") + callback(False, '') + +WAIT_LEN = 0 +WAIT_MESSAGE = 1 +WAIT_PROCESS = 2 +SEND_ANSWER = 3 +CLOSED = 4 + +def locked(func): + "Decorator which locks self.lock." + def nested(self, *args, **kwargs): + self.lock.acquire() + try: + return func(self, *args, **kwargs) + finally: + self.lock.release() + return nested + +def socket_exception(func): + "Decorator close object on socket.error." + def read(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except socket.error: + self.close() + return read + +class Connection: + """Basic class is represented connection. + + It can be in state: + WAIT_LEN --- connection is reading request len. + WAIT_MESSAGE --- connection is reading request. + WAIT_PROCESS --- connection has just read whole request and + waits for call ready routine. + SEND_ANSWER --- connection is sending answer string (including length + of answer). + CLOSED --- socket was closed and connection should be deleted. + """ + def __init__(self, new_socket, wake_up): + self.socket = new_socket + self.socket.setblocking(False) + self.status = WAIT_LEN + self.len = 0 + self.message = '' + self.lock = threading.Lock() + self.wake_up = wake_up + + def _read_len(self): + """Reads length of request. + + It's really paranoic routine and it may be replaced by + self.socket.recv(4).""" + read = self.socket.recv(4 - len(self.message)) + if len(read) == 0: + # if we read 0 bytes and self.message is empty, it means client close + # connection + if len(self.message) != 0: + logging.error("can't read frame size from socket") + self.close() + return + self.message += read + if len(self.message) == 4: + self.len, = struct.unpack('!i', self.message) + if self.len < 0: + logging.error("negative frame size, it seems client"\ + " doesn't use FramedTransport") + self.close() + elif self.len == 0: + logging.error("empty frame, it's really strange") + self.close() + else: + self.message = '' + self.status = WAIT_MESSAGE + + @socket_exception + def read(self): + """Reads data from stream and switch state.""" + assert self.status in (WAIT_LEN, WAIT_MESSAGE) + if self.status == WAIT_LEN: + self._read_len() + # go back to the main loop here for simplicity instead of + # falling through, even though there is a good chance that + # the message is already available + elif self.status == WAIT_MESSAGE: + read = self.socket.recv(self.len - len(self.message)) + if len(read) == 0: + logging.error("can't read frame from socket (get %d of %d bytes)" % + (len(self.message), self.len)) + self.close() + return + self.message += read + if len(self.message) == self.len: + self.status = WAIT_PROCESS + + @socket_exception + def write(self): + """Writes data from socket and switch state.""" + assert self.status == SEND_ANSWER + sent = self.socket.send(self.message) + if sent == len(self.message): + self.status = WAIT_LEN + self.message = '' + self.len = 0 + else: + self.message = self.message[sent:] + + @locked + def ready(self, all_ok, message): + """Callback function for switching state and waking up main thread. + + This function is the only function witch can be called asynchronous. + + The ready can switch Connection to three states: + WAIT_LEN if request was oneway. + SEND_ANSWER if request was processed in normal way. + CLOSED if request throws unexpected exception. + + The one wakes up main thread. + """ + assert self.status == WAIT_PROCESS + if not all_ok: + self.close() + self.wake_up() + return + self.len = '' + self.message = struct.pack('!i', len(message)) + message + if len(message) == 0: + # it was a oneway request, do not write answer + self.status = WAIT_LEN + else: + self.status = SEND_ANSWER + self.wake_up() + + @locked + def is_writeable(self): + "Returns True if connection should be added to write list of select." + return self.status == SEND_ANSWER + + # it's not necessary, but... + @locked + def is_readable(self): + "Returns True if connection should be added to read list of select." + return self.status in (WAIT_LEN, WAIT_MESSAGE) + + @locked + def is_closed(self): + "Returns True if connection is closed." + return self.status == CLOSED + + def fileno(self): + "Returns the file descriptor of the associated socket." + return self.socket.fileno() + + def close(self): + "Closes connection" + self.status = CLOSED + self.socket.close() + +class TNonblockingServer: + """Non-blocking server.""" + def __init__(self, processor, lsocket, inputProtocolFactory=None, + outputProtocolFactory=None, threads=10): + self.processor = processor + self.socket = lsocket + self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() + self.out_protocol = outputProtocolFactory or self.in_protocol + self.threads = int(threads) + self.clients = {} + self.tasks = Queue.Queue() + self._read, self._write = socket.socketpair() + self.prepared = False + + def setNumThreads(self, num): + """Set the number of worker threads that should be created.""" + # implement ThreadPool interface + assert not self.prepared, "You can't change number of threads for working server" + self.threads = num + + def prepare(self): + """Prepares server for serve requests.""" + self.socket.listen() + for _ in xrange(self.threads): + thread = Worker(self.tasks) + thread.setDaemon(True) + thread.start() + self.prepared = True + + def wake_up(self): + """Wake up main thread. + + The server usualy waits in select call in we should terminate one. + The simplest way is using socketpair. + + Select always wait to read from the first socket of socketpair. + + In this case, we can just write anything to the second socket from + socketpair.""" + self._write.send('1') + + def _select(self): + """Does select on open connections.""" + readable = [self.socket.handle.fileno(), self._read.fileno()] + writable = [] + for i, connection in self.clients.items(): + if connection.is_readable(): + readable.append(connection.fileno()) + if connection.is_writeable(): + writable.append(connection.fileno()) + if connection.is_closed(): + del self.clients[i] + return select.select(readable, writable, readable) + + def handle(self): + """Handle requests. + + WARNING! You must call prepare BEFORE calling handle. + """ + assert self.prepared, "You have to call prepare before handle" + rset, wset, xset = self._select() + for readable in rset: + if readable == self._read.fileno(): + # don't care i just need to clean readable flag + self._read.recv(1024) + elif readable == self.socket.handle.fileno(): + client = self.socket.accept().handle + self.clients[client.fileno()] = Connection(client, self.wake_up) + else: + connection = self.clients[readable] + connection.read() + if connection.status == WAIT_PROCESS: + itransport = TTransport.TMemoryBuffer(connection.message) + otransport = TTransport.TMemoryBuffer() + iprot = self.in_protocol.getProtocol(itransport) + oprot = self.out_protocol.getProtocol(otransport) + self.tasks.put([self.processor, iprot, oprot, + otransport, connection.ready]) + for writeable in wset: + self.clients[writeable].write() + for oob in xset: + self.clients[oob].close() + del self.clients[oob] + + def close(self): + """Closes the server.""" + for _ in xrange(self.threads): + self.tasks.put([None, None, None, None, None]) + self.socket.close() + self.prepared = False + + def serve(self): + """Serve forever.""" + self.prepare() + while True: + self.handle() diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py new file mode 100644 index 000000000..615291114 --- /dev/null +++ b/lib/py/src/server/TServer.py @@ -0,0 +1,270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import logging +import sys +import os +import traceback +import threading +import Queue + +from thrift.Thrift import TProcessor +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol + +class TServer: + + """Base interface for a server, which must have a serve method.""" + + """ 3 constructors for all servers: + 1) (processor, serverTransport) + 2) (processor, serverTransport, transportFactory, protocolFactory) + 3) (processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory)""" + def __init__(self, *args): + if (len(args) == 2): + self.__initArgs__(args[0], args[1], + TTransport.TTransportFactoryBase(), + TTransport.TTransportFactoryBase(), + TBinaryProtocol.TBinaryProtocolFactory(), + TBinaryProtocol.TBinaryProtocolFactory()) + elif (len(args) == 4): + self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) + elif (len(args) == 6): + self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) + + def __initArgs__(self, processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory): + self.processor = processor + self.serverTransport = serverTransport + self.inputTransportFactory = inputTransportFactory + self.outputTransportFactory = outputTransportFactory + self.inputProtocolFactory = inputProtocolFactory + self.outputProtocolFactory = outputProtocolFactory + + def serve(self): + pass + +class TSimpleServer(TServer): + + """Simple single-threaded server that just pumps around one transport.""" + + def __init__(self, *args): + TServer.__init__(self, *args) + + def serve(self): + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + +class TThreadedServer(TServer): + + """Threaded server that spawns a new thread per each connection.""" + + def __init__(self, *args): + TServer.__init__(self, *args) + + def serve(self): + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + t = threading.Thread(target = self.handle, args=(client,)) + t.start() + except KeyboardInterrupt: + raise + except Exception, x: + logging.exception(x) + + def handle(self, client): + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + +class TThreadPoolServer(TServer): + + """Server with a fixed size pool of threads which service requests.""" + + def __init__(self, *args): + TServer.__init__(self, *args) + self.clients = Queue.Queue() + self.threads = 10 + + def setNumThreads(self, num): + """Set the number of worker threads that should be created""" + self.threads = num + + def serveThread(self): + """Loop around getting clients from the shared queue and process them.""" + while True: + try: + client = self.clients.get() + self.serveClient(client) + except Exception, x: + logging.exception(x) + + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + + def serve(self): + """Start a fixed number of worker threads and put client into a queue""" + for i in range(self.threads): + try: + t = threading.Thread(target = self.serveThread) + t.start() + except Exception, x: + logging.exception(x) + + # Pump the socket for clients + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + self.clients.put(client) + except Exception, x: + logging.exception(x) + + +class TForkingServer(TServer): + + """A Thrift server that forks a new process for each request""" + """ + This is more scalable than the threaded server as it does not cause + GIL contention. + + Note that this has different semantics from the threading server. + Specifically, updates to shared variables will no longer be shared. + It will also not work on windows. + + This code is heavily inspired by SocketServer.ForkingMixIn in the + Python stdlib. + """ + + def __init__(self, *args): + TServer.__init__(self, *args) + self.children = [] + + def serve(self): + def try_close(file): + try: + file.close() + except IOError, e: + logging.warning(e, exc_info=True) + + + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + try: + pid = os.fork() + + if pid: # parent + # add before collect, otherwise you race w/ waitpid + self.children.append(pid) + self.collect_children() + + # Parent must close socket or the connection may not get + # closed promptly + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + try_close(itrans) + try_close(otrans) + else: + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + ecode = 0 + try: + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, e: + logging.exception(e) + ecode = 1 + finally: + try_close(itrans) + try_close(otrans) + + os._exit(ecode) + + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + + def collect_children(self): + while self.children: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except os.error: + pid = None + + if pid: + self.children.remove(pid) + else: + break + + diff --git a/lib/py/src/server/__init__.py b/lib/py/src/server/__init__.py new file mode 100644 index 000000000..1bf6e254e --- /dev/null +++ b/lib/py/src/server/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TServer', 'TNonblockingServer'] diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/transport/THttpClient.py new file mode 100644 index 000000000..5086032be --- /dev/null +++ b/lib/py/src/transport/THttpClient.py @@ -0,0 +1,100 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TTransport import * +from cStringIO import StringIO + +import urlparse +import httplib +import warnings + +class THttpClient(TTransportBase): + + """Http implementation of TTransport base.""" + + def __init__(self, uri_or_host, port=None, path=None): + """THttpClient supports two different types constructor parameters. + + THttpClient(host, port, path) - deprecated + THttpClient(uri) + + Only the second supports https.""" + + if port is not None: + warnings.warn("Please use the THttpClient('http://host:port/path') syntax", DeprecationWarning, stacklevel=2) + self.host = uri_or_host + self.port = port + assert path + self.path = path + self.scheme = 'http' + else: + parsed = urlparse.urlparse(uri_or_host) + self.scheme = parsed.scheme + assert self.scheme in ('http', 'https') + if self.scheme == 'http': + self.port = parsed.port or httplib.HTTP_PORT + elif self.scheme == 'https': + self.port = parsed.port or httplib.HTTPS_PORT + self.host = parsed.hostname + self.path = parsed.path + self.__wbuf = StringIO() + self.__http = None + + def open(self): + if self.scheme == 'http': + self.__http = httplib.HTTP(self.host, self.port) + else: + self.__http = httplib.HTTPS(self.host, self.port) + + def close(self): + self.__http.close() + self.__http = None + + def isOpen(self): + return self.__http != None + + def read(self, sz): + return self.__http.file.read(sz) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + if self.isOpen(): + self.close() + self.open(); + + # Pull data out of buffer + data = self.__wbuf.getvalue() + self.__wbuf = StringIO() + + # HTTP request + self.__http.putrequest('POST', self.path) + + # Write headers + self.__http.putheader('Host', self.host) + self.__http.putheader('Content-Type', 'application/x-thrift') + self.__http.putheader('Content-Length', str(len(data))) + self.__http.endheaders() + + # Write payload + self.__http.send(data) + + # Get reply to flush the request + self.code, self.message, self.headers = self.__http.getreply() diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py new file mode 100644 index 000000000..4645a023f --- /dev/null +++ b/lib/py/src/transport/TSocket.py @@ -0,0 +1,147 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TTransport import * +import os +import errno +import socket + +class TSocketBase(TTransportBase): + def _resolveAddr(self): + if self._unix_socket is not None: + return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] + else: + return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + + def close(self): + if self.handle: + self.handle.close() + self.handle = None + +class TSocket(TSocketBase): + """Socket implementation of TTransport base.""" + + def __init__(self, host='localhost', port=9090, unix_socket=None): + """Initialize a TSocket + + @param host(str) The host to connect to. + @param port(int) The (TCP) port to connect to. + @param unix_socket(str) The filename of a unix socket to connect to. + (host and port will be ignored.) + """ + + self.host = host + self.port = port + self.handle = None + self._unix_socket = unix_socket + self._timeout = None + + def setHandle(self, h): + self.handle = h + + def isOpen(self): + return self.handle != None + + def setTimeout(self, ms): + if ms is None: + self._timeout = None + else: + self._timeout = ms/1000.0 + + if (self.handle != None): + self.handle.settimeout(self._timeout) + + def open(self): + try: + res0 = self._resolveAddr() + for res in res0: + self.handle = socket.socket(res[0], res[1]) + self.handle.settimeout(self._timeout) + try: + self.handle.connect(res[4]) + except socket.error, e: + if res is not res0[-1]: + continue + else: + raise e + break + except socket.error, e: + if self._unix_socket: + message = 'Could not connect to socket %s' % self._unix_socket + else: + message = 'Could not connect to %s:%d' % (self.host, self.port) + raise TTransportException(TTransportException.NOT_OPEN, message) + + def read(self, sz): + buff = self.handle.recv(sz) + if len(buff) == 0: + raise TTransportException('TSocket read 0 bytes') + return buff + + def write(self, buff): + sent = 0 + have = len(buff) + while sent < have: + plus = self.handle.send(buff) + if plus == 0: + raise TTransportException('TSocket sent 0 bytes') + sent += plus + buff = buff[plus:] + + def flush(self): + pass + +class TServerSocket(TSocketBase, TServerTransportBase): + """Socket implementation of TServerTransport base.""" + + def __init__(self, port=9090, unix_socket=None): + self.host = None + self.port = port + self._unix_socket = unix_socket + self.handle = None + + def listen(self): + res0 = self._resolveAddr() + for res in res0: + if res[0] is socket.AF_INET6 or res is res0[-1]: + break + + # We need remove the old unix socket if the file exists and + # nobody is listening on it. + if self._unix_socket: + tmp = socket.socket(res[0], res[1]) + try: + tmp.connect(res[4]) + except socket.error, err: + eno, message = err.args + if eno == errno.ECONNREFUSED: + os.unlink(res[4]) + + self.handle = socket.socket(res[0], res[1]) + self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(self.handle, 'set_timeout'): + self.handle.set_timeout(None) + self.handle.bind(res[4]) + self.handle.listen(128) + + def accept(self): + client, addr = self.handle.accept() + result = TSocket() + result.setHandle(client) + return result diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py new file mode 100644 index 000000000..32553bddf --- /dev/null +++ b/lib/py/src/transport/TTransport.py @@ -0,0 +1,326 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from cStringIO import StringIO +from struct import pack,unpack +from thrift.Thrift import TException + +class TTransportException(TException): + + """Custom Transport Exception class""" + + UNKNOWN = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + +class TTransportBase: + + """Base class for Thrift transport layer.""" + + def isOpen(self): + pass + + def open(self): + pass + + def close(self): + pass + + def read(self, sz): + pass + + def readAll(self, sz): + buff = '' + have = 0 + while (have < sz): + chunk = self.read(sz-have) + have += len(chunk) + buff += chunk + + if len(chunk) == 0: + raise EOFError() + + return buff + + def write(self, buf): + pass + + def flush(self): + pass + +# This class should be thought of as an interface. +class CReadableTransport: + """base class for transports that are readable from C""" + + # TODO(dreiss): Think about changing this interface to allow us to use + # a (Python, not c) StringIO instead, because it allows + # you to write after reading. + + # NOTE: This is a classic class, so properties will NOT work + # correctly for setting. + @property + def cstringio_buf(self): + """A cStringIO buffer that contains the current chunk we are reading.""" + pass + + def cstringio_refill(self, partialread, reqlen): + """Refills cstringio_buf. + + Returns the currently used buffer (which can but need not be the same as + the old cstringio_buf). partialread is what the C code has read from the + buffer, and should be inserted into the buffer before any more reads. The + return value must be a new, not borrowed reference. Something along the + lines of self._buf should be fine. + + If reqlen bytes can't be read, throw EOFError. + """ + pass + +class TServerTransportBase: + + """Base class for Thrift server transports.""" + + def listen(self): + pass + + def accept(self): + pass + + def close(self): + pass + +class TTransportFactoryBase: + + """Base class for a Transport Factory""" + + def getTransport(self, trans): + return trans + +class TBufferedTransportFactory: + + """Factory transport that builds buffered transports""" + + def getTransport(self, trans): + buffered = TBufferedTransport(trans) + return buffered + + +class TBufferedTransport(TTransportBase,CReadableTransport): + + """Class that wraps another transport and buffers its I/O.""" + + DEFAULT_BUFFER = 4096 + + def __init__(self, trans): + self.__trans = trans + self.__wbuf = StringIO() + self.__rbuf = StringIO("") + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self.__rbuf = StringIO(self.__trans.read(max(sz, self.DEFAULT_BUFFER))) + return self.__rbuf.read(sz) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + out = self.__wbuf.getvalue() + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = StringIO() + self.__trans.write(out) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + retstring = partialread + if reqlen < self.DEFAULT_BUFFER: + # try to make a read of as much as we can. + retstring += self.__trans.read(self.DEFAULT_BUFFER) + + # but make sure we do read reqlen bytes. + if len(retstring) < reqlen: + retstring += self.__trans.readAll(reqlen - len(retstring)) + + self.__rbuf = StringIO(retstring) + return self.__rbuf + +class TMemoryBuffer(TTransportBase, CReadableTransport): + """Wraps a cStringIO object as a TTransport. + + NOTE: Unlike the C++ version of this class, you cannot write to it + then immediately read from it. If you want to read from a + TMemoryBuffer, you must either pass a string to the constructor. + TODO(dreiss): Make this work like the C++ version. + """ + + def __init__(self, value=None): + """value -- a value to read from for stringio + + If value is set, this will be a transport for reading, + otherwise, it is for writing""" + if value is not None: + self._buffer = StringIO(value) + else: + self._buffer = StringIO() + + def isOpen(self): + return not self._buffer.closed + + def open(self): + pass + + def close(self): + self._buffer.close() + + def read(self, sz): + return self._buffer.read(sz) + + def write(self, buf): + self._buffer.write(buf) + + def flush(self): + pass + + def getvalue(self): + return self._buffer.getvalue() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self._buffer + + def cstringio_refill(self, partialread, reqlen): + # only one shot at reading... + raise EOFError() + +class TFramedTransportFactory: + + """Factory transport that builds framed transports""" + + def getTransport(self, trans): + framed = TFramedTransport(trans) + return framed + + +class TFramedTransport(TTransportBase, CReadableTransport): + + """Class that wraps another transport and frames its I/O when writing.""" + + def __init__(self, trans,): + self.__trans = trans + self.__rbuf = StringIO() + self.__wbuf = StringIO() + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self.readFrame() + return self.__rbuf.read(sz) + + def readFrame(self): + buff = self.__trans.readAll(4) + sz, = unpack('!i', buff) + self.__rbuf = StringIO(self.__trans.readAll(sz)) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + wout = self.__wbuf.getvalue() + wsz = len(wout) + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = StringIO() + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + buf = pack("!i", wsz) + wout + self.__trans.write(buf) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + readFrame() + prefix += self.__rbuf.getvalue() + self.__rbuf = StringIO(prefix) + return self.__rbuf + + +class TFileObjectTransport(TTransportBase): + """Wraps a file-like object to make it work as a Thrift transport.""" + + def __init__(self, fileobj): + self.fileobj = fileobj + + def isOpen(self): + return True + + def close(self): + self.fileobj.close() + + def read(self, sz): + return self.fileobj.read(sz) + + def write(self, buf): + self.fileobj.write(buf) + + def flush(self): + self.fileobj.flush() diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py new file mode 100644 index 000000000..b5c2147b2 --- /dev/null +++ b/lib/py/src/transport/TTwisted.py @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from zope.interface import implements, Interface, Attribute +from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ + connectionDone +from twisted.internet import defer +from twisted.protocols import basic +from twisted.python import log + + +from thrift.transport import TTransport +from cStringIO import StringIO + + +class TMessageSenderTransport(TTransport.TTransportBase): + + def __init__(self): + self.__wbuf = StringIO() + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + msg = self.__wbuf.getvalue() + self.__wbuf = StringIO() + self.sendMessage(msg) + + def sendMessage(self, message): + raise NotImplementedError + + +class TCallbackTransport(TMessageSenderTransport): + + def __init__(self, func): + TMessageSenderTransport.__init__(self) + self.func = func + + def sendMessage(self, message): + self.func(message) + + +class ThriftClientProtocol(basic.Int32StringReceiver): + + def __init__(self, client_class, iprot_factory, oprot_factory=None): + self._client_class = client_class + self._iprot_factory = iprot_factory + if oprot_factory is None: + self._oprot_factory = iprot_factory + else: + self._oprot_factory = oprot_factory + + self.recv_map = {} + self.started = defer.Deferred() + + def dispatch(self, msg): + self.sendString(msg) + + def connectionMade(self): + tmo = TCallbackTransport(self.dispatch) + self.client = self._client_class(tmo, self._oprot_factory) + self.started.callback(self.client) + + def connectionLost(self, reason=connectionDone): + for k,v in self.client._reqs.iteritems(): + tex = TTransport.TTransportException( + type=TTransport.TTransportException.END_OF_FILE, + message='Connection closed') + v.errback(tex) + + def stringReceived(self, frame): + tr = TTransport.TMemoryBuffer(frame) + iprot = self._iprot_factory.getProtocol(tr) + (fname, mtype, rseqid) = iprot.readMessageBegin() + + try: + method = self.recv_map[fname] + except KeyError: + method = getattr(self.client, 'recv_' + fname) + self.recv_map[fname] = method + + method(iprot, mtype, rseqid) + + +class ThriftServerProtocol(basic.Int32StringReceiver): + + def dispatch(self, msg): + self.sendString(msg) + + def processError(self, error): + self.transport.loseConnection() + + def processOk(self, _, tmo): + msg = tmo.getvalue() + + if len(msg) > 0: + self.dispatch(msg) + + def stringReceived(self, frame): + tmi = TTransport.TMemoryBuffer(frame) + tmo = TTransport.TMemoryBuffer() + + iprot = self.factory.iprot_factory.getProtocol(tmi) + oprot = self.factory.oprot_factory.getProtocol(tmo) + + d = self.factory.processor.process(iprot, oprot) + d.addCallbacks(self.processOk, self.processError, + callbackArgs=(tmo,)) + + +class IThriftServerFactory(Interface): + + processor = Attribute("Thrift processor") + + iprot_factory = Attribute("Input protocol factory") + + oprot_factory = Attribute("Output protocol factory") + + +class IThriftClientFactory(Interface): + + client_class = Attribute("Thrift client class") + + iprot_factory = Attribute("Input protocol factory") + + oprot_factory = Attribute("Output protocol factory") + + +class ThriftServerFactory(ServerFactory): + + implements(IThriftServerFactory) + + protocol = ThriftServerProtocol + + def __init__(self, processor, iprot_factory, oprot_factory=None): + self.processor = processor + self.iprot_factory = iprot_factory + if oprot_factory is None: + self.oprot_factory = iprot_factory + else: + self.oprot_factory = oprot_factory + + +class ThriftClientFactory(ClientFactory): + + implements(IThriftClientFactory) + + protocol = ThriftClientProtocol + + def __init__(self, client_class, iprot_factory, oprot_factory=None): + self.client_class = client_class + self.iprot_factory = iprot_factory + if oprot_factory is None: + self.oprot_factory = iprot_factory + else: + self.oprot_factory = oprot_factory + + def buildProtocol(self, addr): + p = self.protocol(self.client_class, self.iprot_factory, + self.oprot_factory) + p.factory = self + return p diff --git a/lib/py/src/transport/__init__.py b/lib/py/src/transport/__init__.py new file mode 100644 index 000000000..02c6048a9 --- /dev/null +++ b/lib/py/src/transport/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TTransport', 'TSocket', 'THttpClient'] diff --git a/lib/rb/CHANGELOG b/lib/rb/CHANGELOG new file mode 100644 index 000000000..b5dce2aee --- /dev/null +++ b/lib/rb/CHANGELOG @@ -0,0 +1 @@ +v0.0.1. Initial release diff --git a/lib/rb/Makefile.am b/lib/rb/Makefile.am new file mode 100644 index 000000000..9cfffc71c --- /dev/null +++ b/lib/rb/Makefile.am @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +EXTRA_DIST = \ + CHANGELOG \ + Rakefile \ + Manifest \ + setup.rb \ + lib \ + ext \ + benchmark \ + script \ + spec + +all-local: + $(RUBY) setup.rb config + $(RUBY) setup.rb setup + +install-exec-hook: + $(RUBY) setup.rb install + +# Make sure this doesn't fail if Ruby is not configured. +clean-local: + RUBY=$(RUBY) ; if test -z "$$RUBY" ; then RUBY=: ; fi ; \ + $$RUBY setup.rb clean + +check-local: all +if HAVE_RSPEC + rake spec +endif + diff --git a/lib/rb/Manifest b/lib/rb/Manifest new file mode 100644 index 000000000..7b4503fb4 --- /dev/null +++ b/lib/rb/Manifest @@ -0,0 +1,81 @@ +CHANGELOG +Manifest +Rakefile +README +setup.rb +benchmark/benchmark.rb +benchmark/Benchmark.thrift +benchmark/client.rb +benchmark/server.rb +benchmark/thin_server.rb +ext/binary_protocol_accelerated.c +ext/binary_protocol_accelerated.h +ext/compact_protocol.c +ext/compact_protocol.h +ext/constants.h +ext/extconf.rb +ext/macros.h +ext/memory_buffer.c +ext/memory_buffer.h +ext/protocol.c +ext/protocol.h +ext/struct.c +ext/struct.h +ext/thrift_native.c +lib/thrift.rb +lib/thrift/client.rb +lib/thrift/core_ext.rb +lib/thrift/exceptions.rb +lib/thrift/processor.rb +lib/thrift/struct.rb +lib/thrift/thrift_native.rb +lib/thrift/types.rb +lib/thrift/core_ext/fixnum.rb +lib/thrift/protocol/base_protocol.rb +lib/thrift/protocol/binary_protocol.rb +lib/thrift/protocol/binary_protocol_accelerated.rb +lib/thrift/protocol/compact_protocol.rb +lib/thrift/serializer/deserializer.rb +lib/thrift/serializer/serializer.rb +lib/thrift/server/base_server.rb +lib/thrift/server/mongrel_http_server.rb +lib/thrift/server/nonblocking_server.rb +lib/thrift/server/simple_server.rb +lib/thrift/server/thread_pool_server.rb +lib/thrift/server/threaded_server.rb +lib/thrift/transport/base_server_transport.rb +lib/thrift/transport/base_transport.rb +lib/thrift/transport/buffered_transport.rb +lib/thrift/transport/framed_transport.rb +lib/thrift/transport/http_client_transport.rb +lib/thrift/transport/io_stream_transport.rb +lib/thrift/transport/memory_buffer_transport.rb +lib/thrift/transport/server_socket.rb +lib/thrift/transport/socket.rb +lib/thrift/transport/unix_server_socket.rb +lib/thrift/transport/unix_socket.rb +script/proto_benchmark.rb +script/read_struct.rb +script/write_struct.rb +spec/base_protocol_spec.rb +spec/base_transport_spec.rb +spec/binary_protocol_accelerated_spec.rb +spec/binary_protocol_spec.rb +spec/binary_protocol_spec_shared.rb +spec/client_spec.rb +spec/compact_protocol_spec.rb +spec/exception_spec.rb +spec/http_client_spec.rb +spec/mongrel_http_server_spec.rb +spec/nonblocking_server_spec.rb +spec/processor_spec.rb +spec/serializer_spec.rb +spec/server_socket_spec.rb +spec/server_spec.rb +spec/socket_spec.rb +spec/socket_spec_shared.rb +spec/spec_helper.rb +spec/struct_spec.rb +spec/ThriftSpec.thrift +spec/types_spec.rb +spec/unix_socket_spec.rb diff --git a/lib/rb/README b/lib/rb/README new file mode 100644 index 000000000..d78e35272 --- /dev/null +++ b/lib/rb/README @@ -0,0 +1,43 @@ +Thrift Ruby Software Library + http://incubator.apache.org/thrift/ + +== LICENSE: + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +== DESCRIPTION: + +Thrift is a strongly-typed language-agnostic RPC system. +This library is the ruby implementation for both clients and servers. + +== INSTALL: + + $ gem install thrift + +== CAVEATS: + +This library provides the client and server implementations of thrift. +It does <em>not</em> provide the compiler for the .thrift files. To compile +.thrift files into language-specific implementations, please download the full +thrift software package. + +== USAGE: + +This section should get written by someone with the time and inclination. +In the meantime, look at existing code, such as the benchmark or the tutorial +in the full thrift distribution. diff --git a/lib/rb/Rakefile b/lib/rb/Rakefile new file mode 100644 index 000000000..1a9467a57 --- /dev/null +++ b/lib/rb/Rakefile @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'rubygems' +require 'rake' +require 'spec/rake/spectask' + +THRIFT = '../../compiler/cpp/thrift' + +task :default => [:spec] + +task :spec => [:'gen-rb', :realspec] + +Spec::Rake::SpecTask.new(:realspec) do |t| + t.spec_files = FileList['spec/**/*_spec.rb'] + t.spec_opts = ['--color'] +end + +Spec::Rake::SpecTask.new(:'spec:rcov') do |t| + t.spec_files = FileList['spec/**/*_spec.rb'] + t.spec_opts = ['--color'] + t.rcov = true + t.rcov_opts = ['--exclude', '^spec,/gems/'] +end + +desc 'Run the compiler tests (requires full thrift checkout)' +task :test do + # ensure this is a full thrift checkout and not a tarball of the ruby libs + cmd = 'head -1 ../../README 2>/dev/null | grep Thrift >/dev/null 2>/dev/null' + system(cmd) or fail "rake test requires a full thrift checkout" + sh 'make', '-C', File.dirname(__FILE__) + "/../../test/rb", "check" +end + +desc 'Compile the .thrift files for the specs' +task :'gen-rb' => [:'gen-rb:spec', :'gen-rb:benchmark', :'gen-rb:debug_proto'] + +namespace :'gen-rb' do + task :'spec' do + dir = File.dirname(__FILE__) + '/spec' + sh THRIFT, '--gen', 'rb', '-o', dir, "#{dir}/ThriftSpec.thrift" + end + + task :'benchmark' do + dir = File.dirname(__FILE__) + '/benchmark' + sh THRIFT, '--gen', 'rb', '-o', dir, "#{dir}/Benchmark.thrift" + end + + task :'debug_proto' do + sh "mkdir", "-p", "debug_proto_test" + sh THRIFT, '--gen', 'rb', "-o", "debug_proto_test", "../../test/DebugProtoTest.thrift" + end +end + +desc 'Run benchmarking of NonblockingServer' +task :benchmark do + ruby 'benchmark/benchmark.rb' +end + + +begin + require 'echoe' + + Echoe.new('thrift') do |p| + p.author = ['Kevin Ballard', 'Kevin Clark', 'Mark Slee'] + p.email = ['kevin@sb.org', 'kevin.clark@gmail.com', 'mcslee@facebook.com'] + p.summary = "Ruby libraries for Thrift (a language-agnostic RPC system)" + p.url = "http://incubator.apache.org/thrift/" + p.include_rakefile = true + p.version = "0.1.0" + end + + task :install => [:check_site_lib] + + require 'rbconfig' + task :check_site_lib do + if File.exist?(File.join(Config::CONFIG['sitelibdir'], 'thrift.rb')) + fail "thrift is already installed in site_ruby" + end + end +rescue LoadError + [:install, :package].each do |t| + desc "Stub for #{t}" + task t do + fail "The Echoe gem is required for this task" + end + end +end diff --git a/lib/rb/benchmark/Benchmark.thrift b/lib/rb/benchmark/Benchmark.thrift new file mode 100644 index 000000000..eb5ae38e6 --- /dev/null +++ b/lib/rb/benchmark/Benchmark.thrift @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +namespace rb ThriftBenchmark + +service BenchmarkService { + i32 fibonacci(1:byte n) +} diff --git a/lib/rb/benchmark/benchmark.rb b/lib/rb/benchmark/benchmark.rb new file mode 100644 index 000000000..3dc67dd8c --- /dev/null +++ b/lib/rb/benchmark/benchmark.rb @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'rubygems' +$:.unshift File.dirname(__FILE__) + '/../lib' +require 'thrift' +require 'stringio' + +HOST = '127.0.0.1' +PORT = 42587 + +############### +## Server +############### + +class Server + attr_accessor :serverclass + attr_accessor :interpreter + attr_accessor :host + attr_accessor :port + + def initialize(opts) + @serverclass = opts.fetch(:class, Thrift::NonblockingServer) + @interpreter = opts.fetch(:interpreter, "ruby") + @host = opts.fetch(:host, ::HOST) + @port = opts.fetch(:port, ::PORT) + end + + def start + return if @serverclass == Object + args = (File.basename(@interpreter) == "jruby" ? "-J-server" : "") + @pipe = IO.popen("#{@interpreter} #{args} #{File.dirname(__FILE__)}/server.rb #{@host} #{@port} #{@serverclass.name}", "r+") + Marshal.load(@pipe) # wait until the server has started + sleep 0.4 # give the server time to actually start spawning sockets + end + + def shutdown + return unless @pipe + Marshal.dump(:shutdown, @pipe) + begin + @pipe.read(10) # block until the server shuts down + rescue EOFError + end + @pipe.close + @pipe = nil + end +end + +class BenchmarkManager + def initialize(opts, server) + @socket = opts.fetch(:socket) do + @host = opts.fetch(:host, 'localhost') + @port = opts.fetch(:port) + nil + end + @num_processes = opts.fetch(:num_processes, 40) + @clients_per_process = opts.fetch(:clients_per_process, 10) + @calls_per_client = opts.fetch(:calls_per_client, 50) + @interpreter = opts.fetch(:interpreter, "ruby") + @server = server + @log_exceptions = opts.fetch(:log_exceptions, false) + end + + def run + @pool = [] + @benchmark_start = Time.now + puts "Spawning benchmark processes..." + @num_processes.times do + spawn + sleep 0.02 # space out spawns + end + collect_output + @benchmark_end = Time.now # we know the procs are done here + translate_output + analyze_output + report_output + end + + def spawn + pipe = IO.popen("#{@interpreter} #{File.dirname(__FILE__)}/client.rb #{"-log-exceptions" if @log_exceptions} #{@host} #{@port} #{@clients_per_process} #{@calls_per_client}") + @pool << pipe + end + + def socket_class + if @socket + Thrift::UNIXSocket + else + Thrift::Socket + end + end + + def collect_output + puts "Collecting output..." + # read from @pool until all sockets are closed + @buffers = Hash.new { |h,k| h[k] = '' } + until @pool.empty? + rd, = select(@pool) + next if rd.nil? + rd.each do |fd| + begin + @buffers[fd] << fd.readpartial(4096) + rescue EOFError + @pool.delete fd + end + end + end + end + + def translate_output + puts "Translating output..." + @output = [] + @buffers.each do |fd, buffer| + strio = StringIO.new(buffer) + logs = [] + begin + loop do + logs << Marshal.load(strio) + end + rescue EOFError + @output << logs + end + end + end + + def analyze_output + puts "Analyzing output..." + call_times = [] + client_times = [] + connection_failures = [] + connection_errors = [] + shortest_call = 0 + shortest_client = 0 + longest_call = 0 + longest_client = 0 + @output.each do |logs| + cur_call, cur_client = nil + logs.each do |tok, time| + case tok + when :start + cur_client = time + when :call_start + cur_call = time + when :call_end + delta = time - cur_call + call_times << delta + longest_call = delta unless longest_call > delta + shortest_call = delta if shortest_call == 0 or delta < shortest_call + cur_call = nil + when :end + delta = time - cur_client + client_times << delta + longest_client = delta unless longest_client > delta + shortest_client = delta if shortest_client == 0 or delta < shortest_client + cur_client = nil + when :connection_failure + connection_failures << time + when :connection_error + connection_errors << time + end + end + end + @report = {} + @report[:total_calls] = call_times.inject(0.0) { |a,t| a += t } + @report[:avg_calls] = @report[:total_calls] / call_times.size + @report[:total_clients] = client_times.inject(0.0) { |a,t| a += t } + @report[:avg_clients] = @report[:total_clients] / client_times.size + @report[:connection_failures] = connection_failures.size + @report[:connection_errors] = connection_errors.size + @report[:shortest_call] = shortest_call + @report[:shortest_client] = shortest_client + @report[:longest_call] = longest_call + @report[:longest_client] = longest_client + @report[:total_benchmark_time] = @benchmark_end - @benchmark_start + @report[:fastthread] = $".include?('fastthread.bundle') + end + + def report_output + fmt = "%.4f seconds" + puts + tabulate "%d", + [["Server class", "%s"], @server.serverclass == Object ? "" : @server.serverclass], + [["Server interpreter", "%s"], @server.interpreter], + [["Client interpreter", "%s"], @interpreter], + [["Socket class", "%s"], socket_class], + ["Number of processes", @num_processes], + ["Clients per process", @clients_per_process], + ["Calls per client", @calls_per_client], + [["Using fastthread", "%s"], @report[:fastthread] ? "yes" : "no"] + puts + failures = (@report[:connection_failures] > 0) + tabulate fmt, + [["Connection failures", "%d", [:red, :bold]], @report[:connection_failures]], + [["Connection errors", "%d", [:red, :bold]], @report[:connection_errors]], + ["Average time per call", @report[:avg_calls]], + ["Average time per client (%d calls)" % @calls_per_client, @report[:avg_clients]], + ["Total time for all calls", @report[:total_calls]], + ["Real time for benchmarking", @report[:total_benchmark_time]], + ["Shortest call time", @report[:shortest_call]], + ["Longest call time", @report[:longest_call]], + ["Shortest client time (%d calls)" % @calls_per_client, @report[:shortest_client]], + ["Longest client time (%d calls)" % @calls_per_client, @report[:longest_client]] + end + + ANSI = { + :reset => 0, + :bold => 1, + :black => 30, + :red => 31, + :green => 32, + :yellow => 33, + :blue => 34, + :magenta => 35, + :cyan => 36, + :white => 37 + } + + def tabulate(fmt, *labels_and_values) + labels = labels_and_values.map { |l| Array === l ? l.first : l } + label_width = labels.inject(0) { |w,l| l.size > w ? l.size : w } + labels_and_values.each do |(l,v)| + f = fmt + l, f, c = l if Array === l + fmtstr = "%-#{label_width+1}s #{f}" + if STDOUT.tty? and c and v.to_i > 0 + fmtstr = "\e[#{[*c].map { |x| ANSI[x] } * ";"}m" + fmtstr + "\e[#{ANSI[:reset]}m" + end + puts fmtstr % [l+":", v] + end + end +end + +def resolve_const(const) + const and const.split('::').inject(Object) { |k,c| k.const_get(c) } +end + +puts "Starting server..." +args = {} +args[:interpreter] = ENV['THRIFT_SERVER_INTERPRETER'] || ENV['THRIFT_INTERPRETER'] || "ruby" +args[:class] = resolve_const(ENV['THRIFT_SERVER']) || Thrift::NonblockingServer +args[:host] = ENV['THRIFT_HOST'] || HOST +args[:port] = (ENV['THRIFT_PORT'] || PORT).to_i +server = Server.new(args) +server.start + +args = {} +args[:host] = ENV['THRIFT_HOST'] || HOST +args[:port] = (ENV['THRIFT_PORT'] || PORT).to_i +args[:num_processes] = (ENV['THRIFT_NUM_PROCESSES'] || 40).to_i +args[:clients_per_process] = (ENV['THRIFT_NUM_CLIENTS'] || 5).to_i +args[:calls_per_client] = (ENV['THRIFT_NUM_CALLS'] || 50).to_i +args[:interpreter] = ENV['THRIFT_CLIENT_INTERPRETER'] || ENV['THRIFT_INTERPRETER'] || "ruby" +args[:log_exceptions] = !!ENV['THRIFT_LOG_EXCEPTIONS'] +BenchmarkManager.new(args, server).run + +server.shutdown diff --git a/lib/rb/benchmark/client.rb b/lib/rb/benchmark/client.rb new file mode 100644 index 000000000..703dc8f52 --- /dev/null +++ b/lib/rb/benchmark/client.rb @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.unshift File.dirname(__FILE__) + '/../lib' +require 'thrift' +$:.unshift File.dirname(__FILE__) + "/gen-rb" +require 'benchmark_service' + +class Client + def initialize(host, port, clients_per_process, calls_per_client, log_exceptions) + @host = host + @port = port + @clients_per_process = clients_per_process + @calls_per_client = calls_per_client + @log_exceptions = log_exceptions + end + + def run + @clients_per_process.times do + socket = Thrift::Socket.new(@host, @port) + transport = Thrift::FramedTransport.new(socket) + protocol = Thrift::BinaryProtocol.new(transport) + client = ThriftBenchmark::BenchmarkService::Client.new(protocol) + begin + start = Time.now + transport.open + Marshal.dump [:start, start], STDOUT + rescue => e + Marshal.dump [:connection_failure, Time.now], STDOUT + print_exception e if @log_exceptions + else + begin + @calls_per_client.times do + Marshal.dump [:call_start, Time.now], STDOUT + client.fibonacci(15) + Marshal.dump [:call_end, Time.now], STDOUT + end + transport.close + Marshal.dump [:end, Time.now], STDOUT + rescue Thrift::TransportException => e + Marshal.dump [:connection_error, Time.now], STDOUT + print_exception e if @log_exceptions + end + end + end + end + + def print_exception(e) + STDERR.puts "ERROR: #{e.message}" + STDERR.puts "\t#{e.backtrace * "\n\t"}" + end +end + +log_exceptions = true if ARGV[0] == '-log-exceptions' and ARGV.shift + +host, port, clients_per_process, calls_per_client = ARGV + +Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions).run diff --git a/lib/rb/benchmark/server.rb b/lib/rb/benchmark/server.rb new file mode 100644 index 000000000..74e13f414 --- /dev/null +++ b/lib/rb/benchmark/server.rb @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.unshift File.dirname(__FILE__) + '/../lib' +require 'thrift' +$:.unshift File.dirname(__FILE__) + "/gen-rb" +require 'benchmark_service' + +module Server + include Thrift + + class BenchmarkHandler + # 1-based index into the fibonacci sequence + def fibonacci(n) + seq = [1, 1] + 3.upto(n) do + seq << seq[-1] + seq[-2] + end + seq[n-1] # n is 1-based + end + end + + def self.start_server(host, port, serverClass) + handler = BenchmarkHandler.new + processor = ThriftBenchmark::BenchmarkService::Processor.new(handler) + transport = ServerSocket.new(host, port) + transport_factory = FramedTransportFactory.new + args = [processor, transport, transport_factory, nil, 20] + if serverClass == NonblockingServer + logger = Logger.new(STDERR) + logger.level = Logger::WARN + args << logger + end + server = serverClass.new(*args) + @server_thread = Thread.new do + server.serve + end + @server = server + end + + def self.shutdown + return if @server.nil? + if @server.respond_to? :shutdown + @server.shutdown + else + @server_thread.kill + end + end +end + +def resolve_const(const) + const and const.split('::').inject(Object) { |k,c| k.const_get(c) } +end + +host, port, serverklass = ARGV + +Server.start_server(host, port.to_i, resolve_const(serverklass)) + +# let our host know that the interpreter has started +# ideally we'd wait until the server was serving, but we don't have a hook for that +Marshal.dump(:started, STDOUT) +STDOUT.flush + +Marshal.load(STDIN) # wait until we're instructed to shut down + +Server.shutdown diff --git a/lib/rb/benchmark/thin_server.rb b/lib/rb/benchmark/thin_server.rb new file mode 100644 index 000000000..4de2eef38 --- /dev/null +++ b/lib/rb/benchmark/thin_server.rb @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.unshift File.dirname(__FILE__) + '/../lib' +require 'thrift' +$:.unshift File.dirname(__FILE__) + "/gen-rb" +require 'benchmark_service' +HOST = 'localhost' +PORT = 42587 + +class BenchmarkHandler + # 1-based index into the fibonacci sequence + def fibonacci(n) + seq = [1, 1] + 3.upto(n) do + seq << seq[-1] + seq[-2] + end + seq[n-1] # n is 1-based + end +end + +handler = BenchmarkHandler.new +processor = ThriftBenchmark::BenchmarkService::Processor.new(handler) +transport = Thrift::ServerSocket.new(HOST, PORT) +transport_factory = Thrift::FramedTransportFactory.new +logger = Logger.new(STDERR) +logger.level = Logger::WARN +Thrift::NonblockingServer.new(processor, transport, transport_factory, nil, 20, logger).serve diff --git a/lib/rb/ext/binary_protocol_accelerated.c b/lib/rb/ext/binary_protocol_accelerated.c new file mode 100644 index 000000000..728a05725 --- /dev/null +++ b/lib/rb/ext/binary_protocol_accelerated.c @@ -0,0 +1,474 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <ruby.h> +#include <stdbool.h> +#include <stdint.h> +#include <constants.h> +#include <struct.h> +#include "macros.h" + +VALUE rb_thrift_binary_proto_native_qmark(VALUE self) { + return Qtrue; +} + + + +static int VERSION_1; +static int VERSION_MASK; +static int TYPE_MASK; +static int BAD_VERSION; + +static void write_byte_direct(VALUE trans, int8_t b) { + WRITE(trans, (char*)&b, 1); +} + +static void write_i16_direct(VALUE trans, int16_t value) { + char data[2]; + + data[1] = value; + data[0] = (value >> 8); + + WRITE(trans, data, 2); +} + +static void write_i32_direct(VALUE trans, int32_t value) { + char data[4]; + + data[3] = value; + data[2] = (value >> 8); + data[1] = (value >> 16); + data[0] = (value >> 24); + + WRITE(trans, data, 4); +} + + +static void write_i64_direct(VALUE trans, int64_t value) { + char data[8]; + + data[7] = value; + data[6] = (value >> 8); + data[5] = (value >> 16); + data[4] = (value >> 24); + data[3] = (value >> 32); + data[2] = (value >> 40); + data[1] = (value >> 48); + data[0] = (value >> 56); + + WRITE(trans, data, 8); +} + +static void write_string_direct(VALUE trans, VALUE str) { + write_i32_direct(trans, RSTRING_LEN(str)); + rb_funcall(trans, write_method_id, 1, str); +} + +//-------------------------------- +// interface writing methods +//-------------------------------- + +VALUE rb_thrift_binary_proto_write_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_struct_begin(VALUE self, VALUE name) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_struct_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_message_begin(VALUE self, VALUE name, VALUE type, VALUE seqid) { + VALUE trans = GET_TRANSPORT(self); + VALUE strict_write = GET_STRICT_WRITE(self); + + if (strict_write == Qtrue) { + write_i32_direct(trans, VERSION_1 | FIX2INT(type)); + write_string_direct(trans, name); + write_i32_direct(trans, FIX2INT(seqid)); + } else { + write_string_direct(trans, name); + write_byte_direct(trans, FIX2INT(type)); + write_i32_direct(trans, FIX2INT(seqid)); + } + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_field_begin(VALUE self, VALUE name, VALUE type, VALUE id) { + VALUE trans = GET_TRANSPORT(self); + write_byte_direct(trans, FIX2INT(type)); + write_i16_direct(trans, FIX2INT(id)); + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_field_stop(VALUE self) { + write_byte_direct(GET_TRANSPORT(self), TTYPE_STOP); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_map_begin(VALUE self, VALUE ktype, VALUE vtype, VALUE size) { + VALUE trans = GET_TRANSPORT(self); + write_byte_direct(trans, FIX2INT(ktype)); + write_byte_direct(trans, FIX2INT(vtype)); + write_i32_direct(trans, FIX2INT(size)); + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_list_begin(VALUE self, VALUE etype, VALUE size) { + VALUE trans = GET_TRANSPORT(self); + write_byte_direct(trans, FIX2INT(etype)); + write_i32_direct(trans, FIX2INT(size)); + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_set_begin(VALUE self, VALUE etype, VALUE size) { + rb_thrift_binary_proto_write_list_begin(self, etype, size); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_bool(VALUE self, VALUE b) { + write_byte_direct(GET_TRANSPORT(self), RTEST(b) ? 1 : 0); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_byte(VALUE self, VALUE byte) { + CHECK_NIL(byte); + write_byte_direct(GET_TRANSPORT(self), NUM2INT(byte)); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_i16(VALUE self, VALUE i16) { + CHECK_NIL(i16); + write_i16_direct(GET_TRANSPORT(self), FIX2INT(i16)); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_i32(VALUE self, VALUE i32) { + CHECK_NIL(i32); + write_i32_direct(GET_TRANSPORT(self), NUM2INT(i32)); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_i64(VALUE self, VALUE i64) { + CHECK_NIL(i64); + write_i64_direct(GET_TRANSPORT(self), NUM2LL(i64)); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_double(VALUE self, VALUE dub) { + CHECK_NIL(dub); + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t t; + } transfer; + transfer.f = RFLOAT_VALUE(rb_Float(dub)); + write_i64_direct(GET_TRANSPORT(self), transfer.t); + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_string(VALUE self, VALUE str) { + CHECK_NIL(str); + VALUE trans = GET_TRANSPORT(self); + write_string_direct(trans, str); + return Qnil; +} + +//--------------------------------------- +// interface reading methods +//--------------------------------------- + +VALUE rb_thrift_binary_proto_read_string(VALUE self); +VALUE rb_thrift_binary_proto_read_byte(VALUE self); +VALUE rb_thrift_binary_proto_read_i32(VALUE self); +VALUE rb_thrift_binary_proto_read_i16(VALUE self); + +static char read_byte_direct(VALUE self) { + VALUE buf = READ(self, 1); + return RSTRING_PTR(buf)[0]; +} + +static int16_t read_i16_direct(VALUE self) { + VALUE buf = READ(self, 2); + return (int16_t)(((uint8_t)(RSTRING_PTR(buf)[1])) | ((uint16_t)((RSTRING_PTR(buf)[0]) << 8))); +} + +static int32_t read_i32_direct(VALUE self) { + VALUE buf = READ(self, 4); + return ((uint8_t)(RSTRING_PTR(buf)[3])) | + (((uint8_t)(RSTRING_PTR(buf)[2])) << 8) | + (((uint8_t)(RSTRING_PTR(buf)[1])) << 16) | + (((uint8_t)(RSTRING_PTR(buf)[0])) << 24); +} + +static int64_t read_i64_direct(VALUE self) { + uint64_t hi = read_i32_direct(self); + uint32_t lo = read_i32_direct(self); + return (hi << 32) | lo; +} + +static VALUE get_protocol_exception(VALUE code, VALUE message) { + VALUE args[2]; + args[0] = code; + args[1] = message; + return rb_class_new_instance(2, (VALUE*)&args, protocol_exception_class); +} + +VALUE rb_thrift_binary_proto_read_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_struct_begin(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_struct_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_read_message_begin(VALUE self) { + VALUE strict_read = GET_STRICT_READ(self); + VALUE name, seqid; + int type; + + int version = read_i32_direct(self); + + if (version < 0) { + if ((version & VERSION_MASK) != VERSION_1) { + rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("Missing version identifier"))); + } + type = version & TYPE_MASK; + name = rb_thrift_binary_proto_read_string(self); + seqid = rb_thrift_binary_proto_read_i32(self); + } else { + if (strict_read == Qtrue) { + rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("No version identifier, old protocol client?"))); + } + name = READ(self, version); + type = read_byte_direct(self); + seqid = rb_thrift_binary_proto_read_i32(self); + } + + return rb_ary_new3(3, name, INT2FIX(type), seqid); +} + +VALUE rb_thrift_binary_proto_read_field_begin(VALUE self) { + int type = read_byte_direct(self); + if (type == TTYPE_STOP) { + return rb_ary_new3(3, Qnil, INT2FIX(type), INT2FIX(0)); + } else { + VALUE id = rb_thrift_binary_proto_read_i16(self); + return rb_ary_new3(3, Qnil, INT2FIX(type), id); + } +} + +VALUE rb_thrift_binary_proto_read_map_begin(VALUE self) { + VALUE ktype = rb_thrift_binary_proto_read_byte(self); + VALUE vtype = rb_thrift_binary_proto_read_byte(self); + VALUE size = rb_thrift_binary_proto_read_i32(self); + return rb_ary_new3(3, ktype, vtype, size); +} + +VALUE rb_thrift_binary_proto_read_list_begin(VALUE self) { + VALUE etype = rb_thrift_binary_proto_read_byte(self); + VALUE size = rb_thrift_binary_proto_read_i32(self); + return rb_ary_new3(2, etype, size); +} + +VALUE rb_thrift_binary_proto_read_set_begin(VALUE self) { + return rb_thrift_binary_proto_read_list_begin(self); +} + +VALUE rb_thrift_binary_proto_read_bool(VALUE self) { + char byte = read_byte_direct(self); + return byte != 0 ? Qtrue : Qfalse; +} + +VALUE rb_thrift_binary_proto_read_byte(VALUE self) { + return INT2FIX(read_byte_direct(self)); +} + +VALUE rb_thrift_binary_proto_read_i16(VALUE self) { + return INT2FIX(read_i16_direct(self)); +} + +VALUE rb_thrift_binary_proto_read_i32(VALUE self) { + return INT2NUM(read_i32_direct(self)); +} + +VALUE rb_thrift_binary_proto_read_i64(VALUE self) { + return LL2NUM(read_i64_direct(self)); +} + +VALUE rb_thrift_binary_proto_read_double(VALUE self) { + union { + double f; + int64_t t; + } transfer; + transfer.t = read_i64_direct(self); + return rb_float_new(transfer.f); +} + +VALUE rb_thrift_binary_proto_read_string(VALUE self) { + int size = read_i32_direct(self); + return READ(self, size); +} + +void Init_binary_protocol_accelerated() { + VALUE thrift_binary_protocol_class = rb_const_get(thrift_module, rb_intern("BinaryProtocol")); + + VERSION_1 = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_1"))); + VERSION_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_MASK"))); + TYPE_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("TYPE_MASK"))); + + VALUE bpa_class = rb_define_class_under(thrift_module, "BinaryProtocolAccelerated", thrift_binary_protocol_class); + + rb_define_method(bpa_class, "native?", rb_thrift_binary_proto_native_qmark, 0); + + rb_define_method(bpa_class, "write_message_begin", rb_thrift_binary_proto_write_message_begin, 3); + rb_define_method(bpa_class, "write_field_begin", rb_thrift_binary_proto_write_field_begin, 3); + rb_define_method(bpa_class, "write_field_stop", rb_thrift_binary_proto_write_field_stop, 0); + rb_define_method(bpa_class, "write_map_begin", rb_thrift_binary_proto_write_map_begin, 3); + rb_define_method(bpa_class, "write_list_begin", rb_thrift_binary_proto_write_list_begin, 2); + rb_define_method(bpa_class, "write_set_begin", rb_thrift_binary_proto_write_set_begin, 2); + rb_define_method(bpa_class, "write_byte", rb_thrift_binary_proto_write_byte, 1); + rb_define_method(bpa_class, "write_bool", rb_thrift_binary_proto_write_bool, 1); + rb_define_method(bpa_class, "write_i16", rb_thrift_binary_proto_write_i16, 1); + rb_define_method(bpa_class, "write_i32", rb_thrift_binary_proto_write_i32, 1); + rb_define_method(bpa_class, "write_i64", rb_thrift_binary_proto_write_i64, 1); + rb_define_method(bpa_class, "write_double", rb_thrift_binary_proto_write_double, 1); + rb_define_method(bpa_class, "write_string", rb_thrift_binary_proto_write_string, 1); + // unused methods + rb_define_method(bpa_class, "write_message_end", rb_thrift_binary_proto_write_message_end, 0); + rb_define_method(bpa_class, "write_struct_begin", rb_thrift_binary_proto_write_struct_begin, 1); + rb_define_method(bpa_class, "write_struct_end", rb_thrift_binary_proto_write_struct_end, 0); + rb_define_method(bpa_class, "write_field_end", rb_thrift_binary_proto_write_field_end, 0); + rb_define_method(bpa_class, "write_map_end", rb_thrift_binary_proto_write_map_end, 0); + rb_define_method(bpa_class, "write_list_end", rb_thrift_binary_proto_write_list_end, 0); + rb_define_method(bpa_class, "write_set_end", rb_thrift_binary_proto_write_set_end, 0); + + + + rb_define_method(bpa_class, "read_message_begin", rb_thrift_binary_proto_read_message_begin, 0); + rb_define_method(bpa_class, "read_field_begin", rb_thrift_binary_proto_read_field_begin, 0); + rb_define_method(bpa_class, "read_map_begin", rb_thrift_binary_proto_read_map_begin, 0); + rb_define_method(bpa_class, "read_list_begin", rb_thrift_binary_proto_read_list_begin, 0); + rb_define_method(bpa_class, "read_set_begin", rb_thrift_binary_proto_read_set_begin, 0); + rb_define_method(bpa_class, "read_byte", rb_thrift_binary_proto_read_byte, 0); + rb_define_method(bpa_class, "read_bool", rb_thrift_binary_proto_read_bool, 0); + rb_define_method(bpa_class, "read_i16", rb_thrift_binary_proto_read_i16, 0); + rb_define_method(bpa_class, "read_i32", rb_thrift_binary_proto_read_i32, 0); + rb_define_method(bpa_class, "read_i64", rb_thrift_binary_proto_read_i64, 0); + rb_define_method(bpa_class, "read_double", rb_thrift_binary_proto_read_double, 0); + rb_define_method(bpa_class, "read_string", rb_thrift_binary_proto_read_string, 0); + // unused methods + rb_define_method(bpa_class, "read_message_end", rb_thrift_binary_proto_read_message_end, 0); + rb_define_method(bpa_class, "read_struct_begin", rb_thift_binary_proto_read_struct_begin, 0); + rb_define_method(bpa_class, "read_struct_end", rb_thift_binary_proto_read_struct_end, 0); + rb_define_method(bpa_class, "read_field_end", rb_thift_binary_proto_read_field_end, 0); + rb_define_method(bpa_class, "read_map_end", rb_thift_binary_proto_read_map_end, 0); + rb_define_method(bpa_class, "read_list_end", rb_thift_binary_proto_read_list_end, 0); + rb_define_method(bpa_class, "read_set_end", rb_thift_binary_proto_read_set_end, 0); + + // set up native method table + native_proto_method_table *npmt; + npmt = ALLOC(native_proto_method_table); + + npmt->write_field_begin = rb_thrift_binary_proto_write_field_begin; + npmt->write_field_stop = rb_thrift_binary_proto_write_field_stop; + npmt->write_map_begin = rb_thrift_binary_proto_write_map_begin; + npmt->write_list_begin = rb_thrift_binary_proto_write_list_begin; + npmt->write_set_begin = rb_thrift_binary_proto_write_set_begin; + npmt->write_byte = rb_thrift_binary_proto_write_byte; + npmt->write_bool = rb_thrift_binary_proto_write_bool; + npmt->write_i16 = rb_thrift_binary_proto_write_i16; + npmt->write_i32 = rb_thrift_binary_proto_write_i32; + npmt->write_i64 = rb_thrift_binary_proto_write_i64; + npmt->write_double = rb_thrift_binary_proto_write_double; + npmt->write_string = rb_thrift_binary_proto_write_string; + npmt->write_message_end = rb_thrift_binary_proto_write_message_end; + npmt->write_struct_begin = rb_thrift_binary_proto_write_struct_begin; + npmt->write_struct_end = rb_thrift_binary_proto_write_struct_end; + npmt->write_field_end = rb_thrift_binary_proto_write_field_end; + npmt->write_map_end = rb_thrift_binary_proto_write_map_end; + npmt->write_list_end = rb_thrift_binary_proto_write_list_end; + npmt->write_set_end = rb_thrift_binary_proto_write_set_end; + + npmt->read_message_begin = rb_thrift_binary_proto_read_message_begin; + npmt->read_field_begin = rb_thrift_binary_proto_read_field_begin; + npmt->read_map_begin = rb_thrift_binary_proto_read_map_begin; + npmt->read_list_begin = rb_thrift_binary_proto_read_list_begin; + npmt->read_set_begin = rb_thrift_binary_proto_read_set_begin; + npmt->read_byte = rb_thrift_binary_proto_read_byte; + npmt->read_bool = rb_thrift_binary_proto_read_bool; + npmt->read_i16 = rb_thrift_binary_proto_read_i16; + npmt->read_i32 = rb_thrift_binary_proto_read_i32; + npmt->read_i64 = rb_thrift_binary_proto_read_i64; + npmt->read_double = rb_thrift_binary_proto_read_double; + npmt->read_string = rb_thrift_binary_proto_read_string; + npmt->read_message_end = rb_thrift_binary_proto_read_message_end; + npmt->read_struct_begin = rb_thift_binary_proto_read_struct_begin; + npmt->read_struct_end = rb_thift_binary_proto_read_struct_end; + npmt->read_field_end = rb_thift_binary_proto_read_field_end; + npmt->read_map_end = rb_thift_binary_proto_read_map_end; + npmt->read_list_end = rb_thift_binary_proto_read_list_end; + npmt->read_set_end = rb_thift_binary_proto_read_set_end; + + VALUE method_table_object = Data_Wrap_Struct(rb_cObject, 0, free, npmt); + rb_const_set(bpa_class, rb_intern("@native_method_table"), method_table_object); +} diff --git a/lib/rb/ext/binary_protocol_accelerated.h b/lib/rb/ext/binary_protocol_accelerated.h new file mode 100644 index 000000000..37baf4142 --- /dev/null +++ b/lib/rb/ext/binary_protocol_accelerated.h @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +void Init_binary_protocol_accelerated(); diff --git a/lib/rb/ext/compact_protocol.c b/lib/rb/ext/compact_protocol.c new file mode 100644 index 000000000..7966d3e3f --- /dev/null +++ b/lib/rb/ext/compact_protocol.c @@ -0,0 +1,665 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <ruby.h> +#include <stdbool.h> +#include <stdint.h> +#include <constants.h> +#include <struct.h> +#include "macros.h" + +#define LAST_ID(obj) FIX2INT(rb_ary_pop(rb_ivar_get(obj, last_field_id))) +#define SET_LAST_ID(obj, val) rb_ary_push(rb_ivar_get(obj, last_field_id), val) + +VALUE rb_thrift_compact_proto_native_qmark(VALUE self) { + return Qtrue; +} + +static ID last_field_id; +static ID boolean_field_id; +static ID bool_value_id; + +static int VERSION; +static int VERSION_MASK; +static int TYPE_MASK; +static int TYPE_SHIFT_AMOUNT; +static int PROTOCOL_ID; + +static VALUE thrift_compact_protocol_class; + +static int CTYPE_BOOLEAN_TRUE = 0x01; +static int CTYPE_BOOLEAN_FALSE = 0x02; +static int CTYPE_BYTE = 0x03; +static int CTYPE_I16 = 0x04; +static int CTYPE_I32 = 0x05; +static int CTYPE_I64 = 0x06; +static int CTYPE_DOUBLE = 0x07; +static int CTYPE_BINARY = 0x08; +static int CTYPE_LIST = 0x09; +static int CTYPE_SET = 0x0A; +static int CTYPE_MAP = 0x0B; +static int CTYPE_STRUCT = 0x0C; + +VALUE rb_thrift_compact_proto_write_i16(VALUE self, VALUE i16); + +// TODO: implement this +static int get_compact_type(VALUE type_value) { + int type = FIX2INT(type_value); + if (type == TTYPE_BOOL) { + return CTYPE_BOOLEAN_TRUE; + } else if (type == TTYPE_BYTE) { + return CTYPE_BYTE; + } else if (type == TTYPE_I16) { + return CTYPE_I16; + } else if (type == TTYPE_I32) { + return CTYPE_I32; + } else if (type == TTYPE_I64) { + return CTYPE_I64; + } else if (type == TTYPE_DOUBLE) { + return CTYPE_DOUBLE; + } else if (type == TTYPE_STRING) { + return CTYPE_BINARY; + } else if (type == TTYPE_LIST) { + return CTYPE_LIST; + } else if (type == TTYPE_SET) { + return CTYPE_SET; + } else if (type == TTYPE_MAP) { + return CTYPE_MAP; + } else if (type == TTYPE_STRUCT) { + return CTYPE_STRUCT; + } else { + char str[50]; + sprintf(str, "don't know what type: %d", type); + rb_raise(rb_eStandardError, str); + return 0; + } +} + +static void write_byte_direct(VALUE transport, int8_t b) { + WRITE(transport, (char*)&b, 1); +} + +static void write_field_begin_internal(VALUE self, VALUE type, VALUE id_value, VALUE type_override) { + int id = FIX2INT(id_value); + int last_id = LAST_ID(self); + VALUE transport = GET_TRANSPORT(self); + + // if there's a type override, use that. + int8_t type_to_write = RTEST(type_override) ? FIX2INT(type_override) : get_compact_type(type); + // check if we can use delta encoding for the field id + int diff = id - last_id; + if (diff > 0 && diff <= 15) { + // write them together + write_byte_direct(transport, diff << 4 | (type_to_write & 0x0f)); + } else { + // write them separate + write_byte_direct(transport, type_to_write & 0x0f); + rb_thrift_compact_proto_write_i16(self, id_value); + } + + SET_LAST_ID(self, id_value); +} + +static int32_t int_to_zig_zag(int32_t n) { + return (n << 1) ^ (n >> 31); +} + +static uint64_t ll_to_zig_zag(int64_t n) { + return (n << 1) ^ (n >> 63); +} + +static void write_varint32(VALUE transport, uint32_t n) { + while (true) { + if ((n & ~0x7F) == 0) { + write_byte_direct(transport, n & 0x7f); + break; + } else { + write_byte_direct(transport, (n & 0x7F) | 0x80); + n = n >> 7; + } + } +} + +static void write_varint64(VALUE transport, uint64_t n) { + while (true) { + if ((n & ~0x7F) == 0) { + write_byte_direct(transport, n & 0x7f); + break; + } else { + write_byte_direct(transport, (n & 0x7F) | 0x80); + n = n >> 7; + } + } +} + +static void write_collection_begin(VALUE transport, VALUE elem_type, VALUE size_value) { + int size = FIX2INT(size_value); + if (size <= 14) { + write_byte_direct(transport, size << 4 | get_compact_type(elem_type)); + } else { + write_byte_direct(transport, 0xf0 | get_compact_type(elem_type)); + write_varint32(transport, size); + } +} + + +//-------------------------------- +// interface writing methods +//-------------------------------- + +VALUE rb_thrift_compact_proto_write_i32(VALUE self, VALUE i32); +VALUE rb_thrift_compact_proto_write_string(VALUE self, VALUE str); + +VALUE rb_thrift_compact_proto_write_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_struct_begin(VALUE self, VALUE name) { + rb_ary_push(rb_ivar_get(self, last_field_id), INT2FIX(0)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_struct_end(VALUE self) { + rb_ary_pop(rb_ivar_get(self, last_field_id)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_message_begin(VALUE self, VALUE name, VALUE type, VALUE seqid) { + VALUE transport = GET_TRANSPORT(self); + write_byte_direct(transport, PROTOCOL_ID); + write_byte_direct(transport, (VERSION & VERSION_MASK) | ((FIX2INT(type) << TYPE_SHIFT_AMOUNT) & TYPE_MASK)); + write_varint32(transport, FIX2INT(seqid)); + rb_thrift_compact_proto_write_string(self, name); + + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_field_begin(VALUE self, VALUE name, VALUE type, VALUE id) { + if (FIX2INT(type) == TTYPE_BOOL) { + // we want to possibly include the value, so we'll wait. + rb_ivar_set(self, boolean_field_id, rb_ary_new3(2, type, id)); + } else { + write_field_begin_internal(self, type, id, Qnil); + } + + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_field_stop(VALUE self) { + write_byte_direct(GET_TRANSPORT(self), TTYPE_STOP); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_map_begin(VALUE self, VALUE ktype, VALUE vtype, VALUE size_value) { + int size = FIX2INT(size_value); + VALUE transport = GET_TRANSPORT(self); + if (size == 0) { + write_byte_direct(transport, 0); + } else { + write_varint32(transport, size); + write_byte_direct(transport, get_compact_type(ktype) << 4 | get_compact_type(vtype)); + } + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_list_begin(VALUE self, VALUE etype, VALUE size) { + write_collection_begin(GET_TRANSPORT(self), etype, size); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_set_begin(VALUE self, VALUE etype, VALUE size) { + write_collection_begin(GET_TRANSPORT(self), etype, size); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_bool(VALUE self, VALUE b) { + int8_t type = b == Qtrue ? CTYPE_BOOLEAN_TRUE : CTYPE_BOOLEAN_FALSE; + VALUE boolean_field = rb_ivar_get(self, boolean_field_id); + if (NIL_P(boolean_field)) { + // we're not part of a field, so just write the value. + write_byte_direct(GET_TRANSPORT(self), type); + } else { + // we haven't written the field header yet + write_field_begin_internal(self, rb_ary_entry(boolean_field, 0), rb_ary_entry(boolean_field, 1), INT2FIX(type)); + rb_ivar_set(self, boolean_field_id, Qnil); + } + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_byte(VALUE self, VALUE byte) { + CHECK_NIL(byte); + write_byte_direct(GET_TRANSPORT(self), FIX2INT(byte)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_i16(VALUE self, VALUE i16) { + rb_thrift_compact_proto_write_i32(self, i16); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_i32(VALUE self, VALUE i32) { + CHECK_NIL(i32); + write_varint32(GET_TRANSPORT(self), int_to_zig_zag(NUM2INT(i32))); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_i64(VALUE self, VALUE i64) { + CHECK_NIL(i64); + write_varint64(GET_TRANSPORT(self), ll_to_zig_zag(NUM2LL(i64))); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_double(VALUE self, VALUE dub) { + CHECK_NIL(dub); + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t l; + } transfer; + transfer.f = RFLOAT_VALUE(rb_Float(dub)); + char buf[8]; + buf[0] = transfer.l & 0xff; + buf[1] = (transfer.l >> 8) & 0xff; + buf[2] = (transfer.l >> 16) & 0xff; + buf[3] = (transfer.l >> 24) & 0xff; + buf[4] = (transfer.l >> 32) & 0xff; + buf[5] = (transfer.l >> 40) & 0xff; + buf[6] = (transfer.l >> 48) & 0xff; + buf[7] = (transfer.l >> 56) & 0xff; + WRITE(GET_TRANSPORT(self), buf, 8); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_string(VALUE self, VALUE str) { + VALUE transport = GET_TRANSPORT(self); + write_varint32(transport, RSTRING_LEN(str)); + WRITE(transport, RSTRING_PTR(str), RSTRING_LEN(str)); + return Qnil; +} + +//--------------------------------------- +// interface reading methods +//--------------------------------------- + +#define is_bool_type(ctype) (((ctype) & 0x0F) == CTYPE_BOOLEAN_TRUE || ((ctype) & 0x0F) == CTYPE_BOOLEAN_FALSE) + +VALUE rb_thrift_compact_proto_read_string(VALUE self); +VALUE rb_thrift_compact_proto_read_byte(VALUE self); +VALUE rb_thrift_compact_proto_read_i32(VALUE self); +VALUE rb_thrift_compact_proto_read_i16(VALUE self); + +static int8_t get_ttype(int8_t ctype) { + if (ctype == TTYPE_STOP) { + return TTYPE_STOP; + } else if (ctype == CTYPE_BOOLEAN_TRUE || ctype == CTYPE_BOOLEAN_FALSE) { + return TTYPE_BOOL; + } else if (ctype == CTYPE_BYTE) { + return TTYPE_BYTE; + } else if (ctype == CTYPE_I16) { + return TTYPE_I16; + } else if (ctype == CTYPE_I32) { + return TTYPE_I32; + } else if (ctype == CTYPE_I64) { + return TTYPE_I64; + } else if (ctype == CTYPE_DOUBLE) { + return TTYPE_DOUBLE; + } else if (ctype == CTYPE_BINARY) { + return TTYPE_STRING; + } else if (ctype == CTYPE_LIST) { + return TTYPE_LIST; + } else if (ctype == CTYPE_SET) { + return TTYPE_SET; + } else if (ctype == CTYPE_MAP) { + return TTYPE_MAP; + } else if (ctype == CTYPE_STRUCT) { + return TTYPE_STRUCT; + } else { + char str[50]; + sprintf(str, "don't know what type: %d", ctype); + rb_raise(rb_eStandardError, str); + return 0; + } +} + +static char read_byte_direct(VALUE self) { + VALUE buf = READ(self, 1); + return RSTRING_PTR(buf)[0]; +} + +static int64_t zig_zag_to_ll(int64_t n) { + return (((uint64_t)n) >> 1) ^ -(n & 1); +} + +static int32_t zig_zag_to_int(int32_t n) { + return (((uint32_t)n) >> 1) ^ -(n & 1); +} + +static int64_t read_varint64(VALUE self) { + int shift = 0; + int64_t result = 0; + while (true) { + int8_t b = read_byte_direct(self); + result = result | ((uint64_t)(b & 0x7f) << shift); + if ((b & 0x80) != 0x80) { + break; + } + shift += 7; + } + return result; +} + +static int16_t read_i16(VALUE self) { + return zig_zag_to_int((int32_t)read_varint64(self)); +} + +static VALUE get_protocol_exception(VALUE code, VALUE message) { + VALUE args[2]; + args[0] = code; + args[1] = message; + return rb_class_new_instance(2, (VALUE*)&args, protocol_exception_class); +} + +VALUE rb_thrift_compact_proto_read_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_struct_begin(VALUE self) { + rb_ary_push(rb_ivar_get(self, last_field_id), INT2FIX(0)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_struct_end(VALUE self) { + rb_ary_pop(rb_ivar_get(self, last_field_id)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_message_begin(VALUE self) { + int8_t protocol_id = read_byte_direct(self); + if (protocol_id != PROTOCOL_ID) { + char buf[100]; + int len = sprintf(buf, "Expected protocol id %d but got %d", PROTOCOL_ID, protocol_id); + buf[len] = 0; + rb_exc_raise(get_protocol_exception(INT2FIX(-1), rb_str_new2(buf))); + } + + int8_t version_and_type = read_byte_direct(self); + int8_t version = version_and_type & VERSION_MASK; + if (version != VERSION) { + char buf[100]; + int len = sprintf(buf, "Expected version id %d but got %d", version, VERSION); + buf[len] = 0; + rb_exc_raise(get_protocol_exception(INT2FIX(-1), rb_str_new2(buf))); + } + + int8_t type = (version_and_type >> TYPE_SHIFT_AMOUNT) & 0x03; + int32_t seqid = read_varint64(self); + VALUE messageName = rb_thrift_compact_proto_read_string(self); + return rb_ary_new3(3, messageName, INT2FIX(type), INT2NUM(seqid)); +} + +VALUE rb_thrift_compact_proto_read_field_begin(VALUE self) { + int8_t type = read_byte_direct(self); + // if it's a stop, then we can return immediately, as the struct is over. + if ((type & 0x0f) == TTYPE_STOP) { + return rb_ary_new3(3, Qnil, INT2FIX(0), INT2FIX(0)); + } else { + int field_id = 0; + + // mask off the 4 MSB of the type header. it could contain a field id delta. + uint8_t modifier = ((type & 0xf0) >> 4); + + if (modifier == 0) { + // not a delta. look ahead for the zigzag varint field id. + field_id = read_i16(self); + } else { + // has a delta. add the delta to the last read field id. + field_id = LAST_ID(self) + modifier; + } + + // if this happens to be a boolean field, the value is encoded in the type + if (is_bool_type(type)) { + // save the boolean value in a special instance variable. + rb_ivar_set(self, bool_value_id, (type & 0x0f) == CTYPE_BOOLEAN_TRUE ? Qtrue : Qfalse); + } + + // push the new field onto the field stack so we can keep the deltas going. + SET_LAST_ID(self, INT2FIX(field_id)); + return rb_ary_new3(3, Qnil, INT2FIX(get_ttype(type & 0x0f)), INT2FIX(field_id)); + } +} + +VALUE rb_thrift_compact_proto_read_map_begin(VALUE self) { + int32_t size = read_varint64(self); + uint8_t key_and_value_type = size == 0 ? 0 : read_byte_direct(self); + return rb_ary_new3(3, INT2FIX(get_ttype(key_and_value_type >> 4)), INT2FIX(get_ttype(key_and_value_type & 0xf)), INT2FIX(size)); +} + +VALUE rb_thrift_compact_proto_read_list_begin(VALUE self) { + uint8_t size_and_type = read_byte_direct(self); + int32_t size = (size_and_type >> 4) & 0x0f; + if (size == 15) { + size = read_varint64(self); + } + uint8_t type = get_ttype(size_and_type & 0x0f); + return rb_ary_new3(2, INT2FIX(type), INT2FIX(size)); +} + +VALUE rb_thrift_compact_proto_read_set_begin(VALUE self) { + return rb_thrift_compact_proto_read_list_begin(self); +} + +VALUE rb_thrift_compact_proto_read_bool(VALUE self) { + VALUE bool_value = rb_ivar_get(self, bool_value_id); + if (NIL_P(bool_value)) { + return read_byte_direct(self) == CTYPE_BOOLEAN_TRUE ? Qtrue : Qfalse; + } else { + rb_ivar_set(self, bool_value_id, Qnil); + return bool_value; + } +} + +VALUE rb_thrift_compact_proto_read_byte(VALUE self) { + return INT2FIX(read_byte_direct(self)); +} + +VALUE rb_thrift_compact_proto_read_i16(VALUE self) { + return INT2FIX(read_i16(self)); +} + +VALUE rb_thrift_compact_proto_read_i32(VALUE self) { + return INT2NUM(zig_zag_to_int(read_varint64(self))); +} + +VALUE rb_thrift_compact_proto_read_i64(VALUE self) { + return LL2NUM(zig_zag_to_ll(read_varint64(self))); +} + +VALUE rb_thrift_compact_proto_read_double(VALUE self) { + union { + double f; + int64_t l; + } transfer; + VALUE bytes = READ(self, 8); + uint32_t lo = ((uint8_t)(RSTRING_PTR(bytes)[0])) + | (((uint8_t)(RSTRING_PTR(bytes)[1])) << 8) + | (((uint8_t)(RSTRING_PTR(bytes)[2])) << 16) + | (((uint8_t)(RSTRING_PTR(bytes)[3])) << 24); + uint64_t hi = (((uint8_t)(RSTRING_PTR(bytes)[4]))) + | (((uint8_t)(RSTRING_PTR(bytes)[5])) << 8) + | (((uint8_t)(RSTRING_PTR(bytes)[6])) << 16) + | (((uint8_t)(RSTRING_PTR(bytes)[7])) << 24); + transfer.l = (hi << 32) | lo; + + return rb_float_new(transfer.f); +} + +VALUE rb_thrift_compact_proto_read_string(VALUE self) { + int64_t size = read_varint64(self); + return READ(self, size); +} + +static void Init_constants() { + thrift_compact_protocol_class = rb_const_get(thrift_module, rb_intern("CompactProtocol")); + + VERSION = rb_num2ll(rb_const_get(thrift_compact_protocol_class, rb_intern("VERSION"))); + VERSION_MASK = rb_num2ll(rb_const_get(thrift_compact_protocol_class, rb_intern("VERSION_MASK"))); + TYPE_MASK = rb_num2ll(rb_const_get(thrift_compact_protocol_class, rb_intern("TYPE_MASK"))); + TYPE_SHIFT_AMOUNT = FIX2INT(rb_const_get(thrift_compact_protocol_class, rb_intern("TYPE_SHIFT_AMOUNT"))); + PROTOCOL_ID = FIX2INT(rb_const_get(thrift_compact_protocol_class, rb_intern("PROTOCOL_ID"))); + + last_field_id = rb_intern("@last_field"); + boolean_field_id = rb_intern("@boolean_field"); + bool_value_id = rb_intern("@bool_value"); +} + +static void Init_rb_methods() { + rb_define_method(thrift_compact_protocol_class, "native?", rb_thrift_compact_proto_native_qmark, 0); + + rb_define_method(thrift_compact_protocol_class, "write_message_begin", rb_thrift_compact_proto_write_message_begin, 3); + rb_define_method(thrift_compact_protocol_class, "write_field_begin", rb_thrift_compact_proto_write_field_begin, 3); + rb_define_method(thrift_compact_protocol_class, "write_field_stop", rb_thrift_compact_proto_write_field_stop, 0); + rb_define_method(thrift_compact_protocol_class, "write_map_begin", rb_thrift_compact_proto_write_map_begin, 3); + rb_define_method(thrift_compact_protocol_class, "write_list_begin", rb_thrift_compact_proto_write_list_begin, 2); + rb_define_method(thrift_compact_protocol_class, "write_set_begin", rb_thrift_compact_proto_write_set_begin, 2); + rb_define_method(thrift_compact_protocol_class, "write_byte", rb_thrift_compact_proto_write_byte, 1); + rb_define_method(thrift_compact_protocol_class, "write_bool", rb_thrift_compact_proto_write_bool, 1); + rb_define_method(thrift_compact_protocol_class, "write_i16", rb_thrift_compact_proto_write_i16, 1); + rb_define_method(thrift_compact_protocol_class, "write_i32", rb_thrift_compact_proto_write_i32, 1); + rb_define_method(thrift_compact_protocol_class, "write_i64", rb_thrift_compact_proto_write_i64, 1); + rb_define_method(thrift_compact_protocol_class, "write_double", rb_thrift_compact_proto_write_double, 1); + rb_define_method(thrift_compact_protocol_class, "write_string", rb_thrift_compact_proto_write_string, 1); + + rb_define_method(thrift_compact_protocol_class, "write_message_end", rb_thrift_compact_proto_write_message_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_struct_begin", rb_thrift_compact_proto_write_struct_begin, 1); + rb_define_method(thrift_compact_protocol_class, "write_struct_end", rb_thrift_compact_proto_write_struct_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_field_end", rb_thrift_compact_proto_write_field_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_map_end", rb_thrift_compact_proto_write_map_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_list_end", rb_thrift_compact_proto_write_list_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_set_end", rb_thrift_compact_proto_write_set_end, 0); + + + rb_define_method(thrift_compact_protocol_class, "read_message_begin", rb_thrift_compact_proto_read_message_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_field_begin", rb_thrift_compact_proto_read_field_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_map_begin", rb_thrift_compact_proto_read_map_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_list_begin", rb_thrift_compact_proto_read_list_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_set_begin", rb_thrift_compact_proto_read_set_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_byte", rb_thrift_compact_proto_read_byte, 0); + rb_define_method(thrift_compact_protocol_class, "read_bool", rb_thrift_compact_proto_read_bool, 0); + rb_define_method(thrift_compact_protocol_class, "read_i16", rb_thrift_compact_proto_read_i16, 0); + rb_define_method(thrift_compact_protocol_class, "read_i32", rb_thrift_compact_proto_read_i32, 0); + rb_define_method(thrift_compact_protocol_class, "read_i64", rb_thrift_compact_proto_read_i64, 0); + rb_define_method(thrift_compact_protocol_class, "read_double", rb_thrift_compact_proto_read_double, 0); + rb_define_method(thrift_compact_protocol_class, "read_string", rb_thrift_compact_proto_read_string, 0); + + rb_define_method(thrift_compact_protocol_class, "read_message_end", rb_thrift_compact_proto_read_message_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_struct_begin", rb_thrift_compact_proto_read_struct_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_struct_end", rb_thrift_compact_proto_read_struct_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_field_end", rb_thrift_compact_proto_read_field_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_map_end", rb_thrift_compact_proto_read_map_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_list_end", rb_thrift_compact_proto_read_list_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_set_end", rb_thrift_compact_proto_read_set_end, 0); +} + +static void Init_npmt() { + native_proto_method_table *npmt; + npmt = ALLOC(native_proto_method_table); + + npmt->write_field_begin = rb_thrift_compact_proto_write_field_begin; + npmt->write_field_stop = rb_thrift_compact_proto_write_field_stop; + npmt->write_map_begin = rb_thrift_compact_proto_write_map_begin; + npmt->write_list_begin = rb_thrift_compact_proto_write_list_begin; + npmt->write_set_begin = rb_thrift_compact_proto_write_set_begin; + npmt->write_byte = rb_thrift_compact_proto_write_byte; + npmt->write_bool = rb_thrift_compact_proto_write_bool; + npmt->write_i16 = rb_thrift_compact_proto_write_i16; + npmt->write_i32 = rb_thrift_compact_proto_write_i32; + npmt->write_i64 = rb_thrift_compact_proto_write_i64; + npmt->write_double = rb_thrift_compact_proto_write_double; + npmt->write_string = rb_thrift_compact_proto_write_string; + npmt->write_message_end = rb_thrift_compact_proto_write_message_end; + npmt->write_struct_begin = rb_thrift_compact_proto_write_struct_begin; + npmt->write_struct_end = rb_thrift_compact_proto_write_struct_end; + npmt->write_field_end = rb_thrift_compact_proto_write_field_end; + npmt->write_map_end = rb_thrift_compact_proto_write_map_end; + npmt->write_list_end = rb_thrift_compact_proto_write_list_end; + npmt->write_set_end = rb_thrift_compact_proto_write_set_end; + + npmt->read_message_begin = rb_thrift_compact_proto_read_message_begin; + npmt->read_field_begin = rb_thrift_compact_proto_read_field_begin; + npmt->read_map_begin = rb_thrift_compact_proto_read_map_begin; + npmt->read_list_begin = rb_thrift_compact_proto_read_list_begin; + npmt->read_set_begin = rb_thrift_compact_proto_read_set_begin; + npmt->read_byte = rb_thrift_compact_proto_read_byte; + npmt->read_bool = rb_thrift_compact_proto_read_bool; + npmt->read_i16 = rb_thrift_compact_proto_read_i16; + npmt->read_i32 = rb_thrift_compact_proto_read_i32; + npmt->read_i64 = rb_thrift_compact_proto_read_i64; + npmt->read_double = rb_thrift_compact_proto_read_double; + npmt->read_string = rb_thrift_compact_proto_read_string; + npmt->read_message_end = rb_thrift_compact_proto_read_message_end; + npmt->read_struct_begin = rb_thrift_compact_proto_read_struct_begin; + npmt->read_struct_end = rb_thrift_compact_proto_read_struct_end; + npmt->read_field_end = rb_thrift_compact_proto_read_field_end; + npmt->read_map_end = rb_thrift_compact_proto_read_map_end; + npmt->read_list_end = rb_thrift_compact_proto_read_list_end; + npmt->read_set_end = rb_thrift_compact_proto_read_set_end; + + VALUE method_table_object = Data_Wrap_Struct(rb_cObject, 0, free, npmt); + rb_const_set(thrift_compact_protocol_class, rb_intern("@native_method_table"), method_table_object); +} + + + +void Init_compact_protocol() { + Init_constants(); + Init_rb_methods(); + Init_npmt(); +} diff --git a/lib/rb/ext/compact_protocol.h b/lib/rb/ext/compact_protocol.h new file mode 100644 index 000000000..163915e94 --- /dev/null +++ b/lib/rb/ext/compact_protocol.h @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +void Init_compact_protocol(); diff --git a/lib/rb/ext/constants.h b/lib/rb/ext/constants.h new file mode 100644 index 000000000..57df544b5 --- /dev/null +++ b/lib/rb/ext/constants.h @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern int TTYPE_STOP; +extern int TTYPE_BOOL; +extern int TTYPE_BYTE; +extern int TTYPE_I16; +extern int TTYPE_I32; +extern int TTYPE_I64; +extern int TTYPE_DOUBLE; +extern int TTYPE_STRING; +extern int TTYPE_MAP; +extern int TTYPE_SET; +extern int TTYPE_LIST; +extern int TTYPE_STRUCT; + +extern ID validate_method_id; +extern ID write_struct_begin_method_id; +extern ID write_struct_end_method_id; +extern ID write_field_begin_method_id; +extern ID write_field_end_method_id; +extern ID write_boolean_method_id; +extern ID write_byte_method_id; +extern ID write_i16_method_id; +extern ID write_i32_method_id; +extern ID write_i64_method_id; +extern ID write_double_method_id; +extern ID write_string_method_id; +extern ID write_map_begin_method_id; +extern ID write_map_end_method_id; +extern ID write_list_begin_method_id; +extern ID write_list_end_method_id; +extern ID write_set_begin_method_id; +extern ID write_set_end_method_id; +extern ID size_method_id; +extern ID read_bool_method_id; +extern ID read_byte_method_id; +extern ID read_i16_method_id; +extern ID read_i32_method_id; +extern ID read_i64_method_id; +extern ID read_string_method_id; +extern ID read_double_method_id; +extern ID read_map_begin_method_id; +extern ID read_map_end_method_id; +extern ID read_list_begin_method_id; +extern ID read_list_end_method_id; +extern ID read_set_begin_method_id; +extern ID read_set_end_method_id; +extern ID read_struct_begin_method_id; +extern ID read_struct_end_method_id; +extern ID read_field_begin_method_id; +extern ID read_field_end_method_id; +extern ID keys_method_id; +extern ID entries_method_id; +extern ID name_method_id; +extern ID sort_method_id; +extern ID write_field_stop_method_id; +extern ID skip_method_id; +extern ID write_method_id; +extern ID read_all_method_id; +extern ID native_qmark_method_id; + +extern ID fields_const_id; +extern ID transport_ivar_id; +extern ID strict_read_ivar_id; +extern ID strict_write_ivar_id; + +extern VALUE type_sym; +extern VALUE name_sym; +extern VALUE key_sym; +extern VALUE value_sym; +extern VALUE element_sym; +extern VALUE class_sym; + +extern VALUE rb_cSet; +extern VALUE thrift_module; +extern VALUE thrift_types_module; +extern VALUE class_thrift_protocol; +extern VALUE protocol_exception_class; diff --git a/lib/rb/ext/extconf.rb b/lib/rb/ext/extconf.rb new file mode 100644 index 000000000..07558b8aa --- /dev/null +++ b/lib/rb/ext/extconf.rb @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'mkmf' + +$CFLAGS = "-g -O2 -Wall -Werror" + +have_func("strlcpy", "string.h") + +create_makefile 'thrift_native' diff --git a/lib/rb/ext/macros.h b/lib/rb/ext/macros.h new file mode 100644 index 000000000..265f6930d --- /dev/null +++ b/lib/rb/ext/macros.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#define GET_TRANSPORT(obj) rb_ivar_get(obj, transport_ivar_id) +#define GET_STRICT_READ(obj) rb_ivar_get(obj, strict_read_ivar_id) +#define GET_STRICT_WRITE(obj) rb_ivar_get(obj, strict_write_ivar_id) +#define WRITE(obj, data, length) rb_funcall(obj, write_method_id, 1, rb_str_new(data, length)) +#define CHECK_NIL(obj) if (NIL_P(obj)) { rb_raise(rb_eStandardError, "nil argument not allowed!");} +#define READ(obj, length) rb_funcall(GET_TRANSPORT(obj), read_all_method_id, 1, INT2FIX(length)) + +#ifndef RFLOAT_VALUE +# define RFLOAT_VALUE(v) RFLOAT(rb_Float(v))->value +#endif + +#ifndef RSTRING_LEN +# define RSTRING_LEN(v) RSTRING(rb_String(v))->len +#endif + +#ifndef RSTRING_PTR +# define RSTRING_PTR(v) RSTRING(rb_String(v))->ptr +#endif + +#ifndef RARRAY_LEN +# define RARRAY_LEN(v) RARRAY(rb_Array(v))->len +#endif diff --git a/lib/rb/ext/memory_buffer.c b/lib/rb/ext/memory_buffer.c new file mode 100644 index 000000000..624012d4b --- /dev/null +++ b/lib/rb/ext/memory_buffer.c @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <ruby.h> +#include <constants.h> +#include "macros.h" + +ID buf_ivar_id; +ID index_ivar_id; + +ID slice_method_id; + +int GARBAGE_BUFFER_SIZE; + +#define GET_BUF(self) rb_ivar_get(self, buf_ivar_id) + +VALUE rb_thrift_memory_buffer_write(VALUE self, VALUE str) { + VALUE buf = GET_BUF(self); + rb_str_buf_cat(buf, RSTRING_PTR(str), RSTRING_LEN(str)); + return Qnil; +} + +VALUE rb_thrift_memory_buffer_read(VALUE self, VALUE length_value) { + int length = FIX2INT(length_value); + + VALUE index_value = rb_ivar_get(self, index_ivar_id); + int index = FIX2INT(index_value); + + VALUE buf = GET_BUF(self); + VALUE data = rb_funcall(buf, slice_method_id, 2, index_value, length_value); + + index += length; + if (index > RSTRING_LEN(buf)) { + index = RSTRING_LEN(buf); + } + if (index >= GARBAGE_BUFFER_SIZE) { + rb_ivar_set(self, buf_ivar_id, rb_funcall(buf, slice_method_id, 2, INT2FIX(index), INT2FIX(RSTRING_LEN(buf) - 1))); + index = 0; + } + + rb_ivar_set(self, index_ivar_id, INT2FIX(index)); + return data; +} + +void Init_memory_buffer() { + VALUE thrift_memory_buffer_class = rb_const_get(thrift_module, rb_intern("MemoryBufferTransport")); + rb_define_method(thrift_memory_buffer_class, "write", rb_thrift_memory_buffer_write, 1); + rb_define_method(thrift_memory_buffer_class, "read", rb_thrift_memory_buffer_read, 1); + + buf_ivar_id = rb_intern("@buf"); + index_ivar_id = rb_intern("@index"); + + slice_method_id = rb_intern("slice"); + + GARBAGE_BUFFER_SIZE = FIX2INT(rb_const_get(thrift_memory_buffer_class, rb_intern("GARBAGE_BUFFER_SIZE"))); +} diff --git a/lib/rb/ext/memory_buffer.h b/lib/rb/ext/memory_buffer.h new file mode 100644 index 000000000..b277fa6f6 --- /dev/null +++ b/lib/rb/ext/memory_buffer.h @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +void Init_memory_buffer(); diff --git a/lib/rb/ext/protocol.c b/lib/rb/ext/protocol.c new file mode 100644 index 000000000..c1876541c --- /dev/null +++ b/lib/rb/ext/protocol.c @@ -0,0 +1,185 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <ruby.h> +#include <protocol.h> +#include <stdbool.h> +#include <constants.h> +#include <struct.h> + +static VALUE skip(VALUE self, int ttype) { + if (ttype == TTYPE_STOP) { + return Qnil; + } else if (ttype == TTYPE_BOOL) { + rb_funcall(self, read_bool_method_id, 0); + } else if (ttype == TTYPE_BYTE) { + rb_funcall(self, read_byte_method_id, 0); + } else if (ttype == TTYPE_I16) { + rb_funcall(self, read_i16_method_id, 0); + } else if (ttype == TTYPE_I32) { + rb_funcall(self, read_i32_method_id, 0); + } else if (ttype == TTYPE_I64) { + rb_funcall(self, read_i64_method_id, 0); + } else if (ttype == TTYPE_DOUBLE) { + rb_funcall(self, read_double_method_id, 0); + } else if (ttype == TTYPE_STRING) { + rb_funcall(self, read_string_method_id, 0); + } else if (ttype == TTYPE_STRUCT) { + rb_funcall(self, read_struct_begin_method_id, 0); + while (true) { + VALUE field_header = rb_funcall(self, read_field_begin_method_id, 0); + if (NIL_P(field_header) || FIX2INT(rb_ary_entry(field_header, 1)) == TTYPE_STOP ) { + break; + } + skip(self, FIX2INT(rb_ary_entry(field_header, 1))); + rb_funcall(self, read_field_end_method_id, 0); + } + rb_funcall(self, read_struct_end_method_id, 0); + } else if (ttype == TTYPE_MAP) { + int i; + VALUE map_header = rb_funcall(self, read_map_begin_method_id, 0); + int ktype = FIX2INT(rb_ary_entry(map_header, 0)); + int vtype = FIX2INT(rb_ary_entry(map_header, 1)); + int size = FIX2INT(rb_ary_entry(map_header, 2)); + + for (i = 0; i < size; i++) { + skip(self, ktype); + skip(self, vtype); + } + rb_funcall(self, read_map_end_method_id, 0); + } else if (ttype == TTYPE_LIST || ttype == TTYPE_SET) { + int i; + VALUE collection_header = rb_funcall(self, ttype == TTYPE_LIST ? read_list_begin_method_id : read_set_begin_method_id, 0); + int etype = FIX2INT(rb_ary_entry(collection_header, 0)); + int size = FIX2INT(rb_ary_entry(collection_header, 1)); + for (i = 0; i < size; i++) { + skip(self, etype); + } + rb_funcall(self, ttype == TTYPE_LIST ? read_list_end_method_id : read_set_end_method_id, 0); + } else { + rb_raise(rb_eNotImpError, "don't know how to skip type %d", ttype); + } + + return Qnil; +} + +VALUE rb_thrift_protocol_native_qmark(VALUE self) { + return Qfalse; +} + +VALUE rb_thrift_protocol_skip(VALUE protocol, VALUE ttype) { + return skip(protocol, FIX2INT(ttype)); +} + +VALUE rb_thrift_write_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_struct_begin(VALUE self, VALUE name) { + return Qnil; +} + +VALUE rb_thrift_write_struct_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_read_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_struct_begin(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_struct_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_set_end(VALUE self) { + return Qnil; +} + +void Init_protocol() { + VALUE c_protocol = rb_const_get(thrift_module, rb_intern("BaseProtocol")); + + rb_define_method(c_protocol, "skip", rb_thrift_protocol_skip, 1); + rb_define_method(c_protocol, "write_message_end", rb_thrift_write_message_end, 0); + rb_define_method(c_protocol, "write_struct_begin", rb_thrift_write_struct_begin, 1); + rb_define_method(c_protocol, "write_struct_end", rb_thrift_write_struct_end, 0); + rb_define_method(c_protocol, "write_field_end", rb_thrift_write_field_end, 0); + rb_define_method(c_protocol, "write_map_end", rb_thrift_write_map_end, 0); + rb_define_method(c_protocol, "write_list_end", rb_thrift_write_list_end, 0); + rb_define_method(c_protocol, "write_set_end", rb_thrift_write_set_end, 0); + rb_define_method(c_protocol, "read_message_end", rb_thrift_read_message_end, 0); + rb_define_method(c_protocol, "read_struct_begin", rb_thift_read_struct_begin, 0); + rb_define_method(c_protocol, "read_struct_end", rb_thift_read_struct_end, 0); + rb_define_method(c_protocol, "read_field_end", rb_thift_read_field_end, 0); + rb_define_method(c_protocol, "read_map_end", rb_thift_read_map_end, 0); + rb_define_method(c_protocol, "read_list_end", rb_thift_read_list_end, 0); + rb_define_method(c_protocol, "read_set_end", rb_thift_read_set_end, 0); + rb_define_method(c_protocol, "native?", rb_thrift_protocol_native_qmark, 0); + + // native_proto_method_table *npmt; + // npmt = ALLOC(native_proto_method_table); + // npmt->write_message_end = rb_thrift_write_message_end; + // npmt->write_struct_begin = rb_thrift_write_struct_begin; + // npmt->write_struct_end = rb_thrift_write_struct_end; + // npmt->write_field_end = rb_thrift_write_field_end; + // npmt->write_map_end = rb_thrift_write_map_end; + // npmt->write_list_end = rb_thrift_write_list_end; + // npmt->write_set_end = rb_thrift_write_set_end; + // npmt->read_message_end = rb_thrift_read_message_end; + // npmt->read_struct_begin = rb_thift_read_struct_begin; + // npmt->read_struct_end = rb_thift_read_struct_end; + // npmt->read_field_end = rb_thift_read_field_end; + // npmt->read_map_end = rb_thift_read_map_end; + // npmt->read_list_end = rb_thift_read_list_end; + // npmt->read_set_end = rb_thift_read_set_end; + // + // VALUE method_table_object = Data_Wrap_Struct(rb_cObject, 0, free, npmt); + // rb_const_set(c_protocol, rb_intern("@native_method_table"), method_table_object); +} diff --git a/lib/rb/ext/protocol.h b/lib/rb/ext/protocol.h new file mode 100644 index 000000000..536953036 --- /dev/null +++ b/lib/rb/ext/protocol.h @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +void Init_protocol(); diff --git a/lib/rb/ext/struct.c b/lib/rb/ext/struct.c new file mode 100644 index 000000000..fee285e97 --- /dev/null +++ b/lib/rb/ext/struct.c @@ -0,0 +1,605 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <struct.h> +#include <constants.h> +#include "macros.h" + +#ifndef HAVE_STRLCPY + +static +size_t +strlcpy (char *dst, const char *src, size_t dst_sz) +{ + size_t n; + + for (n = 0; n < dst_sz; n++) { + if ((*dst++ = *src++) == '\0') + break; + } + + if (n < dst_sz) + return n; + if (n > 0) + *(dst - 1) = '\0'; + return n + strlen (src); +} + +#endif + +static native_proto_method_table *mt; +static native_proto_method_table *default_mt; +static VALUE last_proto_class = Qnil; + +#define IS_CONTAINER(ttype) ((ttype) == TTYPE_MAP || (ttype) == TTYPE_LIST || (ttype) == TTYPE_SET) +#define STRUCT_FIELDS(obj) rb_const_get(CLASS_OF(obj), fields_const_id) + +static void set_native_proto_function_pointers(VALUE protocol) { + VALUE method_table_object = rb_const_get(CLASS_OF(protocol), rb_intern("@native_method_table")); + // TODO: check nil? + Data_Get_Struct(method_table_object, native_proto_method_table, mt); +} + +static void check_native_proto_method_table(VALUE protocol) { + VALUE protoclass = CLASS_OF(protocol); + if (protoclass != last_proto_class) { + last_proto_class = protoclass; + if (rb_funcall(protocol, native_qmark_method_id, 0) == Qtrue) { + set_native_proto_function_pointers(protocol); + } else { + mt = default_mt; + } + } +} + +//------------------------------------------- +// Writing section +//------------------------------------------- + +// default fn pointers for protocol stuff here + +VALUE default_write_bool(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_boolean_method_id, 1, value); + return Qnil; +} + +VALUE default_write_byte(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_byte_method_id, 1, value); + return Qnil; +} + +VALUE default_write_i16(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_i16_method_id, 1, value); + return Qnil; +} + +VALUE default_write_i32(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_i32_method_id, 1, value); + return Qnil; +} + +VALUE default_write_i64(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_i64_method_id, 1, value); + return Qnil; +} + +VALUE default_write_double(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_double_method_id, 1, value); + return Qnil; +} + +VALUE default_write_string(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_string_method_id, 1, value); + return Qnil; +} + +VALUE default_write_list_begin(VALUE protocol, VALUE etype, VALUE length) { + rb_funcall(protocol, write_list_begin_method_id, 2, etype, length); + return Qnil; +} + +VALUE default_write_list_end(VALUE protocol) { + rb_funcall(protocol, write_list_end_method_id, 0); + return Qnil; +} + +VALUE default_write_set_begin(VALUE protocol, VALUE etype, VALUE length) { + rb_funcall(protocol, write_set_begin_method_id, 2, etype, length); + return Qnil; +} + +VALUE default_write_set_end(VALUE protocol) { + rb_funcall(protocol, write_set_end_method_id, 0); + return Qnil; +} + +VALUE default_write_map_begin(VALUE protocol, VALUE ktype, VALUE vtype, VALUE length) { + rb_funcall(protocol, write_map_begin_method_id, 3, ktype, vtype, length); + return Qnil; +} + +VALUE default_write_map_end(VALUE protocol) { + rb_funcall(protocol, write_map_end_method_id, 0); + return Qnil; +} + +VALUE default_write_struct_begin(VALUE protocol, VALUE struct_name) { + rb_funcall(protocol, write_struct_begin_method_id, 1, struct_name); + return Qnil; +} + +VALUE default_write_struct_end(VALUE protocol) { + rb_funcall(protocol, write_struct_end_method_id, 0); + return Qnil; +} + +VALUE default_write_field_begin(VALUE protocol, VALUE name, VALUE type, VALUE id) { + rb_funcall(protocol, write_field_begin_method_id, 3, name, type, id); + return Qnil; +} + +VALUE default_write_field_end(VALUE protocol) { + rb_funcall(protocol, write_field_end_method_id, 0); + return Qnil; +} + +VALUE default_write_field_stop(VALUE protocol) { + rb_funcall(protocol, write_field_stop_method_id, 0); + return Qnil; +} + +VALUE default_read_field_begin(VALUE protocol) { + return rb_funcall(protocol, read_field_begin_method_id, 0); +} + +VALUE default_read_field_end(VALUE protocol) { + return rb_funcall(protocol, read_field_end_method_id, 0); +} + +VALUE default_read_map_begin(VALUE protocol) { + return rb_funcall(protocol, read_map_begin_method_id, 0); +} + +VALUE default_read_map_end(VALUE protocol) { + return rb_funcall(protocol, read_map_end_method_id, 0); +} + +VALUE default_read_list_begin(VALUE protocol) { + return rb_funcall(protocol, read_list_begin_method_id, 0); +} + +VALUE default_read_list_end(VALUE protocol) { + return rb_funcall(protocol, read_list_end_method_id, 0); +} + +VALUE default_read_set_begin(VALUE protocol) { + return rb_funcall(protocol, read_set_begin_method_id, 0); +} + +VALUE default_read_set_end(VALUE protocol) { + return rb_funcall(protocol, read_set_end_method_id, 0); +} + +VALUE default_read_byte(VALUE protocol) { + return rb_funcall(protocol, read_byte_method_id, 0); +} + +VALUE default_read_bool(VALUE protocol) { + return rb_funcall(protocol, read_bool_method_id, 0); +} + +VALUE default_read_i16(VALUE protocol) { + return rb_funcall(protocol, read_i16_method_id, 0); +} + +VALUE default_read_i32(VALUE protocol) { + return rb_funcall(protocol, read_i32_method_id, 0); +} + +VALUE default_read_i64(VALUE protocol) { + return rb_funcall(protocol, read_i64_method_id, 0); +} + +VALUE default_read_double(VALUE protocol) { + return rb_funcall(protocol, read_double_method_id, 0); +} + +VALUE default_read_string(VALUE protocol) { + return rb_funcall(protocol, read_string_method_id, 0); +} + +VALUE default_read_struct_begin(VALUE protocol) { + return rb_funcall(protocol, read_struct_begin_method_id, 0); +} + +VALUE default_read_struct_end(VALUE protocol) { + return rb_funcall(protocol, read_struct_end_method_id, 0); +} + +static void set_default_proto_function_pointers() { + default_mt = ALLOC(native_proto_method_table); + + default_mt->write_field_begin = default_write_field_begin; + default_mt->write_field_stop = default_write_field_stop; + default_mt->write_map_begin = default_write_map_begin; + default_mt->write_map_end = default_write_map_end; + default_mt->write_list_begin = default_write_list_begin; + default_mt->write_list_end = default_write_list_end; + default_mt->write_set_begin = default_write_set_begin; + default_mt->write_set_end = default_write_set_end; + default_mt->write_byte = default_write_byte; + default_mt->write_bool = default_write_bool; + default_mt->write_i16 = default_write_i16; + default_mt->write_i32 = default_write_i32; + default_mt->write_i64 = default_write_i64; + default_mt->write_double = default_write_double; + default_mt->write_string = default_write_string; + default_mt->write_struct_begin = default_write_struct_begin; + default_mt->write_struct_end = default_write_struct_end; + default_mt->write_field_end = default_write_field_end; + + default_mt->read_struct_begin = default_read_struct_begin; + default_mt->read_struct_end = default_read_struct_end; + default_mt->read_field_begin = default_read_field_begin; + default_mt->read_field_end = default_read_field_end; + default_mt->read_map_begin = default_read_map_begin; + default_mt->read_map_end = default_read_map_end; + default_mt->read_list_begin = default_read_list_begin; + default_mt->read_list_end = default_read_list_end; + default_mt->read_set_begin = default_read_set_begin; + default_mt->read_set_end = default_read_set_end; + default_mt->read_byte = default_read_byte; + default_mt->read_bool = default_read_bool; + default_mt->read_i16 = default_read_i16; + default_mt->read_i32 = default_read_i32; + default_mt->read_i64 = default_read_i64; + default_mt->read_double = default_read_double; + default_mt->read_string = default_read_string; +} + +// end default protocol methods + + +static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol); +static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info); + +VALUE get_field_value(VALUE obj, VALUE field_name) { + char name_buf[RSTRING_LEN(field_name) + 1]; + + name_buf[0] = '@'; + strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf)); + + VALUE value = rb_ivar_get(obj, rb_intern(name_buf)); + + return value; +} + +static void write_container(int ttype, VALUE field_info, VALUE value, VALUE protocol) { + int sz, i; + + if (ttype == TTYPE_MAP) { + VALUE keys; + VALUE key; + VALUE val; + + Check_Type(value, T_HASH); + + VALUE key_info = rb_hash_aref(field_info, key_sym); + VALUE keytype_value = rb_hash_aref(key_info, type_sym); + int keytype = FIX2INT(keytype_value); + + VALUE value_info = rb_hash_aref(field_info, value_sym); + VALUE valuetype_value = rb_hash_aref(value_info, type_sym); + int valuetype = FIX2INT(valuetype_value); + + keys = rb_funcall(value, keys_method_id, 0); + + sz = RARRAY_LEN(keys); + + mt->write_map_begin(protocol, keytype_value, valuetype_value, INT2FIX(sz)); + + for (i = 0; i < sz; i++) { + key = rb_ary_entry(keys, i); + val = rb_hash_aref(value, key); + + if (IS_CONTAINER(keytype)) { + write_container(keytype, key_info, key, protocol); + } else { + write_anything(keytype, key, protocol, key_info); + } + + if (IS_CONTAINER(valuetype)) { + write_container(valuetype, value_info, val, protocol); + } else { + write_anything(valuetype, val, protocol, value_info); + } + } + + mt->write_map_end(protocol); + } else if (ttype == TTYPE_LIST) { + Check_Type(value, T_ARRAY); + + sz = RARRAY_LEN(value); + + VALUE element_type_info = rb_hash_aref(field_info, element_sym); + VALUE element_type_value = rb_hash_aref(element_type_info, type_sym); + int element_type = FIX2INT(element_type_value); + + mt->write_list_begin(protocol, element_type_value, INT2FIX(sz)); + for (i = 0; i < sz; ++i) { + VALUE val = rb_ary_entry(value, i); + if (IS_CONTAINER(element_type)) { + write_container(element_type, element_type_info, val, protocol); + } else { + write_anything(element_type, val, protocol, element_type_info); + } + } + mt->write_list_end(protocol); + } else if (ttype == TTYPE_SET) { + VALUE items; + + if (TYPE(value) == T_ARRAY) { + items = value; + } else { + if (rb_cSet == CLASS_OF(value)) { + items = rb_funcall(value, entries_method_id, 0); + } else { + Check_Type(value, T_HASH); + items = rb_funcall(value, keys_method_id, 0); + } + } + + sz = RARRAY_LEN(items); + + VALUE element_type_info = rb_hash_aref(field_info, element_sym); + VALUE element_type_value = rb_hash_aref(element_type_info, type_sym); + int element_type = FIX2INT(element_type_value); + + mt->write_set_begin(protocol, element_type_value, INT2FIX(sz)); + + for (i = 0; i < sz; i++) { + VALUE val = rb_ary_entry(items, i); + if (IS_CONTAINER(element_type)) { + write_container(element_type, element_type_info, val, protocol); + } else { + write_anything(element_type, val, protocol, element_type_info); + } + } + + mt->write_set_end(protocol); + } else { + rb_raise(rb_eNotImpError, "can't write container of type: %d", ttype); + } +} + +static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info) { + if (ttype == TTYPE_BOOL) { + mt->write_bool(protocol, value); + } else if (ttype == TTYPE_BYTE) { + mt->write_byte(protocol, value); + } else if (ttype == TTYPE_I16) { + mt->write_i16(protocol, value); + } else if (ttype == TTYPE_I32) { + mt->write_i32(protocol, value); + } else if (ttype == TTYPE_I64) { + mt->write_i64(protocol, value); + } else if (ttype == TTYPE_DOUBLE) { + mt->write_double(protocol, value); + } else if (ttype == TTYPE_STRING) { + mt->write_string(protocol, value); + } else if (IS_CONTAINER(ttype)) { + write_container(ttype, field_info, value, protocol); + } else if (ttype == TTYPE_STRUCT) { + rb_thrift_struct_write(value, protocol); + } else { + rb_raise(rb_eNotImpError, "Unknown type for binary_encoding: %d", ttype); + } +} + +static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol) { + // call validate + rb_funcall(self, validate_method_id, 0); + + check_native_proto_method_table(protocol); + + // write struct begin + mt->write_struct_begin(protocol, rb_class_name(CLASS_OF(self))); + + // iterate through all the fields here + VALUE struct_fields = STRUCT_FIELDS(self); + VALUE struct_field_ids_unordered = rb_funcall(struct_fields, keys_method_id, 0); + VALUE struct_field_ids_ordered = rb_funcall(struct_field_ids_unordered, sort_method_id, 0); + + int i = 0; + for (i=0; i < RARRAY_LEN(struct_field_ids_ordered); i++) { + VALUE field_id = rb_ary_entry(struct_field_ids_ordered, i); + VALUE field_info = rb_hash_aref(struct_fields, field_id); + + VALUE ttype_value = rb_hash_aref(field_info, type_sym); + int ttype = FIX2INT(ttype_value); + VALUE field_name = rb_hash_aref(field_info, name_sym); + VALUE field_value = get_field_value(self, field_name); + + if (!NIL_P(field_value)) { + mt->write_field_begin(protocol, field_name, ttype_value, field_id); + + write_anything(ttype, field_value, protocol, field_info); + + mt->write_field_end(protocol); + } + } + + mt->write_field_stop(protocol); + + // write struct end + mt->write_struct_end(protocol); + + return Qnil; +} + +//------------------------------------------- +// Reading section +//------------------------------------------- + +static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol); + +static void set_field_value(VALUE obj, VALUE field_name, VALUE value) { + char name_buf[RSTRING_LEN(field_name) + 1]; + + name_buf[0] = '@'; + strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf)); + + rb_ivar_set(obj, rb_intern(name_buf), value); +} + +static VALUE read_anything(VALUE protocol, int ttype, VALUE field_info) { + VALUE result = Qnil; + + if (ttype == TTYPE_BOOL) { + result = mt->read_bool(protocol); + } else if (ttype == TTYPE_BYTE) { + result = mt->read_byte(protocol); + } else if (ttype == TTYPE_I16) { + result = mt->read_i16(protocol); + } else if (ttype == TTYPE_I32) { + result = mt->read_i32(protocol); + } else if (ttype == TTYPE_I64) { + result = mt->read_i64(protocol); + } else if (ttype == TTYPE_STRING) { + result = mt->read_string(protocol); + } else if (ttype == TTYPE_DOUBLE) { + result = mt->read_double(protocol); + } else if (ttype == TTYPE_STRUCT) { + VALUE klass = rb_hash_aref(field_info, class_sym); + result = rb_class_new_instance(0, NULL, klass); + rb_thrift_struct_read(result, protocol); + } else if (ttype == TTYPE_MAP) { + int i; + + VALUE map_header = mt->read_map_begin(protocol); + int key_ttype = FIX2INT(rb_ary_entry(map_header, 0)); + int value_ttype = FIX2INT(rb_ary_entry(map_header, 1)); + int num_entries = FIX2INT(rb_ary_entry(map_header, 2)); + + VALUE key_info = rb_hash_aref(field_info, key_sym); + VALUE value_info = rb_hash_aref(field_info, value_sym); + + result = rb_hash_new(); + + for (i = 0; i < num_entries; ++i) { + VALUE key, val; + + key = read_anything(protocol, key_ttype, key_info); + val = read_anything(protocol, value_ttype, value_info); + + rb_hash_aset(result, key, val); + } + + mt->read_map_end(protocol); + } else if (ttype == TTYPE_LIST) { + int i; + + VALUE list_header = mt->read_list_begin(protocol); + int element_ttype = FIX2INT(rb_ary_entry(list_header, 0)); + int num_elements = FIX2INT(rb_ary_entry(list_header, 1)); + result = rb_ary_new2(num_elements); + + for (i = 0; i < num_elements; ++i) { + rb_ary_push(result, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym))); + } + + + mt->read_list_end(protocol); + } else if (ttype == TTYPE_SET) { + VALUE items; + int i; + + VALUE set_header = mt->read_set_begin(protocol); + int element_ttype = FIX2INT(rb_ary_entry(set_header, 0)); + int num_elements = FIX2INT(rb_ary_entry(set_header, 1)); + items = rb_ary_new2(num_elements); + + for (i = 0; i < num_elements; ++i) { + rb_ary_push(items, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym))); + } + + + mt->read_set_end(protocol); + + result = rb_class_new_instance(1, &items, rb_cSet); + } else { + rb_raise(rb_eNotImpError, "read_anything not implemented for type %d!", ttype); + } + + return result; +} + +static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol) { + check_native_proto_method_table(protocol); + + // read struct begin + mt->read_struct_begin(protocol); + + VALUE struct_fields = STRUCT_FIELDS(self); + + // read each field + while (true) { + VALUE field_header = mt->read_field_begin(protocol); + VALUE field_type_value = rb_ary_entry(field_header, 1); + int field_type = FIX2INT(field_type_value); + + if (field_type == TTYPE_STOP) { + break; + } + + // make sure we got a type we expected + VALUE field_info = rb_hash_aref(struct_fields, rb_ary_entry(field_header, 2)); + + if (!NIL_P(field_info)) { + int specified_type = FIX2INT(rb_hash_aref(field_info, type_sym)); + if (field_type == specified_type) { + // read the value + VALUE name = rb_hash_aref(field_info, name_sym); + set_field_value(self, name, read_anything(protocol, field_type, field_info)); + } else { + rb_funcall(protocol, skip_method_id, 1, field_type_value); + } + } else { + rb_funcall(protocol, skip_method_id, 1, field_type_value); + } + + // read field end + mt->read_field_end(protocol); + } + + // read struct end + mt->read_struct_end(protocol); + + return Qnil; +} + +void Init_struct() { + VALUE struct_module = rb_const_get(thrift_module, rb_intern("Struct")); + + rb_define_method(struct_module, "write", rb_thrift_struct_write, 1); + rb_define_method(struct_module, "read", rb_thrift_struct_read, 1); + + set_default_proto_function_pointers(); +} + diff --git a/lib/rb/ext/struct.h b/lib/rb/ext/struct.h new file mode 100644 index 000000000..37b1b35b2 --- /dev/null +++ b/lib/rb/ext/struct.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <stdbool.h> +#include <ruby.h> + +typedef struct native_proto_method_table { + VALUE (*write_bool)(VALUE, VALUE); + VALUE (*write_byte)(VALUE, VALUE); + VALUE (*write_i16)(VALUE, VALUE); + VALUE (*write_i32)(VALUE, VALUE); + VALUE (*write_i64)(VALUE, VALUE); + VALUE (*write_double)(VALUE, VALUE); + VALUE (*write_string)(VALUE, VALUE); + VALUE (*write_list_begin)(VALUE, VALUE, VALUE); + VALUE (*write_list_end)(VALUE); + VALUE (*write_set_begin)(VALUE, VALUE, VALUE); + VALUE (*write_set_end)(VALUE); + VALUE (*write_map_begin)(VALUE, VALUE, VALUE, VALUE); + VALUE (*write_map_end)(VALUE); + VALUE (*write_struct_begin)(VALUE, VALUE); + VALUE (*write_struct_end)(VALUE); + VALUE (*write_field_begin)(VALUE, VALUE, VALUE, VALUE); + VALUE (*write_field_end)(VALUE); + VALUE (*write_field_stop)(VALUE); + VALUE (*write_message_begin)(VALUE, VALUE, VALUE, VALUE); + VALUE (*write_message_end)(VALUE); + + VALUE (*read_message_begin)(VALUE); + VALUE (*read_message_end)(VALUE); + VALUE (*read_field_begin)(VALUE); + VALUE (*read_field_end)(VALUE); + VALUE (*read_map_begin)(VALUE); + VALUE (*read_map_end)(VALUE); + VALUE (*read_list_begin)(VALUE); + VALUE (*read_list_end)(VALUE); + VALUE (*read_set_begin)(VALUE); + VALUE (*read_set_end)(VALUE); + VALUE (*read_byte)(VALUE); + VALUE (*read_bool)(VALUE); + VALUE (*read_i16)(VALUE); + VALUE (*read_i32)(VALUE); + VALUE (*read_i64)(VALUE); + VALUE (*read_double)(VALUE); + VALUE (*read_string)(VALUE); + VALUE (*read_struct_begin)(VALUE); + VALUE (*read_struct_end)(VALUE); + +} native_proto_method_table; + +void Init_struct(); diff --git a/lib/rb/ext/thrift_native.c b/lib/rb/ext/thrift_native.c new file mode 100644 index 000000000..effa202c4 --- /dev/null +++ b/lib/rb/ext/thrift_native.c @@ -0,0 +1,194 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <ruby.h> +#include <struct.h> +#include <binary_protocol_accelerated.h> +#include <compact_protocol.h> +#include <protocol.h> +#include <memory_buffer.h> + +// cached classes/modules +VALUE rb_cSet; +VALUE thrift_module; +VALUE thrift_types_module; + +// TType constants +int TTYPE_STOP; +int TTYPE_BOOL; +int TTYPE_BYTE; +int TTYPE_I16; +int TTYPE_I32; +int TTYPE_I64; +int TTYPE_DOUBLE; +int TTYPE_STRING; +int TTYPE_MAP; +int TTYPE_SET; +int TTYPE_LIST; +int TTYPE_STRUCT; + +// method ids +ID validate_method_id; +ID write_struct_begin_method_id; +ID write_struct_end_method_id; +ID write_field_begin_method_id; +ID write_field_end_method_id; +ID write_boolean_method_id; +ID write_byte_method_id; +ID write_i16_method_id; +ID write_i32_method_id; +ID write_i64_method_id; +ID write_double_method_id; +ID write_string_method_id; +ID write_map_begin_method_id; +ID write_map_end_method_id; +ID write_list_begin_method_id; +ID write_list_end_method_id; +ID write_set_begin_method_id; +ID write_set_end_method_id; +ID size_method_id; +ID read_bool_method_id; +ID read_byte_method_id; +ID read_i16_method_id; +ID read_i32_method_id; +ID read_i64_method_id; +ID read_string_method_id; +ID read_double_method_id; +ID read_map_begin_method_id; +ID read_map_end_method_id; +ID read_list_begin_method_id; +ID read_list_end_method_id; +ID read_set_begin_method_id; +ID read_set_end_method_id; +ID read_struct_begin_method_id; +ID read_struct_end_method_id; +ID read_field_begin_method_id; +ID read_field_end_method_id; +ID keys_method_id; +ID entries_method_id; +ID name_method_id; +ID sort_method_id; +ID write_field_stop_method_id; +ID skip_method_id; +ID write_method_id; +ID read_all_method_id; +ID native_qmark_method_id; + +// constant ids +ID fields_const_id; +ID transport_ivar_id; +ID strict_read_ivar_id; +ID strict_write_ivar_id; + +// cached symbols +VALUE type_sym; +VALUE name_sym; +VALUE key_sym; +VALUE value_sym; +VALUE element_sym; +VALUE class_sym; +VALUE protocol_exception_class; + +void Init_thrift_native() { + // cached classes + thrift_module = rb_const_get(rb_cObject, rb_intern("Thrift")); + thrift_types_module = rb_const_get(thrift_module, rb_intern("Types")); + rb_cSet = rb_const_get(rb_cObject, rb_intern("Set")); + protocol_exception_class = rb_const_get(thrift_module, rb_intern("ProtocolException")); + + // Init ttype constants + TTYPE_BOOL = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BOOL"))); + TTYPE_BYTE = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BYTE"))); + TTYPE_I16 = FIX2INT(rb_const_get(thrift_types_module, rb_intern("I16"))); + TTYPE_I32 = FIX2INT(rb_const_get(thrift_types_module, rb_intern("I32"))); + TTYPE_I64 = FIX2INT(rb_const_get(thrift_types_module, rb_intern("I64"))); + TTYPE_DOUBLE = FIX2INT(rb_const_get(thrift_types_module, rb_intern("DOUBLE"))); + TTYPE_STRING = FIX2INT(rb_const_get(thrift_types_module, rb_intern("STRING"))); + TTYPE_MAP = FIX2INT(rb_const_get(thrift_types_module, rb_intern("MAP"))); + TTYPE_SET = FIX2INT(rb_const_get(thrift_types_module, rb_intern("SET"))); + TTYPE_LIST = FIX2INT(rb_const_get(thrift_types_module, rb_intern("LIST"))); + TTYPE_STRUCT = FIX2INT(rb_const_get(thrift_types_module, rb_intern("STRUCT"))); + + // method ids + validate_method_id = rb_intern("validate"); + write_struct_begin_method_id = rb_intern("write_struct_begin"); + write_struct_end_method_id = rb_intern("write_struct_end"); + write_field_begin_method_id = rb_intern("write_field_begin"); + write_field_end_method_id = rb_intern("write_field_end"); + write_boolean_method_id = rb_intern("write_bool"); + write_byte_method_id = rb_intern("write_byte"); + write_i16_method_id = rb_intern("write_i16"); + write_i32_method_id = rb_intern("write_i32"); + write_i64_method_id = rb_intern("write_i64"); + write_double_method_id = rb_intern("write_double"); + write_string_method_id = rb_intern("write_string"); + write_map_begin_method_id = rb_intern("write_map_begin"); + write_map_end_method_id = rb_intern("write_map_end"); + write_list_begin_method_id = rb_intern("write_list_begin"); + write_list_end_method_id = rb_intern("write_list_end"); + write_set_begin_method_id = rb_intern("write_set_begin"); + write_set_end_method_id = rb_intern("write_set_end"); + size_method_id = rb_intern("size"); + read_bool_method_id = rb_intern("read_bool"); + read_byte_method_id = rb_intern("read_byte"); + read_i16_method_id = rb_intern("read_i16"); + read_i32_method_id = rb_intern("read_i32"); + read_i64_method_id = rb_intern("read_i64"); + read_string_method_id = rb_intern("read_string"); + read_double_method_id = rb_intern("read_double"); + read_map_begin_method_id = rb_intern("read_map_begin"); + read_map_end_method_id = rb_intern("read_map_end"); + read_list_begin_method_id = rb_intern("read_list_begin"); + read_list_end_method_id = rb_intern("read_list_end"); + read_set_begin_method_id = rb_intern("read_set_begin"); + read_set_end_method_id = rb_intern("read_set_end"); + read_struct_begin_method_id = rb_intern("read_struct_begin"); + read_struct_end_method_id = rb_intern("read_struct_end"); + read_field_begin_method_id = rb_intern("read_field_begin"); + read_field_end_method_id = rb_intern("read_field_end"); + keys_method_id = rb_intern("keys"); + entries_method_id = rb_intern("entries"); + name_method_id = rb_intern("name"); + sort_method_id = rb_intern("sort"); + write_field_stop_method_id = rb_intern("write_field_stop"); + skip_method_id = rb_intern("skip"); + write_method_id = rb_intern("write"); + read_all_method_id = rb_intern("read_all"); + native_qmark_method_id = rb_intern("native?"); + + // constant ids + fields_const_id = rb_intern("FIELDS"); + transport_ivar_id = rb_intern("@trans"); + strict_read_ivar_id = rb_intern("@strict_read"); + strict_write_ivar_id = rb_intern("@strict_write"); + + // cached symbols + type_sym = ID2SYM(rb_intern("type")); + name_sym = ID2SYM(rb_intern("name")); + key_sym = ID2SYM(rb_intern("key")); + value_sym = ID2SYM(rb_intern("value")); + element_sym = ID2SYM(rb_intern("element")); + class_sym = ID2SYM(rb_intern("class")); + + Init_protocol(); + Init_struct(); + Init_binary_protocol_accelerated(); + Init_compact_protocol(); + Init_memory_buffer(); +} diff --git a/lib/rb/lib/thrift.rb b/lib/rb/lib/thrift.rb new file mode 100644 index 000000000..88562e137 --- /dev/null +++ b/lib/rb/lib/thrift.rb @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.unshift File.dirname(__FILE__) + +require 'thrift/core_ext' +require 'thrift/exceptions' +require 'thrift/types' +require 'thrift/processor' +require 'thrift/client' +require 'thrift/struct' + +# serializer +require 'thrift/serializer/serializer' +require 'thrift/serializer/deserializer' + +# protocol +require 'thrift/protocol/base_protocol' +require 'thrift/protocol/binary_protocol' +require 'thrift/protocol/binary_protocol_accelerated' +require 'thrift/protocol/compact_protocol' + +# transport +require 'thrift/transport/base_transport' +require 'thrift/transport/base_server_transport' +require 'thrift/transport/socket' +require 'thrift/transport/server_socket' +require 'thrift/transport/unix_socket' +require 'thrift/transport/unix_server_socket' +require 'thrift/transport/buffered_transport' +require 'thrift/transport/framed_transport' +require 'thrift/transport/http_client_transport' +require 'thrift/transport/io_stream_transport' +require 'thrift/transport/memory_buffer_transport' + +# server +require 'thrift/server/base_server' +require 'thrift/server/nonblocking_server' +require 'thrift/server/simple_server' +require 'thrift/server/threaded_server' +require 'thrift/server/thread_pool_server' + +require 'thrift/thrift_native'
\ No newline at end of file diff --git a/lib/rb/lib/thrift/client.rb b/lib/rb/lib/thrift/client.rb new file mode 100644 index 000000000..5b30f0150 --- /dev/null +++ b/lib/rb/lib/thrift/client.rb @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + module Client + def initialize(iprot, oprot=nil) + @iprot = iprot + @oprot = oprot || iprot + @seqid = 0 + end + + def send_message(name, args_class, args = {}) + @oprot.write_message_begin(name, MessageTypes::CALL, @seqid) + data = args_class.new + args.each do |k, v| + data.send("#{k.to_s}=", v) + end + begin + data.write(@oprot) + rescue StandardError => e + @oprot.trans.close + raise e + end + @oprot.write_message_end + @oprot.trans.flush + end + + def receive_message(result_klass) + fname, mtype, rseqid = @iprot.read_message_begin + handle_exception(mtype) + result = result_klass.new + result.read(@iprot) + @iprot.read_message_end + result + end + + def handle_exception(mtype) + if mtype == MessageTypes::EXCEPTION + x = ApplicationException.new + x.read(@iprot) + @iprot.read_message_end + raise x + end + end + end +end diff --git a/lib/rb/lib/thrift/core_ext.rb b/lib/rb/lib/thrift/core_ext.rb new file mode 100644 index 000000000..f763cd534 --- /dev/null +++ b/lib/rb/lib/thrift/core_ext.rb @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +Dir[File.dirname(__FILE__) + "/core_ext/*.rb"].each do |file| + name = File.basename(file, '.rb') + require "thrift/core_ext/#{name}" +end diff --git a/lib/rb/lib/thrift/core_ext/fixnum.rb b/lib/rb/lib/thrift/core_ext/fixnum.rb new file mode 100644 index 000000000..b4fc90dd6 --- /dev/null +++ b/lib/rb/lib/thrift/core_ext/fixnum.rb @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Versions of ruby pre 1.8.7 do not have an .ord method available in the Fixnum +# class. +# +if RUBY_VERSION < "1.8.7" + class Fixnum + def ord + self + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/exceptions.rb b/lib/rb/lib/thrift/exceptions.rb new file mode 100644 index 000000000..dda70894b --- /dev/null +++ b/lib/rb/lib/thrift/exceptions.rb @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class Exception < StandardError + def initialize(message) + super + @message = message + end + + attr_reader :message + end + + class ApplicationException < Exception + + UNKNOWN = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + + attr_reader :type + + def initialize(type=UNKNOWN, message=nil) + super(message) + @type = type + end + + def read(iprot) + iprot.read_struct_begin + while true + fname, ftype, fid = iprot.read_field_begin + if ftype == Types::STOP + break + end + if fid == 1 and ftype == Types::STRING + @message = iprot.read_string + elsif fid == 2 and ftype == Types::I32 + @type = iprot.read_i32 + else + iprot.skip(ftype) + end + iprot.read_field_end + end + iprot.read_struct_end + end + + def write(oprot) + oprot.write_struct_begin('Thrift::ApplicationException') + unless @message.nil? + oprot.write_field_begin('message', Types::STRING, 1) + oprot.write_string(@message) + oprot.write_field_end + end + unless @type.nil? + oprot.write_field_begin('type', Types::I32, 2) + oprot.write_i32(@type) + oprot.write_field_end + end + oprot.write_field_stop + oprot.write_struct_end + end + + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/processor.rb b/lib/rb/lib/thrift/processor.rb new file mode 100644 index 000000000..5d9e0a11c --- /dev/null +++ b/lib/rb/lib/thrift/processor.rb @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + module Processor + def initialize(handler) + @handler = handler + end + + def process(iprot, oprot) + name, type, seqid = iprot.read_message_begin + if respond_to?("process_#{name}") + send("process_#{name}", seqid, iprot, oprot) + true + else + iprot.skip(Types::STRUCT) + iprot.read_message_end + x = ApplicationException.new(ApplicationException::UNKNOWN_METHOD, 'Unknown function '+name) + oprot.write_message_begin(name, MessageTypes::EXCEPTION, seqid) + x.write(oprot) + oprot.write_message_end + oprot.trans.flush + false + end + end + + def read_args(iprot, args_class) + args = args_class.new + args.read(iprot) + iprot.read_message_end + args + end + + def write_result(result, oprot, name, seqid) + oprot.write_message_begin(name, MessageTypes::REPLY, seqid) + result.write(oprot) + oprot.write_message_end + oprot.trans.flush + end + end +end diff --git a/lib/rb/lib/thrift/protocol/base_protocol.rb b/lib/rb/lib/thrift/protocol/base_protocol.rb new file mode 100644 index 000000000..b19909d5f --- /dev/null +++ b/lib/rb/lib/thrift/protocol/base_protocol.rb @@ -0,0 +1,290 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# this require is to make generated struct definitions happy +require 'set' + +module Thrift + class ProtocolException < Exception + + UNKNOWN = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + + attr_reader :type + + def initialize(type=UNKNOWN, message=nil) + super(message) + @type = type + end + end + + class BaseProtocol + + attr_reader :trans + + def initialize(trans) + @trans = trans + end + + def native? + puts "wrong method is being called!" + false + end + + def write_message_begin(name, type, seqid) + raise NotImplementedError + end + + def write_message_end; nil; end + + def write_struct_begin(name) + raise NotImplementedError + end + + def write_struct_end; nil; end + + def write_field_begin(name, type, id) + raise NotImplementedError + end + + def write_field_end; nil; end + + def write_field_stop + raise NotImplementedError + end + + def write_map_begin(ktype, vtype, size) + raise NotImplementedError + end + + def write_map_end; nil; end + + def write_list_begin(etype, size) + raise NotImplementedError + end + + def write_list_end; nil; end + + def write_set_begin(etype, size) + raise NotImplementedError + end + + def write_set_end; nil; end + + def write_bool(bool) + raise NotImplementedError + end + + def write_byte(byte) + raise NotImplementedError + end + + def write_i16(i16) + raise NotImplementedError + end + + def write_i32(i32) + raise NotImplementedError + end + + def write_i64(i64) + raise NotImplementedError + end + + def write_double(dub) + raise NotImplementedError + end + + def write_string(str) + raise NotImplementedError + end + + def read_message_begin + raise NotImplementedError + end + + def read_message_end; nil; end + + def read_struct_begin + raise NotImplementedError + end + + def read_struct_end; nil; end + + def read_field_begin + raise NotImplementedError + end + + def read_field_end; nil; end + + def read_map_begin + raise NotImplementedError + end + + def read_map_end; nil; end + + def read_list_begin + raise NotImplementedError + end + + def read_list_end; nil; end + + def read_set_begin + raise NotImplementedError + end + + def read_set_end; nil; end + + def read_bool + raise NotImplementedError + end + + def read_byte + raise NotImplementedError + end + + def read_i16 + raise NotImplementedError + end + + def read_i32 + raise NotImplementedError + end + + def read_i64 + raise NotImplementedError + end + + def read_double + raise NotImplementedError + end + + def read_string + raise NotImplementedError + end + + def write_field(name, type, fid, value) + write_field_begin(name, type, fid) + write_type(type, value) + write_field_end + end + + def write_type(type, value) + case type + when Types::BOOL + write_bool(value) + when Types::BYTE + write_byte(value) + when Types::DOUBLE + write_double(value) + when Types::I16 + write_i16(value) + when Types::I32 + write_i32(value) + when Types::I64 + write_i64(value) + when Types::STRING + write_string(value) + when Types::STRUCT + value.write(self) + else + raise NotImplementedError + end + end + + def read_type(type) + case type + when Types::BOOL + read_bool + when Types::BYTE + read_byte + when Types::DOUBLE + read_double + when Types::I16 + read_i16 + when Types::I32 + read_i32 + when Types::I64 + read_i64 + when Types::STRING + read_string + else + raise NotImplementedError + end + end + + def skip(type) + case type + when Types::STOP + nil + when Types::BOOL + read_bool + when Types::BYTE + read_byte + when Types::I16 + read_i16 + when Types::I32 + read_i32 + when Types::I64 + read_i64 + when Types::DOUBLE + read_double + when Types::STRING + read_string + when Types::STRUCT + read_struct_begin + while true + name, type, id = read_field_begin + break if type == Types::STOP + skip(type) + read_field_end + end + read_struct_end + when Types::MAP + ktype, vtype, size = read_map_begin + size.times do + skip(ktype) + skip(vtype) + end + read_map_end + when Types::SET + etype, size = read_set_begin + size.times do + skip(etype) + end + read_set_end + when Types::LIST + etype, size = read_list_begin + size.times do + skip(etype) + end + read_list_end + end + end + end + + class BaseProtocolFactory + def get_protocol(trans) + raise NotImplementedError + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/protocol/binary_protocol.rb b/lib/rb/lib/thrift/protocol/binary_protocol.rb new file mode 100644 index 000000000..04d149ac2 --- /dev/null +++ b/lib/rb/lib/thrift/protocol/binary_protocol.rb @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class BinaryProtocol < BaseProtocol + VERSION_MASK = 0xffff0000 + VERSION_1 = 0x80010000 + TYPE_MASK = 0x000000ff + + attr_reader :strict_read, :strict_write + + def initialize(trans, strict_read=true, strict_write=true) + super(trans) + @strict_read = strict_read + @strict_write = strict_write + end + + def write_message_begin(name, type, seqid) + # this is necessary because we added (needed) bounds checking to + # write_i32, and 0x80010000 is too big for that. + if strict_write + write_i16(VERSION_1 >> 16) + write_i16(type) + write_string(name) + write_i32(seqid) + else + write_string(name) + write_byte(type) + write_i32(seqid) + end + end + + def write_struct_begin(name); nil; end + + def write_field_begin(name, type, id) + write_byte(type) + write_i16(id) + end + + def write_field_stop + write_byte(Thrift::Types::STOP) + end + + def write_map_begin(ktype, vtype, size) + write_byte(ktype) + write_byte(vtype) + write_i32(size) + end + + def write_list_begin(etype, size) + write_byte(etype) + write_i32(size) + end + + def write_set_begin(etype, size) + write_byte(etype) + write_i32(size) + end + + def write_bool(bool) + write_byte(bool ? 1 : 0) + end + + def write_byte(byte) + raise RangeError if byte < -2**31 || byte >= 2**32 + trans.write([byte].pack('c')) + end + + def write_i16(i16) + trans.write([i16].pack('n')) + end + + def write_i32(i32) + raise RangeError if i32 < -2**31 || i32 >= 2**31 + trans.write([i32].pack('N')) + end + + def write_i64(i64) + raise RangeError if i64 < -2**63 || i64 >= 2**64 + hi = i64 >> 32 + lo = i64 & 0xffffffff + trans.write([hi, lo].pack('N2')) + end + + def write_double(dub) + trans.write([dub].pack('G')) + end + + def write_string(str) + write_i32(str.length) + trans.write(str) + end + + def read_message_begin + version = read_i32 + if version < 0 + if (version & VERSION_MASK != VERSION_1) + raise ProtocolException.new(ProtocolException::BAD_VERSION, 'Missing version identifier') + end + type = version & TYPE_MASK + name = read_string + seqid = read_i32 + [name, type, seqid] + else + if strict_read + raise ProtocolException.new(ProtocolException::BAD_VERSION, 'No version identifier, old protocol client?') + end + name = trans.read_all(version) + type = read_byte + seqid = read_i32 + [name, type, seqid] + end + end + + def read_struct_begin; nil; end + + def read_field_begin + type = read_byte + if (type == Types::STOP) + [nil, type, 0] + else + id = read_i16 + [nil, type, id] + end + end + + def read_map_begin + ktype = read_byte + vtype = read_byte + size = read_i32 + [ktype, vtype, size] + end + + def read_list_begin + etype = read_byte + size = read_i32 + [etype, size] + end + + def read_set_begin + etype = read_byte + size = read_i32 + [etype, size] + end + + def read_bool + byte = read_byte + byte != 0 + end + + def read_byte + dat = trans.read_all(1) + val = dat[0].ord + if (val > 0x7f) + val = 0 - ((val - 1) ^ 0xff) + end + val + end + + def read_i16 + dat = trans.read_all(2) + val, = dat.unpack('n') + if (val > 0x7fff) + val = 0 - ((val - 1) ^ 0xffff) + end + val + end + + def read_i32 + dat = trans.read_all(4) + val, = dat.unpack('N') + if (val > 0x7fffffff) + val = 0 - ((val - 1) ^ 0xffffffff) + end + val + end + + def read_i64 + dat = trans.read_all(8) + hi, lo = dat.unpack('N2') + if (hi > 0x7fffffff) + hi ^= 0xffffffff + lo ^= 0xffffffff + 0 - (hi << 32) - lo - 1 + else + (hi << 32) + lo + end + end + + def read_double + dat = trans.read_all(8) + val = dat.unpack('G').first + val + end + + def read_string + sz = read_i32 + dat = trans.read_all(sz) + dat + end + + end + + class BinaryProtocolFactory < BaseProtocolFactory + def get_protocol(trans) + return Thrift::BinaryProtocol.new(trans) + end + end +end diff --git a/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb new file mode 100644 index 000000000..eaf64f6be --- /dev/null +++ b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +=begin +The only change required for a transport to support BinaryProtocolAccelerated is to implement 2 methods: + * borrow(size), which takes an optional argument and returns atleast _size_ bytes from the transport, + or the default buffer size if no argument is given + * consume!(size), which removes size bytes from the front of the buffer + +See MemoryBuffer and BufferedTransport for examples. +=end + +module Thrift + class BinaryProtocolAcceleratedFactory < BaseProtocolFactory + def get_protocol(trans) + BinaryProtocolAccelerated.new(trans) + end + end +end diff --git a/lib/rb/lib/thrift/protocol/compact_protocol.rb b/lib/rb/lib/thrift/protocol/compact_protocol.rb new file mode 100644 index 000000000..c8f436559 --- /dev/null +++ b/lib/rb/lib/thrift/protocol/compact_protocol.rb @@ -0,0 +1,422 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class CompactProtocol < BaseProtocol + + PROTOCOL_ID = [0x82].pack('c').unpack('c').first + VERSION = 1 + VERSION_MASK = 0x1f + TYPE_MASK = 0xE0 + TYPE_SHIFT_AMOUNT = 5 + + TSTOP = ["", Types::STOP, 0] + + # + # All of the on-wire type codes. + # + class CompactTypes + BOOLEAN_TRUE = 0x01 + BOOLEAN_FALSE = 0x02 + BYTE = 0x03 + I16 = 0x04 + I32 = 0x05 + I64 = 0x06 + DOUBLE = 0x07 + BINARY = 0x08 + LIST = 0x09 + SET = 0x0A + MAP = 0x0B + STRUCT = 0x0C + + def self.is_bool_type?(b) + (b & 0x0f) == BOOLEAN_TRUE || (b & 0x0f) == BOOLEAN_FALSE + end + + COMPACT_TO_TTYPE = { + Types::STOP => Types::STOP, + BOOLEAN_FALSE => Types::BOOL, + BOOLEAN_TRUE => Types::BOOL, + BYTE => Types::BYTE, + I16 => Types::I16, + I32 => Types::I32, + I64 => Types::I64, + DOUBLE => Types::DOUBLE, + BINARY => Types::STRING, + LIST => Types::LIST, + SET => Types::SET, + MAP => Types::MAP, + STRUCT => Types::STRUCT + } + + TTYPE_TO_COMPACT = { + Types::STOP => Types::STOP, + Types::BOOL => BOOLEAN_TRUE, + Types::BYTE => BYTE, + Types::I16 => I16, + Types::I32 => I32, + Types::I64 => I64, + Types::DOUBLE => DOUBLE, + Types::STRING => BINARY, + Types::LIST => LIST, + Types::SET => SET, + Types::MAP => MAP, + Types::STRUCT => STRUCT + } + + def self.get_ttype(compact_type) + val = COMPACT_TO_TTYPE[compact_type & 0x0f] + raise "don't know what type: #{compact_type & 0x0f}" unless val + val + end + + def self.get_compact_type(ttype) + val = TTYPE_TO_COMPACT[ttype] + raise "don't know what type: #{ttype & 0x0f}" unless val + val + end + end + + def initialize(transport) + super(transport) + + @last_field = [0] + @boolean_value = nil + end + + def write_message_begin(name, type, seqid) + write_byte(PROTOCOL_ID) + write_byte((VERSION & VERSION_MASK) | ((type << TYPE_SHIFT_AMOUNT) & TYPE_MASK)) + write_varint32(seqid) + write_string(name) + nil + end + + def write_struct_begin(name) + @last_field.push(0) + nil + end + + def write_struct_end + @last_field.pop + nil + end + + def write_field_begin(name, type, id) + if type == Types::BOOL + # we want to possibly include the value, so we'll wait. + @boolean_field = [type, id] + else + write_field_begin_internal(type, id) + end + nil + end + + # + # The workhorse of writeFieldBegin. It has the option of doing a + # 'type override' of the type header. This is used specifically in the + # boolean field case. + # + def write_field_begin_internal(type, id, type_override=nil) + last_id = @last_field.pop + + # if there's a type override, use that. + typeToWrite = type_override || CompactTypes.get_compact_type(type) + + # check if we can use delta encoding for the field id + if id > last_id && id - last_id <= 15 + # write them together + write_byte((id - last_id) << 4 | typeToWrite) + else + # write them separate + write_byte(typeToWrite) + write_i16(id) + end + + @last_field.push(id) + nil + end + + def write_field_stop + write_byte(Types::STOP) + end + + def write_map_begin(ktype, vtype, size) + if (size == 0) + write_byte(0) + else + write_varint32(size) + write_byte(CompactTypes.get_compact_type(ktype) << 4 | CompactTypes.get_compact_type(vtype)) + end + end + + def write_list_begin(etype, size) + write_collection_begin(etype, size) + end + + def write_set_begin(etype, size) + write_collection_begin(etype, size); + end + + def write_bool(bool) + type = bool ? CompactTypes::BOOLEAN_TRUE : CompactTypes::BOOLEAN_FALSE + unless @boolean_field.nil? + # we haven't written the field header yet + write_field_begin_internal(@boolean_field.first, @boolean_field.last, type) + @boolean_field = nil + else + # we're not part of a field, so just write the value. + write_byte(type) + end + end + + def write_byte(byte) + @trans.write([byte].pack('c')) + end + + def write_i16(i16) + write_varint32(int_to_zig_zag(i16)) + end + + def write_i32(i32) + write_varint32(int_to_zig_zag(i32)) + end + + def write_i64(i64) + write_varint64(long_to_zig_zag(i64)) + end + + def write_double(dub) + @trans.write([dub].pack("G").reverse) + end + + def write_string(str) + write_varint32(str.length) + @trans.write(str) + end + + def read_message_begin + protocol_id = read_byte() + if protocol_id != PROTOCOL_ID + raise ProtocolException.new("Expected protocol id #{PROTOCOL_ID} but got #{protocol_id}") + end + + version_and_type = read_byte() + version = version_and_type & VERSION_MASK + if (version != VERSION) + raise ProtocolException.new("Expected version #{VERSION} but got #{version}"); + end + + type = (version_and_type >> TYPE_SHIFT_AMOUNT) & 0x03 + seqid = read_varint32() + messageName = read_string() + [messageName, type, seqid] + end + + def read_struct_begin + @last_field.push(0) + "" + end + + def read_struct_end + @last_field.pop() + nil + end + + def read_field_begin + type = read_byte() + + # if it's a stop, then we can return immediately, as the struct is over. + if (type & 0x0f) == Types::STOP + TSTOP + else + field_id = nil + + # mask off the 4 MSB of the type header. it could contain a field id delta. + modifier = (type & 0xf0) >> 4 + if modifier == 0 + # not a delta. look ahead for the zigzag varint field id. + field_id = read_i16() + else + # has a delta. add the delta to the last read field id. + field_id = @last_field.pop + modifier + end + + # if this happens to be a boolean field, the value is encoded in the type + if CompactTypes.is_bool_type?(type) + # save the boolean value in a special instance variable. + @bool_value = (type & 0x0f) == CompactTypes::BOOLEAN_TRUE + end + + # push the new field onto the field stack so we can keep the deltas going. + @last_field.push(field_id) + ["", CompactTypes.get_ttype(type & 0x0f), field_id] + end + end + + def read_map_begin + size = read_varint32() + key_and_value_type = size == 0 ? 0 : read_byte() + [CompactTypes.get_ttype(key_and_value_type >> 4), CompactTypes.get_ttype(key_and_value_type & 0xf), size] + end + + def read_list_begin + size_and_type = read_byte() + size = (size_and_type >> 4) & 0x0f + if size == 15 + size = read_varint32() + end + type = CompactTypes.get_ttype(size_and_type) + [type, size] + end + + def read_set_begin + read_list_begin + end + + def read_bool + unless @bool_value.nil? + bv = @bool_value + @bool_value = nil + bv + else + read_byte() == CompactTypes::BOOLEAN_TRUE + end + end + + def read_byte + dat = trans.read_all(1) + val = dat[0] + if (val > 0x7f) + val = 0 - ((val - 1) ^ 0xff) + end + val + end + + def read_i16 + zig_zag_to_int(read_varint32()) + end + + def read_i32 + zig_zag_to_int(read_varint32()) + end + + def read_i64 + zig_zag_to_long(read_varint64()) + end + + def read_double + dat = trans.read_all(8) + val = dat.reverse.unpack('G').first + val + end + + def read_string + size = read_varint32() + trans.read_all(size) + end + + + private + + # + # Abstract method for writing the start of lists and sets. List and sets on + # the wire differ only by the type indicator. + # + def write_collection_begin(elem_type, size) + if size <= 14 + write_byte(size << 4 | CompactTypes.get_compact_type(elem_type)) + else + write_byte(0xf0 | CompactTypes.get_compact_type(elem_type)) + write_varint32(size) + end + end + + def write_varint32(n) + # int idx = 0; + while true + if (n & ~0x7F) == 0 + # i32buf[idx++] = (byte)n; + write_byte(n) + break + # return; + else + # i32buf[idx++] = (byte)((n & 0x7F) | 0x80); + write_byte((n & 0x7F) | 0x80) + n = n >> 7 + end + end + # trans_.write(i32buf, 0, idx); + end + + SEVEN_BIT_MASK = 0x7F + EVERYTHING_ELSE_MASK = ~SEVEN_BIT_MASK + + def write_varint64(n) + while true + if (n & EVERYTHING_ELSE_MASK) == 0 #TODO need to find a way to make this into a long... + write_byte(n) + break + else + write_byte((n & SEVEN_BIT_MASK) | 0x80) + n >>= 7 + end + end + end + + def read_varint32() + read_varint64() + end + + def read_varint64() + shift = 0 + result = 0 + while true + b = read_byte() + result |= (b & 0x7f) << shift + break if (b & 0x80) != 0x80 + shift += 7 + end + result + end + + def int_to_zig_zag(n) + (n << 1) ^ (n >> 31) + end + + def long_to_zig_zag(l) + # puts "zz encoded #{l} to #{(l << 1) ^ (l >> 63)}" + (l << 1) ^ (l >> 63) + end + + def zig_zag_to_int(n) + (n >> 1) ^ -(n & 1) + end + + def zig_zag_to_long(n) + (n >> 1) ^ -(n & 1) + end + end + + class CompactProtocolFactory < BaseProtocolFactory + def get_protocol(trans) + CompactProtocol.new(trans) + end + end +end diff --git a/lib/rb/lib/thrift/serializer/deserializer.rb b/lib/rb/lib/thrift/serializer/deserializer.rb new file mode 100644 index 000000000..d2ee325a5 --- /dev/null +++ b/lib/rb/lib/thrift/serializer/deserializer.rb @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class Deserializer + def initialize(protocol_factory = BinaryProtocolFactory.new) + @transport = MemoryBufferTransport.new + @protocol = protocol_factory.get_protocol(@transport) + end + + def deserialize(base, buffer) + @transport.reset_buffer(buffer) + base.read(@protocol) + base + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/serializer/serializer.rb b/lib/rb/lib/thrift/serializer/serializer.rb new file mode 100644 index 000000000..22316395d --- /dev/null +++ b/lib/rb/lib/thrift/serializer/serializer.rb @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class Serializer + def initialize(protocol_factory = BinaryProtocolFactory.new) + @transport = MemoryBufferTransport.new + @protocol = protocol_factory.get_protocol(@transport) + end + + def serialize(base) + @transport.reset_buffer + base.write(@protocol) + @transport.read(@transport.available) + end + end +end + diff --git a/lib/rb/lib/thrift/server/base_server.rb b/lib/rb/lib/thrift/server/base_server.rb new file mode 100644 index 000000000..1ee121333 --- /dev/null +++ b/lib/rb/lib/thrift/server/base_server.rb @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class BaseServer + def initialize(processor, server_transport, transport_factory=nil, protocol_factory=nil) + @processor = processor + @server_transport = server_transport + @transport_factory = transport_factory ? transport_factory : Thrift::BaseTransportFactory.new + @protocol_factory = protocol_factory ? protocol_factory : Thrift::BinaryProtocolFactory.new + end + + def serve; nil; end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/server/mongrel_http_server.rb b/lib/rb/lib/thrift/server/mongrel_http_server.rb new file mode 100644 index 000000000..84eacf0dc --- /dev/null +++ b/lib/rb/lib/thrift/server/mongrel_http_server.rb @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'mongrel' + +## Sticks a service on a URL, using mongrel to do the HTTP work +module Thrift + class MongrelHTTPServer < BaseServer + class Handler < Mongrel::HttpHandler + def initialize(processor, protocol_factory) + @processor = processor + @protocol_factory = protocol_factory + end + + def process(request, response) + if request.params["REQUEST_METHOD"] == "POST" + response.start(200) do |head, out| + head["Content-Type"] = "application/x-thrift" + transport = IOStreamTransport.new request.body, out + protocol = @protocol_factory.get_protocol transport + @processor.process protocol, protocol + end + else + response.start(404) { } + end + end + end + + def initialize(processor, opts={}) + port = opts[:port] || 80 + ip = opts[:ip] || "0.0.0.0" + path = opts[:path] || "" + protocol_factory = opts[:protocol_factory] || BinaryProtocolFactory.new + @server = Mongrel::HttpServer.new ip, port + @server.register "/#{path}", Handler.new(processor, protocol_factory) + end + + def serve + @server.run.join + end + end +end diff --git a/lib/rb/lib/thrift/server/nonblocking_server.rb b/lib/rb/lib/thrift/server/nonblocking_server.rb new file mode 100644 index 000000000..5425f6de1 --- /dev/null +++ b/lib/rb/lib/thrift/server/nonblocking_server.rb @@ -0,0 +1,296 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'logger' +require 'thread' + +module Thrift + # this class expects to always use a FramedTransport for reading messages + class NonblockingServer < BaseServer + def initialize(processor, server_transport, transport_factory=nil, protocol_factory=nil, num=20, logger=nil) + super(processor, server_transport, transport_factory, protocol_factory) + @num_threads = num + if logger.nil? + @logger = Logger.new(STDERR) + @logger.level = Logger::WARN + else + @logger = logger + end + @shutdown_semaphore = Mutex.new + @transport_semaphore = Mutex.new + end + + def serve + @logger.info "Starting #{self}" + @server_transport.listen + @io_manager = start_io_manager + + begin + loop do + break if @server_transport.closed? + rd, = select([@server_transport], nil, nil, 0.1) + next if rd.nil? + socket = @server_transport.accept + @logger.debug "Accepted socket: #{socket.inspect}" + @io_manager.add_connection socket + end + rescue IOError => e + end + # we must be shutting down + @logger.info "#{self} is shutting down, goodbye" + ensure + @transport_semaphore.synchronize do + @server_transport.close + end + @io_manager.ensure_closed unless @io_manager.nil? + end + + def shutdown(timeout = 0, block = true) + @shutdown_semaphore.synchronize do + return if @is_shutdown + @is_shutdown = true + end + # nonblocking is intended for calling from within a Handler + # but we can't change the order of operations here, so lets thread + shutdown_proc = lambda do + @io_manager.shutdown(timeout) + @transport_semaphore.synchronize do + @server_transport.close # this will break the accept loop + end + end + if block + shutdown_proc.call + else + Thread.new &shutdown_proc + end + end + + private + + def start_io_manager + iom = IOManager.new(@processor, @server_transport, @transport_factory, @protocol_factory, @num_threads, @logger) + iom.spawn + iom + end + + class IOManager # :nodoc: + DEFAULT_BUFFER = 2**20 + + def initialize(processor, server_transport, transport_factory, protocol_factory, num, logger) + @processor = processor + @server_transport = server_transport + @transport_factory = transport_factory + @protocol_factory = protocol_factory + @num_threads = num + @logger = logger + @connections = [] + @buffers = Hash.new { |h,k| h[k] = '' } + @signal_queue = Queue.new + @signal_pipes = IO.pipe + @signal_pipes[1].sync = true + @worker_queue = Queue.new + @shutdown_queue = Queue.new + end + + def add_connection(socket) + signal [:connection, socket] + end + + def spawn + @iom_thread = Thread.new do + @logger.debug "Starting #{self}" + run + end + end + + def shutdown(timeout = 0) + @logger.debug "#{self} is shutting down workers" + @worker_queue.clear + @num_threads.times { @worker_queue.push [:shutdown] } + signal [:shutdown, timeout] + @shutdown_queue.pop + @signal_pipes[0].close + @signal_pipes[1].close + @logger.debug "#{self} is shutting down, goodbye" + end + + def ensure_closed + kill_worker_threads if @worker_threads + @iom_thread.kill + end + + private + + def run + spin_worker_threads + + loop do + rd, = select([@signal_pipes[0], *@connections]) + if rd.delete @signal_pipes[0] + break if read_signals == :shutdown + end + rd.each do |fd| + if fd.handle.eof? + remove_connection fd + else + read_connection fd + end + end + end + join_worker_threads(@shutdown_timeout) + ensure + @shutdown_queue.push :shutdown + end + + def read_connection(fd) + @buffers[fd] << fd.read(DEFAULT_BUFFER) + frame = slice_frame!(@buffers[fd]) + if frame + @logger.debug "#{self} is processing a frame" + @worker_queue.push [:frame, fd, frame] + end + end + + def spin_worker_threads + @logger.debug "#{self} is spinning up worker threads" + @worker_threads = [] + @num_threads.times do + @worker_threads << spin_thread + end + end + + def spin_thread + Worker.new(@processor, @transport_factory, @protocol_factory, @logger, @worker_queue).spawn + end + + def signal(msg) + @signal_queue << msg + @signal_pipes[1].write " " + end + + def read_signals + # clear the signal pipe + # note that since read_nonblock is broken in jruby, + # we can only read up to a set number of signals at once + sigstr = @signal_pipes[0].readpartial(1024) + # now read the signals + begin + sigstr.length.times do + signal, obj = @signal_queue.pop(true) + case signal + when :connection + @connections << obj + when :shutdown + @shutdown_timeout = obj + return :shutdown + end + end + rescue ThreadError + # out of signals + # note that in a perfect world this would never happen, since we're + # only reading the number of signals pushed on the pipe, but given the lack + # of locks, in theory we could clear the pipe/queue while a new signal is being + # placed on the pipe, at which point our next read_signals would hit this error + end + end + + def remove_connection(fd) + # don't explicitly close it, a thread may still be writing to it + @connections.delete fd + @buffers.delete fd + end + + def join_worker_threads(shutdown_timeout) + start = Time.now + @worker_threads.each do |t| + if shutdown_timeout > 0 + timeout = (start + shutdown_timeout) - Time.now + break if timeout <= 0 + t.join(timeout) + else + t.join + end + end + kill_worker_threads + end + + def kill_worker_threads + @worker_threads.each do |t| + t.kill if t.status + end + @worker_threads.clear + end + + def slice_frame!(buf) + if buf.length >= 4 + size = buf.unpack('N').first + if buf.length >= size + 4 + buf.slice!(0, size + 4) + else + nil + end + else + nil + end + end + + class Worker # :nodoc: + def initialize(processor, transport_factory, protocol_factory, logger, queue) + @processor = processor + @transport_factory = transport_factory + @protocol_factory = protocol_factory + @logger = logger + @queue = queue + end + + def spawn + Thread.new do + @logger.debug "#{self} is spawning" + run + end + end + + private + + def run + loop do + cmd, *args = @queue.pop + case cmd + when :shutdown + @logger.debug "#{self} is shutting down, goodbye" + break + when :frame + fd, frame = args + begin + otrans = @transport_factory.get_transport(fd) + oprot = @protocol_factory.get_protocol(otrans) + membuf = MemoryBufferTransport.new(frame) + itrans = @transport_factory.get_transport(membuf) + iprot = @protocol_factory.get_protocol(itrans) + @processor.process(iprot, oprot) + rescue => e + @logger.error "#{Thread.current.inspect} raised error: #{e.inspect}\n#{e.backtrace.join("\n")}" + end + end + end + end + end + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/server/simple_server.rb b/lib/rb/lib/thrift/server/simple_server.rb new file mode 100644 index 000000000..21e865926 --- /dev/null +++ b/lib/rb/lib/thrift/server/simple_server.rb @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class SimpleServer < BaseServer + def serve + begin + @server_transport.listen + loop do + client = @server_transport.accept + trans = @transport_factory.get_transport(client) + prot = @protocol_factory.get_protocol(trans) + begin + loop do + @processor.process(prot, prot) + end + rescue Thrift::TransportException, Thrift::ProtocolException + ensure + trans.close + end + end + ensure + @server_transport.close + end + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/server/thread_pool_server.rb b/lib/rb/lib/thrift/server/thread_pool_server.rb new file mode 100644 index 000000000..8cec805a9 --- /dev/null +++ b/lib/rb/lib/thrift/server/thread_pool_server.rb @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'thread' + +module Thrift + class ThreadPoolServer < BaseServer + def initialize(processor, server_transport, transport_factory=nil, protocol_factory=nil, num=20) + super(processor, server_transport, transport_factory, protocol_factory) + @thread_q = SizedQueue.new(num) + @exception_q = Queue.new + @running = false + end + + ## exceptions that happen in worker threads will be relayed here and + ## must be caught. 'retry' can be used to continue. (threads will + ## continue to run while the exception is being handled.) + def rescuable_serve + Thread.new { serve } unless @running + @running = true + raise @exception_q.pop + end + + ## exceptions that happen in worker threads simply cause that thread + ## to die and another to be spawned in its place. + def serve + @server_transport.listen + + begin + loop do + @thread_q.push(:token) + Thread.new do + begin + loop do + client = @server_transport.accept + trans = @transport_factory.get_transport(client) + prot = @protocol_factory.get_protocol(trans) + begin + loop do + @processor.process(prot, prot) + end + rescue Thrift::TransportException, Thrift::ProtocolException => e + ensure + trans.close + end + end + rescue => e + @exception_q.push(e) + ensure + @thread_q.pop # thread died! + end + end + end + ensure + @server_transport.close + end + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/server/threaded_server.rb b/lib/rb/lib/thrift/server/threaded_server.rb new file mode 100644 index 000000000..a2c917cb8 --- /dev/null +++ b/lib/rb/lib/thrift/server/threaded_server.rb @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'thread' + +module Thrift + class ThreadedServer < BaseServer + def serve + begin + @server_transport.listen + loop do + client = @server_transport.accept + trans = @transport_factory.get_transport(client) + prot = @protocol_factory.get_protocol(trans) + Thread.new(prot, trans) do |p, t| + begin + loop do + @processor.process(p, p) + end + rescue Thrift::TransportException, Thrift::ProtocolException + ensure + t.close + end + end + end + ensure + @server_transport.close + end + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/struct.rb b/lib/rb/lib/thrift/struct.rb new file mode 100644 index 000000000..01aae56b7 --- /dev/null +++ b/lib/rb/lib/thrift/struct.rb @@ -0,0 +1,294 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'set' + +module Thrift + module Struct + def initialize(d={}) + # get a copy of the default values to work on, removing defaults in favor of arguments + fields_with_defaults = fields_with_default_values.dup + + # check if the defaults is empty, or if there are no parameters for this + # instantiation, and if so, don't bother overriding defaults. + unless fields_with_defaults.empty? || d.empty? + d.each_key do |name| + fields_with_defaults.delete(name.to_s) + end + end + + # assign all the user-specified arguments + unless d.empty? + d.each do |name, value| + unless name_to_id(name.to_s) + raise Exception, "Unknown key given to #{self.class}.new: #{name}" + end + Thrift.check_type(value, struct_fields[name_to_id(name.to_s)], name) if Thrift.type_checking + instance_variable_set("@#{name}", value) + end + end + + # assign all the default values + unless fields_with_defaults.empty? + fields_with_defaults.each do |name, default_value| + instance_variable_set("@#{name}", (default_value.dup rescue default_value)) + end + end + end + + def fields_with_default_values + fields_with_default_values = self.class.instance_variable_get("@fields_with_default_values") + unless fields_with_default_values + fields_with_default_values = {} + struct_fields.each do |fid, field_def| + unless field_def[:default].nil? + fields_with_default_values[field_def[:name]] = field_def[:default] + end + end + self.class.instance_variable_set("@fields_with_default_values", fields_with_default_values) + end + fields_with_default_values + end + + def name_to_id(name) + names_to_ids = self.class.instance_variable_get("@names_to_ids") + unless names_to_ids + names_to_ids = {} + struct_fields.each do |fid, field_def| + names_to_ids[field_def[:name]] = fid + end + self.class.instance_variable_set("@names_to_ids", names_to_ids) + end + names_to_ids[name] + end + + def each_field + struct_fields.keys.sort.each do |fid| + data = struct_fields[fid] + yield fid, data + end + end + + def inspect(skip_optional_nulls = true) + fields = [] + each_field do |fid, field_info| + name = field_info[:name] + value = instance_variable_get("@#{name}") + unless skip_optional_nulls && field_info[:optional] && value.nil? + fields << "#{name}:#{value.inspect}" + end + end + "<#{self.class} #{fields.join(", ")}>" + end + + def read(iprot) + iprot.read_struct_begin + loop do + fname, ftype, fid = iprot.read_field_begin + break if (ftype == Types::STOP) + handle_message(iprot, fid, ftype) + iprot.read_field_end + end + iprot.read_struct_end + validate + end + + def write(oprot) + validate + oprot.write_struct_begin(self.class.name) + each_field do |fid, field_info| + name = field_info[:name] + type = field_info[:type] + if (value = instance_variable_get("@#{name}")) + if is_container? type + oprot.write_field_begin(name, type, fid) + write_container(oprot, value, field_info) + oprot.write_field_end + else + oprot.write_field(name, type, fid, value) + end + end + end + oprot.write_field_stop + oprot.write_struct_end + end + + def ==(other) + each_field do |fid, field_info| + name = field_info[:name] + return false unless self.instance_variable_get("@#{name}") == other.instance_variable_get("@#{name}") + end + true + end + + def eql?(other) + self.class == other.class && self == other + end + + # for the time being, we're ok with a naive hash. this could definitely be improved upon. + def hash + 0 + end + + def differences(other) + diffs = [] + unless other.is_a?(self.class) + diffs << "Different class!" + else + each_field do |fid, field_info| + name = field_info[:name] + diffs << "#{name} differs!" unless self.instance_variable_get("@#{name}") == other.instance_variable_get("@#{name}") + end + end + diffs + end + + def self.field_accessor(klass, *fields) + fields.each do |field| + klass.send :attr_reader, field + klass.send :define_method, "#{field}=" do |value| + Thrift.check_type(value, klass::FIELDS.values.find { |f| f[:name].to_s == field.to_s }, field) if Thrift.type_checking + instance_variable_set("@#{field}", value) + end + end + end + + protected + + def self.append_features(mod) + if mod.ancestors.include? ::Exception + mod.send :class_variable_set, :'@@__thrift_struct_real_initialize', mod.instance_method(:initialize) + super + # set up our custom initializer so `raise Xception, 'message'` works + mod.send :define_method, :struct_initialize, mod.instance_method(:initialize) + mod.send :define_method, :initialize, mod.instance_method(:exception_initialize) + else + super + end + end + + def exception_initialize(*args, &block) + if args.size == 1 and args.first.is_a? Hash + # looks like it's a regular Struct initialize + method(:struct_initialize).call(args.first) + else + # call the Struct initializer first with no args + # this will set our field default values + method(:struct_initialize).call() + # now give it to the exception + self.class.send(:class_variable_get, :'@@__thrift_struct_real_initialize').bind(self).call(*args, &block) if args.size > 0 + # self.class.instance_method(:initialize).bind(self).call(*args, &block) + end + end + + def handle_message(iprot, fid, ftype) + field = struct_fields[fid] + if field and field[:type] == ftype + value = read_field(iprot, field) + instance_variable_set("@#{field[:name]}", value) + else + iprot.skip(ftype) + end + end + + def read_field(iprot, field = {}) + case field[:type] + when Types::STRUCT + value = field[:class].new + value.read(iprot) + when Types::MAP + key_type, val_type, size = iprot.read_map_begin + value = {} + size.times do + k = read_field(iprot, field_info(field[:key])) + v = read_field(iprot, field_info(field[:value])) + value[k] = v + end + iprot.read_map_end + when Types::LIST + e_type, size = iprot.read_list_begin + value = Array.new(size) do |n| + read_field(iprot, field_info(field[:element])) + end + iprot.read_list_end + when Types::SET + e_type, size = iprot.read_set_begin + value = Set.new + size.times do + element = read_field(iprot, field_info(field[:element])) + value << element + end + iprot.read_set_end + else + value = iprot.read_type(field[:type]) + end + value + end + + def write_data(oprot, value, field) + if is_container? field[:type] + write_container(oprot, value, field) + else + oprot.write_type(field[:type], value) + end + end + + def write_container(oprot, value, field = {}) + case field[:type] + when Types::MAP + oprot.write_map_begin(field[:key][:type], field[:value][:type], value.size) + value.each do |k, v| + write_data(oprot, k, field[:key]) + write_data(oprot, v, field[:value]) + end + oprot.write_map_end + when Types::LIST + oprot.write_list_begin(field[:element][:type], value.size) + value.each do |elem| + write_data(oprot, elem, field[:element]) + end + oprot.write_list_end + when Types::SET + oprot.write_set_begin(field[:element][:type], value.size) + value.each do |v,| # the , is to preserve compatibility with the old Hash-style sets + write_data(oprot, v, field[:element]) + end + oprot.write_set_end + else + raise "Not a container type: #{field[:type]}" + end + end + + CONTAINER_TYPES = [] + CONTAINER_TYPES[Types::LIST] = true + CONTAINER_TYPES[Types::MAP] = true + CONTAINER_TYPES[Types::SET] = true + def is_container?(type) + CONTAINER_TYPES[type] + end + + def field_info(field) + { :type => field[:type], + :class => field[:class], + :key => field[:key], + :value => field[:value], + :element => field[:element] } + end + end +end diff --git a/lib/rb/lib/thrift/thrift_native.rb b/lib/rb/lib/thrift/thrift_native.rb new file mode 100644 index 000000000..4d8df61ff --- /dev/null +++ b/lib/rb/lib/thrift/thrift_native.rb @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +begin + require "thrift_native" +rescue LoadError + puts "Unable to load thrift_native extension. Defaulting to pure Ruby libraries." +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/base_server_transport.rb b/lib/rb/lib/thrift/transport/base_server_transport.rb new file mode 100644 index 000000000..68c5af076 --- /dev/null +++ b/lib/rb/lib/thrift/transport/base_server_transport.rb @@ -0,0 +1,37 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class BaseServerTransport + def listen + raise NotImplementedError + end + + def accept + raise NotImplementedError + end + + def close; nil; end + + def closed? + raise NotImplementedError + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/base_transport.rb b/lib/rb/lib/thrift/transport/base_transport.rb new file mode 100644 index 000000000..08a71dab7 --- /dev/null +++ b/lib/rb/lib/thrift/transport/base_transport.rb @@ -0,0 +1,70 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class TransportException < Exception + UNKNOWN = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 + + attr_reader :type + + def initialize(type=UNKNOWN, message=nil) + super(message) + @type = type + end + end + + class BaseTransport + def open?; end + + def open; end + + def close; end + + def read(sz) + raise NotImplementedError + end + + def read_all(size) + buf = '' + + while (buf.length < size) + chunk = read(size - buf.length) + buf << chunk + end + + buf + end + + def write(buf); end + alias_method :<<, :write + + def flush; end + end + + class BaseTransportFactory + def get_transport(trans) + return trans + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/buffered_transport.rb b/lib/rb/lib/thrift/transport/buffered_transport.rb new file mode 100644 index 000000000..8dead4e0b --- /dev/null +++ b/lib/rb/lib/thrift/transport/buffered_transport.rb @@ -0,0 +1,77 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class BufferedTransport < BaseTransport + DEFAULT_BUFFER = 4096 + + def initialize(transport) + @transport = transport + @wbuf = '' + @rbuf = '' + @index = 0 + end + + def open? + return @transport.open? + end + + def open + @transport.open + end + + def close + flush + @transport.close + end + + def read(sz) + @index += sz + ret = @rbuf.slice(@index - sz, sz) || '' + + if ret.length == 0 + @rbuf = @transport.read([sz, DEFAULT_BUFFER].max) + @index = sz + ret = @rbuf.slice(0, sz) || '' + end + + ret + end + + def write(buf) + @wbuf << buf + end + + def flush + if @wbuf != '' + @transport.write(@wbuf) + @wbuf = '' + end + + @transport.flush + end + end + + class BufferedTransportFactory < BaseTransportFactory + def get_transport(transport) + return BufferedTransport.new(transport) + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/framed_transport.rb b/lib/rb/lib/thrift/transport/framed_transport.rb new file mode 100644 index 000000000..558af7444 --- /dev/null +++ b/lib/rb/lib/thrift/transport/framed_transport.rb @@ -0,0 +1,90 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class FramedTransport < BaseTransport + def initialize(transport, read=true, write=true) + @transport = transport + @rbuf = '' + @wbuf = '' + @read = read + @write = write + @index = 0 + end + + def open? + @transport.open? + end + + def open + @transport.open + end + + def close + @transport.close + end + + def read(sz) + return @transport.read(sz) unless @read + + return '' if sz <= 0 + + read_frame if @index >= @rbuf.length + + @index += sz + @rbuf.slice(@index - sz, sz) || '' + end + + def write(buf,sz=nil) + return @transport.write(buf) unless @write + + @wbuf << (sz ? buf[0...sz] : buf) + end + + # + # Writes the output buffer to the stream in the format of a 4-byte length + # followed by the actual data. + # + def flush + return @transport.flush unless @write + + out = [@wbuf.length].pack('N') + out << @wbuf + @transport.write(out) + @transport.flush + @wbuf = '' + end + + private + + def read_frame + sz = @transport.read_all(4).unpack('N').first + + @index = 0 + @rbuf = @transport.read_all(sz) + end + end + + class FramedTransportFactory < BaseTransportFactory + def get_transport(transport) + return FramedTransport.new(transport) + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/http_client_transport.rb b/lib/rb/lib/thrift/transport/http_client_transport.rb new file mode 100644 index 000000000..a190a9830 --- /dev/null +++ b/lib/rb/lib/thrift/transport/http_client_transport.rb @@ -0,0 +1,45 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'net/http' +require 'net/https' +require 'uri' +require 'stringio' + +module Thrift + class HTTPClientTransport < BaseTransport + def initialize(url) + @url = URI url + @outbuf = "" + end + + def open?; true end + def read(sz); @inbuf.read sz end + def write(buf); @outbuf << buf end + def flush + http = Net::HTTP.new @url.host, @url.port + http.use_ssl = @url.scheme == "https" + headers = { 'Content-Type' => 'application/x-thrift' } + resp, data = http.post(@url.path, @outbuf, headers) + @inbuf = StringIO.new data + @outbuf = "" + end + end +end diff --git a/lib/rb/lib/thrift/transport/io_stream_transport.rb b/lib/rb/lib/thrift/transport/io_stream_transport.rb new file mode 100644 index 000000000..be348aa09 --- /dev/null +++ b/lib/rb/lib/thrift/transport/io_stream_transport.rb @@ -0,0 +1,39 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Very very simple implementation of wrapping two objects, one with a #read +# method and one with a #write method, into a transport for thrift. +# +# Assumes both objects are open, remain open, don't require flushing, etc. +# +module Thrift + class IOStreamTransport < BaseTransport + def initialize(input, output) + @input = input + @output = output + end + + def open?; not @input.closed? or not @output.closed? end + def read(sz); @input.read(sz) end + def write(buf); @output.write(buf) end + def close; @input.close; @output.close end + def to_io; @input end # we're assuming this is used in a IO.select for reading + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/memory_buffer_transport.rb b/lib/rb/lib/thrift/transport/memory_buffer_transport.rb new file mode 100644 index 000000000..33d732d13 --- /dev/null +++ b/lib/rb/lib/thrift/transport/memory_buffer_transport.rb @@ -0,0 +1,93 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class MemoryBufferTransport < BaseTransport + GARBAGE_BUFFER_SIZE = 4*(2**10) # 4kB + + # If you pass a string to this, you should #dup that string + # unless you want it to be modified by #read and #write + #-- + # this behavior is no longer required. If you wish to change it + # go ahead, just make sure the specs pass + def initialize(buffer = nil) + @buf = buffer || '' + @index = 0 + end + + def open? + return true + end + + def open + end + + def close + end + + def peek + @index < @buf.size + end + + # this method does not use the passed object directly but copies it + def reset_buffer(new_buf = '') + @buf.replace new_buf + @index = 0 + end + + def available + @buf.length - @index + end + + def read(len) + data = @buf.slice(@index, len) + @index += len + @index = @buf.size if @index > @buf.size + if @index >= GARBAGE_BUFFER_SIZE + @buf = @buf.slice(@index..-1) + @index = 0 + end + data + end + + def write(wbuf) + @buf << wbuf + end + + def flush + end + + def inspect_buffer + out = [] + for idx in 0...(@buf.size) + # if idx != 0 + # out << " " + # end + + if idx == @index + out << ">" + end + + out << @buf[idx].to_s(16) + end + out.join(" ") + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/server_socket.rb b/lib/rb/lib/thrift/transport/server_socket.rb new file mode 100644 index 000000000..7feb9ab0d --- /dev/null +++ b/lib/rb/lib/thrift/transport/server_socket.rb @@ -0,0 +1,63 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'socket' + +module Thrift + class ServerSocket < BaseServerTransport + # call-seq: initialize(host = nil, port) + def initialize(host_or_port, port = nil) + if port + @host = host_or_port + @port = port + else + @host = nil + @port = host_or_port + end + @handle = nil + end + + attr_reader :handle + + def listen + @handle = TCPServer.new(@host, @port) + end + + def accept + unless @handle.nil? + sock = @handle.accept + trans = Socket.new + trans.handle = sock + trans + end + end + + def close + @handle.close unless @handle.nil? or @handle.closed? + @handle = nil + end + + def closed? + @handle.nil? or @handle.closed? + end + + alias to_io handle + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/socket.rb b/lib/rb/lib/thrift/transport/socket.rb new file mode 100644 index 000000000..06c937e58 --- /dev/null +++ b/lib/rb/lib/thrift/transport/socket.rb @@ -0,0 +1,136 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'socket' + +module Thrift + class Socket < BaseTransport + def initialize(host='localhost', port=9090, timeout=nil) + @host = host + @port = port + @timeout = timeout + @desc = "#{host}:#{port}" + @handle = nil + end + + attr_accessor :handle, :timeout + + def open + begin + addrinfo = ::Socket::getaddrinfo(@host, @port).first + @handle = ::Socket.new(addrinfo[4], ::Socket::SOCK_STREAM, 0) + sockaddr = ::Socket.sockaddr_in(addrinfo[1], addrinfo[3]) + begin + @handle.connect_nonblock(sockaddr) + rescue Errno::EINPROGRESS + unless IO.select(nil, [ @handle ], nil, @timeout) + raise TransportException.new(TransportException::NOT_OPEN, "Connection timeout to #{@desc}") + end + begin + @handle.connect_nonblock(sockaddr) + rescue Errno::EISCONN + end + end + @handle + rescue StandardError => e + raise TransportException.new(TransportException::NOT_OPEN, "Could not connect to #{@desc}: #{e}") + end + end + + def open? + !@handle.nil? and !@handle.closed? + end + + def write(str) + raise IOError, "closed stream" unless open? + begin + if @timeout.nil? or @timeout == 0 + @handle.write(str) + else + len = 0 + start = Time.now + while Time.now - start < @timeout + rd, wr, = IO.select(nil, [@handle], nil, @timeout) + if wr and not wr.empty? + len += @handle.write_nonblock(str[len..-1]) + break if len >= str.length + end + end + if len < str.length + raise TransportException.new(TransportException::TIMED_OUT, "Socket: Timed out writing #{str.length} bytes to #{@desc}") + else + len + end + end + rescue TransportException => e + # pass this on + raise e + rescue StandardError => e + @handle.close + @handle = nil + raise TransportException.new(TransportException::NOT_OPEN, e.message) + end + end + + def read(sz) + raise IOError, "closed stream" unless open? + + begin + if @timeout.nil? or @timeout == 0 + data = @handle.readpartial(sz) + else + # it's possible to interrupt select for something other than the timeout + # so we need to ensure we've waited long enough + start = Time.now + rd = nil # scoping + loop do + rd, = IO.select([@handle], nil, nil, @timeout) + break if (rd and not rd.empty?) or Time.now - start >= @timeout + end + if rd.nil? or rd.empty? + raise TransportException.new(TransportException::TIMED_OUT, "Socket: Timed out reading #{sz} bytes from #{@desc}") + else + data = @handle.readpartial(sz) + end + end + rescue TransportException => e + # don't let this get caught by the StandardError handler + raise e + rescue StandardError => e + @handle.close unless @handle.closed? + @handle = nil + raise TransportException.new(TransportException::NOT_OPEN, e.message) + end + if (data.nil? or data.length == 0) + raise TransportException.new(TransportException::UNKNOWN, "Socket: Could not read #{sz} bytes from #{@desc}") + end + data + end + + def close + @handle.close unless @handle.nil? or @handle.closed? + @handle = nil + end + + def to_io + @handle + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/unix_server_socket.rb b/lib/rb/lib/thrift/transport/unix_server_socket.rb new file mode 100644 index 000000000..a135d25f2 --- /dev/null +++ b/lib/rb/lib/thrift/transport/unix_server_socket.rb @@ -0,0 +1,60 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'socket' + +module Thrift + class UNIXServerSocket < BaseServerTransport + def initialize(path) + @path = path + @handle = nil + end + + attr_accessor :handle + + def listen + @handle = ::UNIXServer.new(@path) + end + + def accept + unless @handle.nil? + sock = @handle.accept + trans = UNIXSocket.new(nil) + trans.handle = sock + trans + end + end + + def close + if @handle + @handle.close unless @handle.closed? + @handle = nil + # UNIXServer doesn't delete the socket file, so we have to do it ourselves + File.delete(@path) + end + end + + def closed? + @handle.nil? or @handle.closed? + end + + alias to_io handle + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/unix_socket.rb b/lib/rb/lib/thrift/transport/unix_socket.rb new file mode 100644 index 000000000..8f692e4c8 --- /dev/null +++ b/lib/rb/lib/thrift/transport/unix_socket.rb @@ -0,0 +1,40 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'socket' + +module Thrift + class UNIXSocket < Socket + def initialize(path, timeout=nil) + @path = path + @timeout = timeout + @desc = @path # for read()'s error + @handle = nil + end + + def open + begin + @handle = ::UNIXSocket.new(@path) + rescue StandardError + raise TransportException.new(TransportException::NOT_OPEN, "Could not open UNIX socket at #{@path}") + end + end + end +end
\ No newline at end of file diff --git a/lib/rb/lib/thrift/types.rb b/lib/rb/lib/thrift/types.rb new file mode 100644 index 000000000..20e4ca2c1 --- /dev/null +++ b/lib/rb/lib/thrift/types.rb @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'set' + +module Thrift + module Types + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + end + + class << self + attr_accessor :type_checking + end + + class TypeError < Exception + end + + def self.check_type(value, field, name, skip_nil=true) + return if value.nil? and skip_nil + klasses = case field[:type] + when Types::VOID + NilClass + when Types::BOOL + [TrueClass, FalseClass] + when Types::BYTE, Types::I16, Types::I32, Types::I64 + Integer + when Types::DOUBLE + Float + when Types::STRING + String + when Types::STRUCT + Struct + when Types::MAP + Hash + when Types::SET + Set + when Types::LIST + Array + end + valid = klasses && [*klasses].any? { |klass| klass === value } + raise TypeError, "Expected #{type_name(field[:type])}, received #{value.class} for field #{name}" unless valid + # check elements now + case field[:type] + when Types::MAP + value.each_pair do |k,v| + check_type(k, field[:key], "#{name}.key", false) + check_type(v, field[:value], "#{name}.value", false) + end + when Types::SET, Types::LIST + value.each do |el| + check_type(el, field[:element], "#{name}.element", false) + end + when Types::STRUCT + raise TypeError, "Expected #{field[:class]}, received #{value.class} for field #{name}" unless field[:class] == value.class + end + end + + def self.type_name(type) + Types.constants.each do |const| + return "Types::#{const}" if Types.const_get(const) == type + end + nil + end + + module MessageTypes + CALL = 1 + REPLY = 2 + EXCEPTION = 3 + ONEWAY = 4 + end +end + +Thrift.type_checking = false if Thrift.type_checking.nil? diff --git a/lib/rb/script/proto_benchmark.rb b/lib/rb/script/proto_benchmark.rb new file mode 100644 index 000000000..bb49e2e42 --- /dev/null +++ b/lib/rb/script/proto_benchmark.rb @@ -0,0 +1,121 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + "/../spec/spec_helper.rb" + +require "benchmark" +# require "ruby-prof" + +obj = Fixtures::COMPACT_PROTOCOL_TEST_STRUCT + +HOW_MANY = 1_000 + +binser = Thrift::Serializer.new +bin_data = binser.serialize(obj) +bindeser = Thrift::Deserializer.new +accel_bin_ser = Thrift::Serializer.new(Thrift::BinaryProtocolAcceleratedFactory.new) +accel_bin_deser = Thrift::Deserializer.new(Thrift::BinaryProtocolAcceleratedFactory.new) + +compact_ser = Thrift::Serializer.new(Thrift::CompactProtocolFactory.new) +compact_data = compact_ser.serialize(obj) +compact_deser = Thrift::Deserializer.new(Thrift::CompactProtocolFactory.new) + +Benchmark.bm(60) do |reporter| + reporter.report("binary protocol, write") do + HOW_MANY.times do + binser.serialize(obj) + end + end + + reporter.report("accelerated binary protocol, write") do + HOW_MANY.times do + accel_bin_ser.serialize(obj) + end + end + + reporter.report("compact protocol, write") do + # RubyProf.start + HOW_MANY.times do + compact_ser.serialize(obj) + end + # result = RubyProf.stop + # printer = RubyProf::GraphHtmlPrinter.new(result) + # file = File.open("profile.html", "w+") + # printer.print(file, 0) + # file.close + end + + reporter.report("binary protocol, read") do + HOW_MANY.times do + bindeser.deserialize(obj, bin_data) + end + end + + reporter.report("accelerated binary protocol, read") do + HOW_MANY.times do + accel_bin_deser.deserialize(obj, bin_data) + end + end + + reporter.report("compact protocol, read") do + HOW_MANY.times do + compact_deser.deserialize(obj, compact_data) + end + end + + + # f = File.new("/tmp/testfile", "w") + # proto = Thrift::BinaryProtocolAccelerated.new(Thrift::IOStreamTransport.new(Thrift::MemoryBufferTransport.new, f)) + # reporter.report("accelerated binary protocol, write (to disk)") do + # HOW_MANY.times do + # obj.write(proto) + # end + # f.flush + # end + # f.close + # + # f = File.new("/tmp/testfile", "r") + # proto = Thrift::BinaryProtocolAccelerated.new(Thrift::IOStreamTransport.new(f, Thrift::MemoryBufferTransport.new)) + # reporter.report("accelerated binary protocol, read (from disk)") do + # HOW_MANY.times do + # obj.read(proto) + # end + # end + # f.close + # + # f = File.new("/tmp/testfile", "w") + # reporter.report("compact protocol, write (to disk)") do + # proto = Thrift::CompactProtocol.new(Thrift::IOStreamTransport.new(Thrift::MemoryBufferTransport.new, f)) + # HOW_MANY.times do + # obj.write(proto) + # end + # f.flush + # end + # f.close + # + # f = File.new("/tmp/testfile", "r") + # reporter.report("compact protocol, read (from disk)") do + # proto = Thrift::CompactProtocol.new(Thrift::IOStreamTransport.new(f, Thrift::MemoryBufferTransport.new)) + # HOW_MANY.times do + # obj.read(proto) + # end + # end + # f.close + +end diff --git a/lib/rb/script/read_struct.rb b/lib/rb/script/read_struct.rb new file mode 100644 index 000000000..831fcec90 --- /dev/null +++ b/lib/rb/script/read_struct.rb @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require "spec/spec_helper" + +path, factory_class = ARGV + +factory = eval(factory_class).new + +deser = Thrift::Deserializer.new(factory) + +cpts = CompactProtoTestStruct.new +CompactProtoTestStruct.constants.each do |const| + cpts.instance_variable_set("@#{const}", nil) +end + +data = File.read(path) + +deser.deserialize(cpts, data) + +if cpts == Fixtures::COMPACT_PROTOCOL_TEST_STRUCT + puts "Object verified successfully!" +else + puts "Object failed verification! Expected #{Fixtures::COMPACT_PROTOCOL_TEST_STRUCT.inspect} but got #{cpts.inspect}" + + puts cpts.differences(Fixtures::COMPACT_PROTOCOL_TEST_STRUCT) +end diff --git a/lib/rb/script/write_struct.rb b/lib/rb/script/write_struct.rb new file mode 100644 index 000000000..da1421975 --- /dev/null +++ b/lib/rb/script/write_struct.rb @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require "spec/spec_helper" + +path, factory_class = ARGV + +factory = eval(factory_class).new + +ser = Thrift::Serializer.new(factory) + +File.open(path, "w") do |file| + file.write(ser.serialize(Fixtures::COMPACT_PROTOCOL_TEST_STRUCT)) +end
\ No newline at end of file diff --git a/lib/rb/setup.rb b/lib/rb/setup.rb new file mode 100644 index 000000000..9f0c8267a --- /dev/null +++ b/lib/rb/setup.rb @@ -0,0 +1,1585 @@ +# +# setup.rb +# +# Copyright (c) 2000-2005 Minero Aoki +# +# This program is free software. +# You can distribute/modify this program under the terms of +# the GNU LGPL, Lesser General Public License version 2.1. +# + +unless Enumerable.method_defined?(:map) # Ruby 1.4.6 + module Enumerable + alias map collect + end +end + +unless File.respond_to?(:read) # Ruby 1.6 + def File.read(fname) + open(fname) {|f| + return f.read + } + end +end + +unless Errno.const_defined?(:ENOTEMPTY) # Windows? + module Errno + class ENOTEMPTY + # We do not raise this exception, implementation is not needed. + end + end +end + +def File.binread(fname) + open(fname, 'rb') {|f| + return f.read + } +end + +# for corrupted Windows' stat(2) +def File.dir?(path) + File.directory?((path[-1,1] == '/') ? path : path + '/') +end + + +class ConfigTable + + include Enumerable + + def initialize(rbconfig) + @rbconfig = rbconfig + @items = [] + @table = {} + # options + @install_prefix = nil + @config_opt = nil + @verbose = true + @no_harm = false + end + + attr_accessor :install_prefix + attr_accessor :config_opt + + attr_writer :verbose + + def verbose? + @verbose + end + + attr_writer :no_harm + + def no_harm? + @no_harm + end + + def [](key) + lookup(key).resolve(self) + end + + def []=(key, val) + lookup(key).set val + end + + def names + @items.map {|i| i.name } + end + + def each(&block) + @items.each(&block) + end + + def key?(name) + @table.key?(name) + end + + def lookup(name) + @table[name] or setup_rb_error "no such config item: #{name}" + end + + def add(item) + @items.push item + @table[item.name] = item + end + + def remove(name) + item = lookup(name) + @items.delete_if {|i| i.name == name } + @table.delete_if {|name, i| i.name == name } + item + end + + def load_script(path, inst = nil) + if File.file?(path) + MetaConfigEnvironment.new(self, inst).instance_eval File.read(path), path + end + end + + def savefile + '.config' + end + + def load_savefile + begin + File.foreach(savefile()) do |line| + k, v = *line.split(/=/, 2) + self[k] = v.strip + end + rescue Errno::ENOENT + setup_rb_error $!.message + "\n#{File.basename($0)} config first" + end + end + + def save + @items.each {|i| i.value } + File.open(savefile(), 'w') {|f| + @items.each do |i| + f.printf "%s=%s\n", i.name, i.value if i.value? and i.value + end + } + end + + def load_standard_entries + standard_entries(@rbconfig).each do |ent| + add ent + end + end + + def standard_entries(rbconfig) + c = rbconfig + + rubypath = File.join(c['bindir'], c['ruby_install_name'] + c['EXEEXT']) + + major = c['MAJOR'].to_i + minor = c['MINOR'].to_i + teeny = c['TEENY'].to_i + version = "#{major}.#{minor}" + + # ruby ver. >= 1.4.4? + newpath_p = ((major >= 2) or + ((major == 1) and + ((minor >= 5) or + ((minor == 4) and (teeny >= 4))))) + + if c['rubylibdir'] + # V > 1.6.3 + libruby = "#{c['prefix']}/lib/ruby" + librubyver = c['rubylibdir'] + librubyverarch = c['archdir'] + siteruby = c['sitedir'] + siterubyver = c['sitelibdir'] + siterubyverarch = c['sitearchdir'] + elsif newpath_p + # 1.4.4 <= V <= 1.6.3 + libruby = "#{c['prefix']}/lib/ruby" + librubyver = "#{c['prefix']}/lib/ruby/#{version}" + librubyverarch = "#{c['prefix']}/lib/ruby/#{version}/#{c['arch']}" + siteruby = c['sitedir'] + siterubyver = "$siteruby/#{version}" + siterubyverarch = "$siterubyver/#{c['arch']}" + else + # V < 1.4.4 + libruby = "#{c['prefix']}/lib/ruby" + librubyver = "#{c['prefix']}/lib/ruby/#{version}" + librubyverarch = "#{c['prefix']}/lib/ruby/#{version}/#{c['arch']}" + siteruby = "#{c['prefix']}/lib/ruby/#{version}/site_ruby" + siterubyver = siteruby + siterubyverarch = "$siterubyver/#{c['arch']}" + end + parameterize = lambda {|path| + path.sub(/\A#{Regexp.quote(c['prefix'])}/, '$prefix') + } + + if arg = c['configure_args'].split.detect {|arg| /--with-make-prog=/ =~ arg } + makeprog = arg.sub(/'/, '').split(/=/, 2)[1] + else + makeprog = 'make' + end + + [ + ExecItem.new('installdirs', 'std/site/home', + 'std: install under libruby; site: install under site_ruby; home: install under $HOME')\ + {|val, table| + case val + when 'std' + table['rbdir'] = '$librubyver' + table['sodir'] = '$librubyverarch' + when 'site' + table['rbdir'] = '$siterubyver' + table['sodir'] = '$siterubyverarch' + when 'home' + setup_rb_error '$HOME was not set' unless ENV['HOME'] + table['prefix'] = ENV['HOME'] + table['rbdir'] = '$libdir/ruby' + table['sodir'] = '$libdir/ruby' + end + }, + PathItem.new('prefix', 'path', c['prefix'], + 'path prefix of target environment'), + PathItem.new('bindir', 'path', parameterize.call(c['bindir']), + 'the directory for commands'), + PathItem.new('libdir', 'path', parameterize.call(c['libdir']), + 'the directory for libraries'), + PathItem.new('datadir', 'path', parameterize.call(c['datadir']), + 'the directory for shared data'), + PathItem.new('mandir', 'path', parameterize.call(c['mandir']), + 'the directory for man pages'), + PathItem.new('sysconfdir', 'path', parameterize.call(c['sysconfdir']), + 'the directory for system configuration files'), + PathItem.new('localstatedir', 'path', parameterize.call(c['localstatedir']), + 'the directory for local state data'), + PathItem.new('libruby', 'path', libruby, + 'the directory for ruby libraries'), + PathItem.new('librubyver', 'path', librubyver, + 'the directory for standard ruby libraries'), + PathItem.new('librubyverarch', 'path', librubyverarch, + 'the directory for standard ruby extensions'), + PathItem.new('siteruby', 'path', siteruby, + 'the directory for version-independent aux ruby libraries'), + PathItem.new('siterubyver', 'path', siterubyver, + 'the directory for aux ruby libraries'), + PathItem.new('siterubyverarch', 'path', siterubyverarch, + 'the directory for aux ruby binaries'), + PathItem.new('rbdir', 'path', '$siterubyver', + 'the directory for ruby scripts'), + PathItem.new('sodir', 'path', '$siterubyverarch', + 'the directory for ruby extentions'), + PathItem.new('rubypath', 'path', rubypath, + 'the path to set to #! line'), + ProgramItem.new('rubyprog', 'name', rubypath, + 'the ruby program using for installation'), + ProgramItem.new('makeprog', 'name', makeprog, + 'the make program to compile ruby extentions'), + SelectItem.new('shebang', 'all/ruby/never', 'ruby', + 'shebang line (#!) editing mode'), + BoolItem.new('without-ext', 'yes/no', 'no', + 'does not compile/install ruby extentions') + ] + end + private :standard_entries + + def load_multipackage_entries + multipackage_entries().each do |ent| + add ent + end + end + + def multipackage_entries + [ + PackageSelectionItem.new('with', 'name,name...', '', 'ALL', + 'package names that you want to install'), + PackageSelectionItem.new('without', 'name,name...', '', 'NONE', + 'package names that you do not want to install') + ] + end + private :multipackage_entries + + ALIASES = { + 'std-ruby' => 'librubyver', + 'stdruby' => 'librubyver', + 'rubylibdir' => 'librubyver', + 'archdir' => 'librubyverarch', + 'site-ruby-common' => 'siteruby', # For backward compatibility + 'site-ruby' => 'siterubyver', # For backward compatibility + 'bin-dir' => 'bindir', + 'bin-dir' => 'bindir', + 'rb-dir' => 'rbdir', + 'so-dir' => 'sodir', + 'data-dir' => 'datadir', + 'ruby-path' => 'rubypath', + 'ruby-prog' => 'rubyprog', + 'ruby' => 'rubyprog', + 'make-prog' => 'makeprog', + 'make' => 'makeprog' + } + + def fixup + ALIASES.each do |ali, name| + @table[ali] = @table[name] + end + @items.freeze + @table.freeze + @options_re = /\A--(#{@table.keys.join('|')})(?:=(.*))?\z/ + end + + def parse_opt(opt) + m = @options_re.match(opt) or setup_rb_error "config: unknown option #{opt}" + m.to_a[1,2] + end + + def dllext + @rbconfig['DLEXT'] + end + + def value_config?(name) + lookup(name).value? + end + + class Item + def initialize(name, template, default, desc) + @name = name.freeze + @template = template + @value = default + @default = default + @description = desc + end + + attr_reader :name + attr_reader :description + + attr_accessor :default + alias help_default default + + def help_opt + "--#{@name}=#{@template}" + end + + def value? + true + end + + def value + @value + end + + def resolve(table) + @value.gsub(%r<\$([^/]+)>) { table[$1] } + end + + def set(val) + @value = check(val) + end + + private + + def check(val) + setup_rb_error "config: --#{name} requires argument" unless val + val + end + end + + class BoolItem < Item + def config_type + 'bool' + end + + def help_opt + "--#{@name}" + end + + private + + def check(val) + return 'yes' unless val + case val + when /\Ay(es)?\z/i, /\At(rue)?\z/i then 'yes' + when /\An(o)?\z/i, /\Af(alse)\z/i then 'no' + else + setup_rb_error "config: --#{@name} accepts only yes/no for argument" + end + end + end + + class PathItem < Item + def config_type + 'path' + end + + private + + def check(path) + setup_rb_error "config: --#{@name} requires argument" unless path + path[0,1] == '$' ? path : File.expand_path(path) + end + end + + class ProgramItem < Item + def config_type + 'program' + end + end + + class SelectItem < Item + def initialize(name, selection, default, desc) + super + @ok = selection.split('/') + end + + def config_type + 'select' + end + + private + + def check(val) + unless @ok.include?(val.strip) + setup_rb_error "config: use --#{@name}=#{@template} (#{val})" + end + val.strip + end + end + + class ExecItem < Item + def initialize(name, selection, desc, &block) + super name, selection, nil, desc + @ok = selection.split('/') + @action = block + end + + def config_type + 'exec' + end + + def value? + false + end + + def resolve(table) + setup_rb_error "$#{name()} wrongly used as option value" + end + + undef set + + def evaluate(val, table) + v = val.strip.downcase + unless @ok.include?(v) + setup_rb_error "invalid option --#{@name}=#{val} (use #{@template})" + end + @action.call v, table + end + end + + class PackageSelectionItem < Item + def initialize(name, template, default, help_default, desc) + super name, template, default, desc + @help_default = help_default + end + + attr_reader :help_default + + def config_type + 'package' + end + + private + + def check(val) + unless File.dir?("packages/#{val}") + setup_rb_error "config: no such package: #{val}" + end + val + end + end + + class MetaConfigEnvironment + def initialize(config, installer) + @config = config + @installer = installer + end + + def config_names + @config.names + end + + def config?(name) + @config.key?(name) + end + + def bool_config?(name) + @config.lookup(name).config_type == 'bool' + end + + def path_config?(name) + @config.lookup(name).config_type == 'path' + end + + def value_config?(name) + @config.lookup(name).config_type != 'exec' + end + + def add_config(item) + @config.add item + end + + def add_bool_config(name, default, desc) + @config.add BoolItem.new(name, 'yes/no', default ? 'yes' : 'no', desc) + end + + def add_path_config(name, default, desc) + @config.add PathItem.new(name, 'path', default, desc) + end + + def set_config_default(name, default) + @config.lookup(name).default = default + end + + def remove_config(name) + @config.remove(name) + end + + # For only multipackage + def packages + raise '[setup.rb fatal] multi-package metaconfig API packages() called for single-package; contact application package vendor' unless @installer + @installer.packages + end + + # For only multipackage + def declare_packages(list) + raise '[setup.rb fatal] multi-package metaconfig API declare_packages() called for single-package; contact application package vendor' unless @installer + @installer.packages = list + end + end + +end # class ConfigTable + + +# This module requires: #verbose?, #no_harm? +module FileOperations + + def mkdir_p(dirname, prefix = nil) + dirname = prefix + File.expand_path(dirname) if prefix + $stderr.puts "mkdir -p #{dirname}" if verbose? + return if no_harm? + + # Does not check '/', it's too abnormal. + dirs = File.expand_path(dirname).split(%r<(?=/)>) + if /\A[a-z]:\z/i =~ dirs[0] + disk = dirs.shift + dirs[0] = disk + dirs[0] + end + dirs.each_index do |idx| + path = dirs[0..idx].join('') + Dir.mkdir path unless File.dir?(path) + end + end + + def rm_f(path) + $stderr.puts "rm -f #{path}" if verbose? + return if no_harm? + force_remove_file path + end + + def rm_rf(path) + $stderr.puts "rm -rf #{path}" if verbose? + return if no_harm? + remove_tree path + end + + def remove_tree(path) + if File.symlink?(path) + remove_file path + elsif File.dir?(path) + remove_tree0 path + else + force_remove_file path + end + end + + def remove_tree0(path) + Dir.foreach(path) do |ent| + next if ent == '.' + next if ent == '..' + entpath = "#{path}/#{ent}" + if File.symlink?(entpath) + remove_file entpath + elsif File.dir?(entpath) + remove_tree0 entpath + else + force_remove_file entpath + end + end + begin + Dir.rmdir path + rescue Errno::ENOTEMPTY + # directory may not be empty + end + end + + def move_file(src, dest) + force_remove_file dest + begin + File.rename src, dest + rescue + File.open(dest, 'wb') {|f| + f.write File.binread(src) + } + File.chmod File.stat(src).mode, dest + File.unlink src + end + end + + def force_remove_file(path) + begin + remove_file path + rescue + end + end + + def remove_file(path) + File.chmod 0777, path + File.unlink path + end + + def install(from, dest, mode, prefix = nil) + $stderr.puts "install #{from} #{dest}" if verbose? + return if no_harm? + + realdest = prefix ? prefix + File.expand_path(dest) : dest + realdest = File.join(realdest, File.basename(from)) if File.dir?(realdest) + str = File.binread(from) + if diff?(str, realdest) + verbose_off { + rm_f realdest if File.exist?(realdest) + } + File.open(realdest, 'wb') {|f| + f.write str + } + File.chmod mode, realdest + + File.open("#{objdir_root()}/InstalledFiles", 'a') {|f| + if prefix + f.puts realdest.sub(prefix, '') + else + f.puts realdest + end + } + end + end + + def diff?(new_content, path) + return true unless File.exist?(path) + new_content != File.binread(path) + end + + def command(*args) + $stderr.puts args.join(' ') if verbose? + system(*args) or raise RuntimeError, + "system(#{args.map{|a| a.inspect }.join(' ')}) failed" + end + + def ruby(*args) + command config('rubyprog'), *args + end + + def make(task = nil) + command(*[config('makeprog'), task].compact) + end + + def extdir?(dir) + File.exist?("#{dir}/MANIFEST") or File.exist?("#{dir}/extconf.rb") + end + + def files_of(dir) + Dir.open(dir) {|d| + return d.select {|ent| File.file?("#{dir}/#{ent}") } + } + end + + DIR_REJECT = %w( . .. CVS SCCS RCS CVS.adm .svn ) + + def directories_of(dir) + Dir.open(dir) {|d| + return d.select {|ent| File.dir?("#{dir}/#{ent}") } - DIR_REJECT + } + end + +end + + +# This module requires: #srcdir_root, #objdir_root, #relpath +module HookScriptAPI + + def get_config(key) + @config[key] + end + + alias config get_config + + # obsolete: use metaconfig to change configuration + def set_config(key, val) + @config[key] = val + end + + # + # srcdir/objdir (works only in the package directory) + # + + def curr_srcdir + "#{srcdir_root()}/#{relpath()}" + end + + def curr_objdir + "#{objdir_root()}/#{relpath()}" + end + + def srcfile(path) + "#{curr_srcdir()}/#{path}" + end + + def srcexist?(path) + File.exist?(srcfile(path)) + end + + def srcdirectory?(path) + File.dir?(srcfile(path)) + end + + def srcfile?(path) + File.file?(srcfile(path)) + end + + def srcentries(path = '.') + Dir.open("#{curr_srcdir()}/#{path}") {|d| + return d.to_a - %w(. ..) + } + end + + def srcfiles(path = '.') + srcentries(path).select {|fname| + File.file?(File.join(curr_srcdir(), path, fname)) + } + end + + def srcdirectories(path = '.') + srcentries(path).select {|fname| + File.dir?(File.join(curr_srcdir(), path, fname)) + } + end + +end + + +class ToplevelInstaller + + Version = '3.4.1' + Copyright = 'Copyright (c) 2000-2005 Minero Aoki' + + TASKS = [ + [ 'all', 'do config, setup, then install' ], + [ 'config', 'saves your configurations' ], + [ 'show', 'shows current configuration' ], + [ 'setup', 'compiles ruby extentions and others' ], + [ 'install', 'installs files' ], + [ 'test', 'run all tests in test/' ], + [ 'clean', "does `make clean' for each extention" ], + [ 'distclean',"does `make distclean' for each extention" ] + ] + + def ToplevelInstaller.invoke + config = ConfigTable.new(load_rbconfig()) + config.load_standard_entries + config.load_multipackage_entries if multipackage? + config.fixup + klass = (multipackage?() ? ToplevelInstallerMulti : ToplevelInstaller) + klass.new(File.dirname($0), config).invoke + end + + def ToplevelInstaller.multipackage? + File.dir?(File.dirname($0) + '/packages') + end + + def ToplevelInstaller.load_rbconfig + if arg = ARGV.detect {|arg| /\A--rbconfig=/ =~ arg } + ARGV.delete(arg) + load File.expand_path(arg.split(/=/, 2)[1]) + $".push 'rbconfig.rb' + else + require 'rbconfig' + end + ::Config::CONFIG + end + + def initialize(ardir_root, config) + @ardir = File.expand_path(ardir_root) + @config = config + # cache + @valid_task_re = nil + end + + def config(key) + @config[key] + end + + def inspect + "#<#{self.class} #{__id__()}>" + end + + def invoke + run_metaconfigs + case task = parsearg_global() + when nil, 'all' + parsearg_config + init_installers + exec_config + exec_setup + exec_install + else + case task + when 'config', 'test' + ; + when 'clean', 'distclean' + @config.load_savefile if File.exist?(@config.savefile) + else + @config.load_savefile + end + __send__ "parsearg_#{task}" + init_installers + __send__ "exec_#{task}" + end + end + + def run_metaconfigs + @config.load_script "#{@ardir}/metaconfig" + end + + def init_installers + @installer = Installer.new(@config, @ardir, File.expand_path('.')) + end + + # + # Hook Script API bases + # + + def srcdir_root + @ardir + end + + def objdir_root + '.' + end + + def relpath + '.' + end + + # + # Option Parsing + # + + def parsearg_global + while arg = ARGV.shift + case arg + when /\A\w+\z/ + setup_rb_error "invalid task: #{arg}" unless valid_task?(arg) + return arg + when '-q', '--quiet' + @config.verbose = false + when '--verbose' + @config.verbose = true + when '--help' + print_usage $stdout + exit 0 + when '--version' + puts "#{File.basename($0)} version #{Version}" + exit 0 + when '--copyright' + puts Copyright + exit 0 + else + setup_rb_error "unknown global option '#{arg}'" + end + end + nil + end + + def valid_task?(t) + valid_task_re() =~ t + end + + def valid_task_re + @valid_task_re ||= /\A(?:#{TASKS.map {|task,desc| task }.join('|')})\z/ + end + + def parsearg_no_options + unless ARGV.empty? + task = caller(0).first.slice(%r<`parsearg_(\w+)'>, 1) + setup_rb_error "#{task}: unknown options: #{ARGV.join(' ')}" + end + end + + alias parsearg_show parsearg_no_options + alias parsearg_setup parsearg_no_options + alias parsearg_test parsearg_no_options + alias parsearg_clean parsearg_no_options + alias parsearg_distclean parsearg_no_options + + def parsearg_config + evalopt = [] + set = [] + @config.config_opt = [] + while i = ARGV.shift + if /\A--?\z/ =~ i + @config.config_opt = ARGV.dup + break + end + name, value = *@config.parse_opt(i) + if @config.value_config?(name) + @config[name] = value + else + evalopt.push [name, value] + end + set.push name + end + evalopt.each do |name, value| + @config.lookup(name).evaluate value, @config + end + # Check if configuration is valid + set.each do |n| + @config[n] if @config.value_config?(n) + end + end + + def parsearg_install + @config.no_harm = false + @config.install_prefix = '' + while a = ARGV.shift + case a + when '--no-harm' + @config.no_harm = true + when /\A--prefix=/ + path = a.split(/=/, 2)[1] + path = File.expand_path(path) unless path[0,1] == '/' + @config.install_prefix = path + else + setup_rb_error "install: unknown option #{a}" + end + end + end + + def print_usage(out) + out.puts 'Typical Installation Procedure:' + out.puts " $ ruby #{File.basename $0} config" + out.puts " $ ruby #{File.basename $0} setup" + out.puts " # ruby #{File.basename $0} install (may require root privilege)" + out.puts + out.puts 'Detailed Usage:' + out.puts " ruby #{File.basename $0} <global option>" + out.puts " ruby #{File.basename $0} [<global options>] <task> [<task options>]" + + fmt = " %-24s %s\n" + out.puts + out.puts 'Global options:' + out.printf fmt, '-q,--quiet', 'suppress message outputs' + out.printf fmt, ' --verbose', 'output messages verbosely' + out.printf fmt, ' --help', 'print this message' + out.printf fmt, ' --version', 'print version and quit' + out.printf fmt, ' --copyright', 'print copyright and quit' + out.puts + out.puts 'Tasks:' + TASKS.each do |name, desc| + out.printf fmt, name, desc + end + + fmt = " %-24s %s [%s]\n" + out.puts + out.puts 'Options for CONFIG or ALL:' + @config.each do |item| + out.printf fmt, item.help_opt, item.description, item.help_default + end + out.printf fmt, '--rbconfig=path', 'rbconfig.rb to load',"running ruby's" + out.puts + out.puts 'Options for INSTALL:' + out.printf fmt, '--no-harm', 'only display what to do if given', 'off' + out.printf fmt, '--prefix=path', 'install path prefix', '' + out.puts + end + + # + # Task Handlers + # + + def exec_config + @installer.exec_config + @config.save # must be final + end + + def exec_setup + @installer.exec_setup + end + + def exec_install + @installer.exec_install + end + + def exec_test + @installer.exec_test + end + + def exec_show + @config.each do |i| + printf "%-20s %s\n", i.name, i.value if i.value? + end + end + + def exec_clean + @installer.exec_clean + end + + def exec_distclean + @installer.exec_distclean + end + +end # class ToplevelInstaller + + +class ToplevelInstallerMulti < ToplevelInstaller + + include FileOperations + + def initialize(ardir_root, config) + super + @packages = directories_of("#{@ardir}/packages") + raise 'no package exists' if @packages.empty? + @root_installer = Installer.new(@config, @ardir, File.expand_path('.')) + end + + def run_metaconfigs + @config.load_script "#{@ardir}/metaconfig", self + @packages.each do |name| + @config.load_script "#{@ardir}/packages/#{name}/metaconfig" + end + end + + attr_reader :packages + + def packages=(list) + raise 'package list is empty' if list.empty? + list.each do |name| + raise "directory packages/#{name} does not exist"\ + unless File.dir?("#{@ardir}/packages/#{name}") + end + @packages = list + end + + def init_installers + @installers = {} + @packages.each do |pack| + @installers[pack] = Installer.new(@config, + "#{@ardir}/packages/#{pack}", + "packages/#{pack}") + end + with = extract_selection(config('with')) + without = extract_selection(config('without')) + @selected = @installers.keys.select {|name| + (with.empty? or with.include?(name)) \ + and not without.include?(name) + } + end + + def extract_selection(list) + a = list.split(/,/) + a.each do |name| + setup_rb_error "no such package: #{name}" unless @installers.key?(name) + end + a + end + + def print_usage(f) + super + f.puts 'Inluded packages:' + f.puts ' ' + @packages.sort.join(' ') + f.puts + end + + # + # Task Handlers + # + + def exec_config + run_hook 'pre-config' + each_selected_installers {|inst| inst.exec_config } + run_hook 'post-config' + @config.save # must be final + end + + def exec_setup + run_hook 'pre-setup' + each_selected_installers {|inst| inst.exec_setup } + run_hook 'post-setup' + end + + def exec_install + run_hook 'pre-install' + each_selected_installers {|inst| inst.exec_install } + run_hook 'post-install' + end + + def exec_test + run_hook 'pre-test' + each_selected_installers {|inst| inst.exec_test } + run_hook 'post-test' + end + + def exec_clean + rm_f @config.savefile + run_hook 'pre-clean' + each_selected_installers {|inst| inst.exec_clean } + run_hook 'post-clean' + end + + def exec_distclean + rm_f @config.savefile + run_hook 'pre-distclean' + each_selected_installers {|inst| inst.exec_distclean } + run_hook 'post-distclean' + end + + # + # lib + # + + def each_selected_installers + Dir.mkdir 'packages' unless File.dir?('packages') + @selected.each do |pack| + $stderr.puts "Processing the package `#{pack}' ..." if verbose? + Dir.mkdir "packages/#{pack}" unless File.dir?("packages/#{pack}") + Dir.chdir "packages/#{pack}" + yield @installers[pack] + Dir.chdir '../..' + end + end + + def run_hook(id) + @root_installer.run_hook id + end + + # module FileOperations requires this + def verbose? + @config.verbose? + end + + # module FileOperations requires this + def no_harm? + @config.no_harm? + end + +end # class ToplevelInstallerMulti + + +class Installer + + FILETYPES = %w( bin lib ext data conf man ) + + include FileOperations + include HookScriptAPI + + def initialize(config, srcroot, objroot) + @config = config + @srcdir = File.expand_path(srcroot) + @objdir = File.expand_path(objroot) + @currdir = '.' + end + + def inspect + "#<#{self.class} #{File.basename(@srcdir)}>" + end + + def noop(rel) + end + + # + # Hook Script API base methods + # + + def srcdir_root + @srcdir + end + + def objdir_root + @objdir + end + + def relpath + @currdir + end + + # + # Config Access + # + + # module FileOperations requires this + def verbose? + @config.verbose? + end + + # module FileOperations requires this + def no_harm? + @config.no_harm? + end + + def verbose_off + begin + save, @config.verbose = @config.verbose?, false + yield + ensure + @config.verbose = save + end + end + + # + # TASK config + # + + def exec_config + exec_task_traverse 'config' + end + + alias config_dir_bin noop + alias config_dir_lib noop + + def config_dir_ext(rel) + extconf if extdir?(curr_srcdir()) + end + + alias config_dir_data noop + alias config_dir_conf noop + alias config_dir_man noop + + def extconf + ruby "#{curr_srcdir()}/extconf.rb", *@config.config_opt + end + + # + # TASK setup + # + + def exec_setup + exec_task_traverse 'setup' + end + + def setup_dir_bin(rel) + files_of(curr_srcdir()).each do |fname| + update_shebang_line "#{curr_srcdir()}/#{fname}" + end + end + + alias setup_dir_lib noop + + def setup_dir_ext(rel) + make if extdir?(curr_srcdir()) + end + + alias setup_dir_data noop + alias setup_dir_conf noop + alias setup_dir_man noop + + def update_shebang_line(path) + return if no_harm? + return if config('shebang') == 'never' + old = Shebang.load(path) + if old + $stderr.puts "warning: #{path}: Shebang line includes too many args. It is not portable and your program may not work." if old.args.size > 1 + new = new_shebang(old) + return if new.to_s == old.to_s + else + return unless config('shebang') == 'all' + new = Shebang.new(config('rubypath')) + end + $stderr.puts "updating shebang: #{File.basename(path)}" if verbose? + open_atomic_writer(path) {|output| + File.open(path, 'rb') {|f| + f.gets if old # discard + output.puts new.to_s + output.print f.read + } + } + end + + def new_shebang(old) + if /\Aruby/ =~ File.basename(old.cmd) + Shebang.new(config('rubypath'), old.args) + elsif File.basename(old.cmd) == 'env' and old.args.first == 'ruby' + Shebang.new(config('rubypath'), old.args[1..-1]) + else + return old unless config('shebang') == 'all' + Shebang.new(config('rubypath')) + end + end + + def open_atomic_writer(path, &block) + tmpfile = File.basename(path) + '.tmp' + begin + File.open(tmpfile, 'wb', &block) + File.rename tmpfile, File.basename(path) + ensure + File.unlink tmpfile if File.exist?(tmpfile) + end + end + + class Shebang + def Shebang.load(path) + line = nil + File.open(path) {|f| + line = f.gets + } + return nil unless /\A#!/ =~ line + parse(line) + end + + def Shebang.parse(line) + cmd, *args = *line.strip.sub(/\A\#!/, '').split(' ') + new(cmd, args) + end + + def initialize(cmd, args = []) + @cmd = cmd + @args = args + end + + attr_reader :cmd + attr_reader :args + + def to_s + "#! #{@cmd}" + (@args.empty? ? '' : " #{@args.join(' ')}") + end + end + + # + # TASK install + # + + def exec_install + rm_f 'InstalledFiles' + exec_task_traverse 'install' + end + + def install_dir_bin(rel) + install_files targetfiles(), "#{config('bindir')}/#{rel}", 0755 + end + + def install_dir_lib(rel) + install_files libfiles(), "#{config('rbdir')}/#{rel}", 0644 + end + + def install_dir_ext(rel) + return unless extdir?(curr_srcdir()) + install_files rubyextentions('.'), + "#{config('sodir')}/#{File.dirname(rel)}", + 0555 + end + + def install_dir_data(rel) + install_files targetfiles(), "#{config('datadir')}/#{rel}", 0644 + end + + def install_dir_conf(rel) + # FIXME: should not remove current config files + # (rename previous file to .old/.org) + install_files targetfiles(), "#{config('sysconfdir')}/#{rel}", 0644 + end + + def install_dir_man(rel) + install_files targetfiles(), "#{config('mandir')}/#{rel}", 0644 + end + + def install_files(list, dest, mode) + mkdir_p dest, @config.install_prefix + list.each do |fname| + install fname, dest, mode, @config.install_prefix + end + end + + def libfiles + glob_reject(%w(*.y *.output), targetfiles()) + end + + def rubyextentions(dir) + ents = glob_select("*.#{@config.dllext}", targetfiles()) + if ents.empty? + setup_rb_error "no ruby extention exists: 'ruby #{$0} setup' first" + end + ents + end + + def targetfiles + mapdir(existfiles() - hookfiles()) + end + + def mapdir(ents) + ents.map {|ent| + if File.exist?(ent) + then ent # objdir + else "#{curr_srcdir()}/#{ent}" # srcdir + end + } + end + + # picked up many entries from cvs-1.11.1/src/ignore.c + JUNK_FILES = %w( + core RCSLOG tags TAGS .make.state + .nse_depinfo #* .#* cvslog.* ,* .del-* *.olb + *~ *.old *.bak *.BAK *.orig *.rej _$* *$ + + *.org *.in .* + ) + + def existfiles + glob_reject(JUNK_FILES, (files_of(curr_srcdir()) | files_of('.'))) + end + + def hookfiles + %w( pre-%s post-%s pre-%s.rb post-%s.rb ).map {|fmt| + %w( config setup install clean ).map {|t| sprintf(fmt, t) } + }.flatten + end + + def glob_select(pat, ents) + re = globs2re([pat]) + ents.select {|ent| re =~ ent } + end + + def glob_reject(pats, ents) + re = globs2re(pats) + ents.reject {|ent| re =~ ent } + end + + GLOB2REGEX = { + '.' => '\.', + '$' => '\$', + '#' => '\#', + '*' => '.*' + } + + def globs2re(pats) + /\A(?:#{ + pats.map {|pat| pat.gsub(/[\.\$\#\*]/) {|ch| GLOB2REGEX[ch] } }.join('|') + })\z/ + end + + # + # TASK test + # + + TESTDIR = 'test' + + def exec_test + unless File.directory?('test') + $stderr.puts 'no test in this package' if verbose? + return + end + $stderr.puts 'Running tests...' if verbose? + begin + require 'test/unit' + rescue LoadError + setup_rb_error 'test/unit cannot loaded. You need Ruby 1.8 or later to invoke this task.' + end + runner = Test::Unit::AutoRunner.new(true) + runner.to_run << TESTDIR + runner.run + end + + # + # TASK clean + # + + def exec_clean + exec_task_traverse 'clean' + rm_f @config.savefile + rm_f 'InstalledFiles' + end + + alias clean_dir_bin noop + alias clean_dir_lib noop + alias clean_dir_data noop + alias clean_dir_conf noop + alias clean_dir_man noop + + def clean_dir_ext(rel) + return unless extdir?(curr_srcdir()) + make 'clean' if File.file?('Makefile') + end + + # + # TASK distclean + # + + def exec_distclean + exec_task_traverse 'distclean' + rm_f @config.savefile + rm_f 'InstalledFiles' + end + + alias distclean_dir_bin noop + alias distclean_dir_lib noop + + def distclean_dir_ext(rel) + return unless extdir?(curr_srcdir()) + make 'distclean' if File.file?('Makefile') + end + + alias distclean_dir_data noop + alias distclean_dir_conf noop + alias distclean_dir_man noop + + # + # Traversing + # + + def exec_task_traverse(task) + run_hook "pre-#{task}" + FILETYPES.each do |type| + if type == 'ext' and config('without-ext') == 'yes' + $stderr.puts 'skipping ext/* by user option' if verbose? + next + end + traverse task, type, "#{task}_dir_#{type}" + end + run_hook "post-#{task}" + end + + def traverse(task, rel, mid) + dive_into(rel) { + run_hook "pre-#{task}" + __send__ mid, rel.sub(%r[\A.*?(?:/|\z)], '') + directories_of(curr_srcdir()).each do |d| + traverse task, "#{rel}/#{d}", mid + end + run_hook "post-#{task}" + } + end + + def dive_into(rel) + return unless File.dir?("#{@srcdir}/#{rel}") + + dir = File.basename(rel) + Dir.mkdir dir unless File.dir?(dir) + prevdir = Dir.pwd + Dir.chdir dir + $stderr.puts '---> ' + rel if verbose? + @currdir = rel + yield + Dir.chdir prevdir + $stderr.puts '<--- ' + rel if verbose? + @currdir = File.dirname(rel) + end + + def run_hook(id) + path = [ "#{curr_srcdir()}/#{id}", + "#{curr_srcdir()}/#{id}.rb" ].detect {|cand| File.file?(cand) } + return unless path + begin + instance_eval File.read(path), path, 1 + rescue + raise if $DEBUG + setup_rb_error "hook #{path} failed:\n" + $!.message + end + end + +end # class Installer + + +class SetupError < StandardError; end + +def setup_rb_error(msg) + raise SetupError, msg +end + +if $0 == __FILE__ + begin + ToplevelInstaller.invoke + rescue SetupError + raise if $DEBUG + $stderr.puts $!.message + $stderr.puts "Try 'ruby #{$0} --help' for detailed usage." + exit 1 + end +end diff --git a/lib/rb/spec/ThriftSpec.thrift b/lib/rb/spec/ThriftSpec.thrift new file mode 100644 index 000000000..fe5a8aae4 --- /dev/null +++ b/lib/rb/spec/ThriftSpec.thrift @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +namespace rb SpecNamespace + +struct Hello { + 1: string greeting = "hello world" +} + +struct Foo { + 1: i32 simple = 53, + 2: string words = "words", + 3: Hello hello = {'greeting' : "hello, world!"}, + 4: list<i32> ints = [1, 2, 2, 3], + 5: map<i32, map<string, double>> complex, + 6: set<i16> shorts = [5, 17, 239], + 7: optional string opt_string +} + +struct BoolStruct { + 1: bool yesno = 1 +} + +struct SimpleList { + 1: list<bool> bools, + 2: list<byte> bytes, + 3: list<i16> i16s, + 4: list<i32> i32s, + 5: list<i64> i64s, + 6: list<double> doubles, + 7: list<string> strings, + 8: list<map<i16, i16>> maps, + 9: list<list<i16>> lists, + 10: list<set<i16>> sets, + 11: list<Hello> hellos +} + +exception Xception { + 1: string message, + 2: i32 code = 1 +} + +service NonblockingService { + Hello greeting(1:bool english) + bool block() + oneway void unblock(1:i32 n) + oneway void shutdown() + void sleep(1:double seconds) +} diff --git a/lib/rb/spec/base_protocol_spec.rb b/lib/rb/spec/base_protocol_spec.rb new file mode 100644 index 000000000..efb16d8c8 --- /dev/null +++ b/lib/rb/spec/base_protocol_spec.rb @@ -0,0 +1,160 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftBaseProtocolSpec < Spec::ExampleGroup + include Thrift + + before(:each) do + @trans = mock("MockTransport") + @prot = BaseProtocol.new(@trans) + end + + describe BaseProtocol do + # most of the methods are stubs, so we can ignore them + + it "should make trans accessible" do + @prot.trans.should eql(@trans) + end + + it "should write out a field nicely" do + @prot.should_receive(:write_field_begin).with('field', 'type', 'fid').ordered + @prot.should_receive(:write_type).with('type', 'value').ordered + @prot.should_receive(:write_field_end).ordered + @prot.write_field('field', 'type', 'fid', 'value') + end + + it "should write out the different types" do + @prot.should_receive(:write_bool).with('bool').ordered + @prot.should_receive(:write_byte).with('byte').ordered + @prot.should_receive(:write_double).with('double').ordered + @prot.should_receive(:write_i16).with('i16').ordered + @prot.should_receive(:write_i32).with('i32').ordered + @prot.should_receive(:write_i64).with('i64').ordered + @prot.should_receive(:write_string).with('string').ordered + struct = mock('Struct') + struct.should_receive(:write).with(@prot).ordered + @prot.write_type(Types::BOOL, 'bool') + @prot.write_type(Types::BYTE, 'byte') + @prot.write_type(Types::DOUBLE, 'double') + @prot.write_type(Types::I16, 'i16') + @prot.write_type(Types::I32, 'i32') + @prot.write_type(Types::I64, 'i64') + @prot.write_type(Types::STRING, 'string') + @prot.write_type(Types::STRUCT, struct) + # all other types are not implemented + [Types::STOP, Types::VOID, Types::MAP, Types::SET, Types::LIST].each do |type| + lambda { @prot.write_type(type, type.to_s) }.should raise_error(NotImplementedError) + end + end + + it "should read the different types" do + @prot.should_receive(:read_bool).ordered + @prot.should_receive(:read_byte).ordered + @prot.should_receive(:read_i16).ordered + @prot.should_receive(:read_i32).ordered + @prot.should_receive(:read_i64).ordered + @prot.should_receive(:read_double).ordered + @prot.should_receive(:read_string).ordered + @prot.read_type(Types::BOOL) + @prot.read_type(Types::BYTE) + @prot.read_type(Types::I16) + @prot.read_type(Types::I32) + @prot.read_type(Types::I64) + @prot.read_type(Types::DOUBLE) + @prot.read_type(Types::STRING) + # all other types are not implemented + [Types::STOP, Types::VOID, Types::MAP, Types::SET, Types::LIST].each do |type| + lambda { @prot.read_type(type) }.should raise_error(NotImplementedError) + end + end + + it "should skip the basic types" do + @prot.should_receive(:read_bool).ordered + @prot.should_receive(:read_byte).ordered + @prot.should_receive(:read_i16).ordered + @prot.should_receive(:read_i32).ordered + @prot.should_receive(:read_i64).ordered + @prot.should_receive(:read_double).ordered + @prot.should_receive(:read_string).ordered + @prot.skip(Types::BOOL) + @prot.skip(Types::BYTE) + @prot.skip(Types::I16) + @prot.skip(Types::I32) + @prot.skip(Types::I64) + @prot.skip(Types::DOUBLE) + @prot.skip(Types::STRING) + @prot.skip(Types::STOP) # should do absolutely nothing + end + + it "should skip structs" do + real_skip = @prot.method(:skip) + @prot.should_receive(:read_struct_begin).ordered + @prot.should_receive(:read_field_begin).exactly(4).times.and_return( + ['field 1', Types::STRING, 1], + ['field 2', Types::I32, 2], + ['field 3', Types::MAP, 3], + [nil, Types::STOP, 0] + ) + @prot.should_receive(:read_field_end).exactly(3).times + @prot.should_receive(:read_string).exactly(3).times + @prot.should_receive(:read_i32).ordered + @prot.should_receive(:read_map_begin).ordered.and_return([Types::STRING, Types::STRING, 1]) + # @prot.should_receive(:read_string).exactly(2).times + @prot.should_receive(:read_map_end).ordered + @prot.should_receive(:read_struct_end).ordered + real_skip.call(Types::STRUCT) + end + + it "should skip maps" do + real_skip = @prot.method(:skip) + @prot.should_receive(:read_map_begin).ordered.and_return([Types::STRING, Types::STRUCT, 1]) + @prot.should_receive(:read_string).ordered + @prot.should_receive(:read_struct_begin).ordered.and_return(["some_struct"]) + @prot.should_receive(:read_field_begin).ordered.and_return([nil, Types::STOP, nil]); + @prot.should_receive(:read_struct_end).ordered + @prot.should_receive(:read_map_end).ordered + real_skip.call(Types::MAP) + end + + it "should skip sets" do + real_skip = @prot.method(:skip) + @prot.should_receive(:read_set_begin).ordered.and_return([Types::I64, 9]) + @prot.should_receive(:read_i64).ordered.exactly(9).times + @prot.should_receive(:read_set_end) + real_skip.call(Types::SET) + end + + it "should skip lists" do + real_skip = @prot.method(:skip) + @prot.should_receive(:read_list_begin).ordered.and_return([Types::DOUBLE, 11]) + @prot.should_receive(:read_double).ordered.exactly(11).times + @prot.should_receive(:read_list_end) + real_skip.call(Types::LIST) + end + end + + describe BaseProtocolFactory do + it "should raise NotImplementedError" do + # returning nil since Protocol is just an abstract class + lambda {BaseProtocolFactory.new.get_protocol(mock("MockTransport"))}.should raise_error(NotImplementedError) + end + end +end diff --git a/lib/rb/spec/base_transport_spec.rb b/lib/rb/spec/base_transport_spec.rb new file mode 100644 index 000000000..71897759a --- /dev/null +++ b/lib/rb/spec/base_transport_spec.rb @@ -0,0 +1,344 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftBaseTransportSpec < Spec::ExampleGroup + include Thrift + + describe TransportException do + it "should make type accessible" do + exc = TransportException.new(TransportException::ALREADY_OPEN, "msg") + exc.type.should == TransportException::ALREADY_OPEN + exc.message.should == "msg" + end + end + + describe BaseTransport do + it "should read the specified size" do + transport = BaseTransport.new + transport.should_receive(:read).with(40).ordered.and_return("10 letters") + transport.should_receive(:read).with(30).ordered.and_return("fifteen letters") + transport.should_receive(:read).with(15).ordered.and_return("more characters") + transport.read_all(40).should == "10 lettersfifteen lettersmore characters" + end + + it "should stub out the rest of the methods" do + # can't test for stubbiness, so just make sure they're defined + [:open?, :open, :close, :read, :write, :flush].each do |sym| + BaseTransport.method_defined?(sym).should be_true + end + end + + it "should alias << to write" do + BaseTransport.instance_method(:<<).should == BaseTransport.instance_method(:write) + end + end + + describe BaseServerTransport do + it "should stub out its methods" do + [:listen, :accept, :close].each do |sym| + BaseServerTransport.method_defined?(sym).should be_true + end + end + end + + describe BaseTransportFactory do + it "should return the transport it's given" do + transport = mock("Transport") + BaseTransportFactory.new.get_transport(transport).should eql(transport) + end + end + + describe BufferedTransport do + it "should pass through everything but write/flush/read" do + trans = mock("Transport") + trans.should_receive(:open?).ordered.and_return("+ open?") + trans.should_receive(:open).ordered.and_return("+ open") + trans.should_receive(:flush).ordered # from the close + trans.should_receive(:close).ordered.and_return("+ close") + btrans = BufferedTransport.new(trans) + btrans.open?.should == "+ open?" + btrans.open.should == "+ open" + btrans.close.should == "+ close" + end + + it "should buffer reads in chunks of #{BufferedTransport::DEFAULT_BUFFER}" do + trans = mock("Transport") + trans.should_receive(:read).with(BufferedTransport::DEFAULT_BUFFER).and_return("lorum ipsum dolor emet") + btrans = BufferedTransport.new(trans) + btrans.read(6).should == "lorum " + btrans.read(6).should == "ipsum " + btrans.read(6).should == "dolor " + btrans.read(6).should == "emet" + end + + it "should buffer writes and send them on flush" do + trans = mock("Transport") + btrans = BufferedTransport.new(trans) + btrans.write("one/") + btrans.write("two/") + btrans.write("three/") + trans.should_receive(:write).with("one/two/three/").ordered + trans.should_receive(:flush).ordered + btrans.flush + end + + it "should only send buffered data once" do + trans = mock("Transport") + btrans = BufferedTransport.new(trans) + btrans.write("one/") + btrans.write("two/") + btrans.write("three/") + trans.should_receive(:write).with("one/two/three/") + trans.stub!(:flush) + btrans.flush + # Nothing to flush with no data + btrans.flush + end + + it "should flush on close" do + trans = mock("Transport") + trans.should_receive(:close) + btrans = BufferedTransport.new(trans) + btrans.should_receive(:flush) + btrans.close + end + + it "should not write to socket if there's no data" do + trans = mock("Transport") + trans.should_receive(:flush) + btrans = BufferedTransport.new(trans) + btrans.flush + end + end + + describe BufferedTransportFactory do + it "should wrap the given transport in a BufferedTransport" do + trans = mock("Transport") + btrans = mock("BufferedTransport") + BufferedTransport.should_receive(:new).with(trans).and_return(btrans) + BufferedTransportFactory.new.get_transport(trans).should == btrans + end + end + + describe FramedTransport do + before(:each) do + @trans = mock("Transport") + end + + it "should pass through open?/open/close" do + ftrans = FramedTransport.new(@trans) + @trans.should_receive(:open?).ordered.and_return("+ open?") + @trans.should_receive(:open).ordered.and_return("+ open") + @trans.should_receive(:close).ordered.and_return("+ close") + ftrans.open?.should == "+ open?" + ftrans.open.should == "+ open" + ftrans.close.should == "+ close" + end + + it "should pass through read when read is turned off" do + ftrans = FramedTransport.new(@trans, false, true) + @trans.should_receive(:read).with(17).ordered.and_return("+ read") + ftrans.read(17).should == "+ read" + end + + it "should pass through write/flush when write is turned off" do + ftrans = FramedTransport.new(@trans, true, false) + @trans.should_receive(:write).with("foo").ordered.and_return("+ write") + @trans.should_receive(:flush).ordered.and_return("+ flush") + ftrans.write("foo").should == "+ write" + ftrans.flush.should == "+ flush" + end + + it "should return a full frame if asked for >= the frame's length" do + frame = "this is a frame" + @trans.should_receive(:read_all).with(4).and_return("\000\000\000\017") + @trans.should_receive(:read_all).with(frame.length).and_return(frame) + FramedTransport.new(@trans).read(frame.length + 10).should == frame + end + + it "should return slices of the frame when asked for < the frame's length" do + frame = "this is a frame" + @trans.should_receive(:read_all).with(4).and_return("\000\000\000\017") + @trans.should_receive(:read_all).with(frame.length).and_return(frame) + ftrans = FramedTransport.new(@trans) + ftrans.read(4).should == "this" + ftrans.read(4).should == " is " + ftrans.read(16).should == "a frame" + end + + it "should return nothing if asked for <= 0" do + FramedTransport.new(@trans).read(-2).should == "" + end + + it "should pull a new frame when the first is exhausted" do + frame = "this is a frame" + frame2 = "yet another frame" + @trans.should_receive(:read_all).with(4).and_return("\000\000\000\017", "\000\000\000\021") + @trans.should_receive(:read_all).with(frame.length).and_return(frame) + @trans.should_receive(:read_all).with(frame2.length).and_return(frame2) + ftrans = FramedTransport.new(@trans) + ftrans.read(4).should == "this" + ftrans.read(8).should == " is a fr" + ftrans.read(6).should == "ame" + ftrans.read(4).should == "yet " + ftrans.read(16).should == "another frame" + end + + it "should buffer writes" do + ftrans = FramedTransport.new(@trans) + @trans.should_not_receive(:write) + ftrans.write("foo") + ftrans.write("bar") + ftrans.write("this is a frame") + end + + it "should write slices of the buffer" do + ftrans = FramedTransport.new(@trans) + ftrans.write("foobar", 3) + ftrans.write("barfoo", 1) + @trans.stub!(:flush) + @trans.should_receive(:write).with("\000\000\000\004foob") + ftrans.flush + end + + it "should flush frames with a 4-byte header" do + ftrans = FramedTransport.new(@trans) + @trans.should_receive(:write).with("\000\000\000\035one/two/three/this is a frame").ordered + @trans.should_receive(:flush).ordered + ftrans.write("one/") + ftrans.write("two/") + ftrans.write("three/") + ftrans.write("this is a frame") + ftrans.flush + end + + it "should not flush the same buffered data twice" do + ftrans = FramedTransport.new(@trans) + @trans.should_receive(:write).with("\000\000\000\007foo/bar") + @trans.stub!(:flush) + ftrans.write("foo") + ftrans.write("/bar") + ftrans.flush + @trans.should_receive(:write).with("\000\000\000\000") + ftrans.flush + end + end + + describe FramedTransportFactory do + it "should wrap the given transport in a FramedTransport" do + trans = mock("Transport") + FramedTransport.should_receive(:new).with(trans) + FramedTransportFactory.new.get_transport(trans) + end + end + + describe MemoryBufferTransport do + before(:each) do + @buffer = MemoryBufferTransport.new + end + + it "should accept a buffer on input and use it directly" do + s = "this is a test" + @buffer = MemoryBufferTransport.new(s) + @buffer.read(4).should == "this" + s.slice!(-4..-1) + @buffer.read(@buffer.available).should == " is a " + end + + it "should always remain open" do + @buffer.should be_open + @buffer.close + @buffer.should be_open + end + + it "should respond to peek and available" do + @buffer.write "some data" + @buffer.peek.should be_true + @buffer.available.should == 9 + @buffer.read(4) + @buffer.peek.should be_true + @buffer.available.should == 5 + @buffer.read(16) + @buffer.peek.should be_false + @buffer.available.should == 0 + end + + it "should be able to reset the buffer" do + @buffer.write "test data" + @buffer.reset_buffer("foobar") + @buffer.available.should == 6 + @buffer.read(10).should == "foobar" + @buffer.reset_buffer + @buffer.available.should == 0 + end + + it "should copy the given string whne resetting the buffer" do + s = "this is a test" + @buffer.reset_buffer(s) + @buffer.available.should == 14 + @buffer.read(10) + @buffer.available.should == 4 + s.should == "this is a test" + end + + it "should return from read what was given in write" do + @buffer.write "test data" + @buffer.read(4).should == "test" + @buffer.read(10).should == " data" + @buffer.read(10).should == "" + @buffer.write "foo" + @buffer.write " bar" + @buffer.read(10).should == "foo bar" + end + end + + describe IOStreamTransport do + before(:each) do + @input = mock("Input", :closed? => false) + @output = mock("Output", :closed? => false) + @trans = IOStreamTransport.new(@input, @output) + end + + it "should be open as long as both input or output are open" do + @trans.should be_open + @input.stub!(:closed?).and_return(true) + @trans.should be_open + @input.stub!(:closed?).and_return(false) + @output.stub!(:closed?).and_return(true) + @trans.should be_open + @input.stub!(:closed?).and_return(true) + @trans.should_not be_open + end + + it "should pass through read/write to input/output" do + @input.should_receive(:read).with(17).and_return("+ read") + @output.should_receive(:write).with("foobar").and_return("+ write") + @trans.read(17).should == "+ read" + @trans.write("foobar").should == "+ write" + end + + it "should close both input and output when closed" do + @input.should_receive(:close) + @output.should_receive(:close) + @trans.close + end + end +end diff --git a/lib/rb/spec/binary_protocol_accelerated_spec.rb b/lib/rb/spec/binary_protocol_accelerated_spec.rb new file mode 100644 index 000000000..0306cf5f2 --- /dev/null +++ b/lib/rb/spec/binary_protocol_accelerated_spec.rb @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/binary_protocol_spec_shared' +require File.dirname(__FILE__) + '/gen-rb/thrift_spec_types' + +class ThriftBinaryProtocolAcceleratedSpec < Spec::ExampleGroup + include Thrift + + describe Thrift::BinaryProtocolAccelerated do + # since BinaryProtocolAccelerated should be directly equivalent to + # BinaryProtocol, we don't need any custom specs! + it_should_behave_like 'a binary protocol' + + def protocol_class + BinaryProtocolAccelerated + end + end + + describe BinaryProtocolAcceleratedFactory do + it "should create a BinaryProtocolAccelerated" do + BinaryProtocolAcceleratedFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocolAccelerated) + end + end +end diff --git a/lib/rb/spec/binary_protocol_spec.rb b/lib/rb/spec/binary_protocol_spec.rb new file mode 100644 index 000000000..0abccb892 --- /dev/null +++ b/lib/rb/spec/binary_protocol_spec.rb @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/binary_protocol_spec_shared' + +class ThriftBinaryProtocolSpec < Spec::ExampleGroup + include Thrift + + describe BinaryProtocol do + it_should_behave_like 'a binary protocol' + + def protocol_class + BinaryProtocol + end + + it "should read a message header" do + @trans.should_receive(:read_all).exactly(2).times.and_return( + [protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::REPLY].pack('N'), + [42].pack('N') + ) + @prot.should_receive(:read_string).and_return('testMessage') + @prot.read_message_begin.should == ['testMessage', Thrift::MessageTypes::REPLY, 42] + end + + it "should raise an exception if the message header has the wrong version" do + @prot.should_receive(:read_i32).and_return(-1) + lambda { @prot.read_message_begin }.should raise_error(Thrift::ProtocolException, 'Missing version identifier') do |e| + e.type == Thrift::ProtocolException::BAD_VERSION + end + end + + it "should raise an exception if the message header does not exist and strict_read is enabled" do + @prot.should_receive(:read_i32).and_return(42) + @prot.should_receive(:strict_read).and_return(true) + lambda { @prot.read_message_begin }.should raise_error(Thrift::ProtocolException, 'No version identifier, old protocol client?') do |e| + e.type == Thrift::ProtocolException::BAD_VERSION + end + end + end + + describe BinaryProtocolFactory do + it "should create a BinaryProtocol" do + BinaryProtocolFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocol) + end + end +end diff --git a/lib/rb/spec/binary_protocol_spec_shared.rb b/lib/rb/spec/binary_protocol_spec_shared.rb new file mode 100644 index 000000000..c6608e01a --- /dev/null +++ b/lib/rb/spec/binary_protocol_spec_shared.rb @@ -0,0 +1,375 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +shared_examples_for 'a binary protocol' do + before(:each) do + @trans = Thrift::MemoryBufferTransport.new + @prot = protocol_class.new(@trans) + end + + it "should define the proper VERSION_1, VERSION_MASK AND TYPE_MASK" do + protocol_class.const_get(:VERSION_MASK).should == 0xffff0000 + protocol_class.const_get(:VERSION_1).should == 0x80010000 + protocol_class.const_get(:TYPE_MASK).should == 0x000000ff + end + + it "should make strict_read readable" do + @prot.strict_read.should eql(true) + end + + it "should make strict_write readable" do + @prot.strict_write.should eql(true) + end + + it "should write the message header" do + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == [protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::CALL, "testMessage".size, "testMessage", 17].pack("NNa11N") + end + + it "should write the message header without version when writes are not strict" do + @prot = protocol_class.new(@trans, true, false) # no strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\000\000\000\vtestMessage\001\000\000\000\021" + end + + it "should write the message header with a version when writes are strict" do + @prot = protocol_class.new(@trans) # strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\200\001\000\001\000\000\000\vtestMessage\000\000\000\021" + end + + + # message footer is a noop + + it "should write the field header" do + @prot.write_field_begin('foo', Thrift::Types::DOUBLE, 3) + @trans.read(1000).should == [Thrift::Types::DOUBLE, 3].pack("cn") + end + + # field footer is a noop + + it "should write the STOP field" do + @prot.write_field_stop + @trans.read(1).should == "\000" + end + + it "should write the map header" do + @prot.write_map_begin(Thrift::Types::STRING, Thrift::Types::LIST, 17) + @trans.read(1000).should == [Thrift::Types::STRING, Thrift::Types::LIST, 17].pack("ccN"); + end + + # map footer is a noop + + it "should write the list header" do + @prot.write_list_begin(Thrift::Types::I16, 42) + @trans.read(1000).should == [Thrift::Types::I16, 42].pack("cN") + end + + # list footer is a noop + + it "should write the set header" do + @prot.write_set_begin(Thrift::Types::I16, 42) + @trans.read(1000).should == [Thrift::Types::I16, 42].pack("cN") + end + + it "should write a bool" do + @prot.write_bool(true) + @prot.write_bool(false) + @trans.read(1000).should == "\001\000" + end + + it "should treat a nil bool as false" do + @prot.write_bool(nil) + @trans.read(1).should == "\000" + end + + it "should write a byte" do + # byte is small enough, let's check -128..127 + (-128..127).each do |i| + @prot.write_byte(i) + @trans.read(1).should == [i].pack('c') + end + # handing it numbers out of signed range should clip + @trans.rspec_verify + (128..255).each do |i| + @prot.write_byte(i) + @trans.read(1).should == [i].pack('c') + end + # and lastly, a Bignum is going to error out + lambda { @prot.write_byte(2**65) }.should raise_error(RangeError) + end + + it "should error gracefully when trying to write a nil byte" do + lambda { @prot.write_byte(nil) }.should raise_error + end + + it "should write an i16" do + # try a random scattering of values + # include the signed i16 minimum/maximum + [-2**15, -1024, 17, 0, -10000, 1723, 2**15-1].each do |i| + @prot.write_i16(i) + end + # and try something out of signed range, it should clip + @prot.write_i16(2**15 + 5) + + @trans.read(1000).should == "\200\000\374\000\000\021\000\000\330\360\006\273\177\377\200\005" + + # a Bignum should error + # lambda { @prot.write_i16(2**65) }.should raise_error(RangeError) + end + + it "should error gracefully when trying to write a nil i16" do + lambda { @prot.write_i16(nil) }.should raise_error + end + + it "should write an i32" do + # try a random scattering of values + # include the signed i32 minimum/maximum + [-2**31, -123123, -2532, -3, 0, 2351235, 12331, 2**31-1].each do |i| + @prot.write_i32(i) + end + # try something out of signed range, it should clip + @trans.read(1000).should == "\200\000\000\000" + "\377\376\037\r" + "\377\377\366\034" + "\377\377\377\375" + "\000\000\000\000" + "\000#\340\203" + "\000\0000+" + "\177\377\377\377" + [2 ** 31 + 5, 2 ** 65 + 5].each do |i| + lambda { @prot.write_i32(i) }.should raise_error(RangeError) + end + end + + it "should error gracefully when trying to write a nil i32" do + lambda { @prot.write_i32(nil) }.should raise_error + end + + it "should write an i64" do + # try a random scattering of values + # try the signed i64 minimum/maximum + [-2**63, -12356123612323, -23512351, -234, 0, 1231, 2351236, 12361236213, 2**63-1].each do |i| + @prot.write_i64(i) + end + # try something out of signed range, it should clip + @trans.read(1000).should == ["\200\000\000\000\000\000\000\000", + "\377\377\364\303\035\244+]", + "\377\377\377\377\376\231:\341", + "\377\377\377\377\377\377\377\026", + "\000\000\000\000\000\000\000\000", + "\000\000\000\000\000\000\004\317", + "\000\000\000\000\000#\340\204", + "\000\000\000\002\340\311~\365", + "\177\377\377\377\377\377\377\377"].join("") + lambda { @prot.write_i64(2 ** 65 + 5) }.should raise_error(RangeError) + end + + it "should error gracefully when trying to write a nil i64" do + lambda { @prot.write_i64(nil) }.should raise_error + end + + it "should write a double" do + # try a random scattering of values, including min/max + values = [Float::MIN,-1231.15325, -123123.23, -23.23515123, 0, 12351.1325, 523.23, Float::MAX] + values.each do |f| + @prot.write_double(f) + @trans.read(1000).should == [f].pack("G") + end + end + + it "should error gracefully when trying to write a nil double" do + lambda { @prot.write_double(nil) }.should raise_error + end + + it "should write a string" do + str = "hello world" + @prot.write_string(str) + @trans.read(1000).should == [str.size].pack("N") + str + end + + it "should error gracefully when trying to write a nil string" do + lambda { @prot.write_string(nil) }.should raise_error + end + + it "should write the message header without version when writes are not strict" do + @prot = protocol_class.new(@trans, true, false) # no strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\000\000\000\vtestMessage\001\000\000\000\021" + end + + it "should write the message header with a version when writes are strict" do + @prot = protocol_class.new(@trans) # strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\200\001\000\001\000\000\000\vtestMessage\000\000\000\021" + end + + # message footer is a noop + + it "should read a field header" do + @trans.write([Thrift::Types::STRING, 3].pack("cn")) + @prot.read_field_begin.should == [nil, Thrift::Types::STRING, 3] + end + + # field footer is a noop + + it "should read a stop field" do + @trans.write([Thrift::Types::STOP].pack("c")); + @prot.read_field_begin.should == [nil, Thrift::Types::STOP, 0] + end + + it "should read a map header" do + @trans.write([Thrift::Types::DOUBLE, Thrift::Types::I64, 42].pack("ccN")) + @prot.read_map_begin.should == [Thrift::Types::DOUBLE, Thrift::Types::I64, 42] + end + + # map footer is a noop + + it "should read a list header" do + @trans.write([Thrift::Types::STRING, 17].pack("cN")) + @prot.read_list_begin.should == [Thrift::Types::STRING, 17] + end + + # list footer is a noop + + it "should read a set header" do + @trans.write([Thrift::Types::STRING, 17].pack("cN")) + @prot.read_set_begin.should == [Thrift::Types::STRING, 17] + end + + # set footer is a noop + + it "should read a bool" do + @trans.write("\001\000"); + @prot.read_bool.should == true + @prot.read_bool.should == false + end + + it "should read a byte" do + [-128, -57, -3, 0, 17, 24, 127].each do |i| + @trans.write([i].pack("c")) + @prot.read_byte.should == i + end + end + + it "should read an i16" do + # try a scattering of values, including min/max + [-2**15, -5237, -353, 0, 1527, 2234, 2**15-1].each do |i| + @trans.write([i].pack("n")); + @prot.read_i16.should == i + end + end + + it "should read an i32" do + # try a scattering of values, including min/max + [-2**31, -235125, -6236, 0, 2351, 123123, 2**31-1].each do |i| + @trans.write([i].pack("N")) + @prot.read_i32.should == i + end + end + + it "should read an i64" do + # try a scattering of values, including min/max + [-2**63, -123512312, -6346, 0, 32, 2346322323, 2**63-1].each do |i| + @trans.write([i >> 32, i & 0xFFFFFFFF].pack("NN")) + @prot.read_i64.should == i + end + end + + it "should read a double" do + # try a random scattering of values, including min/max + [Float::MIN, -231231.12351, -323.233513, 0, 123.2351235, 2351235.12351235, Float::MAX].each do |f| + @trans.write([f].pack("G")); + @prot.read_double.should == f + end + end + + it "should read a string" do + str = "hello world" + @trans.write([str.size].pack("N") + str) + @prot.read_string.should == str + end + + it "should perform a complete rpc with no args or return" do + srv_test( + proc {|client| client.send_voidMethod()}, + proc {|client| client.recv_voidMethod.should == nil} + ) + end + + it "should perform a complete rpc with a primitive return type" do + srv_test( + proc {|client| client.send_primitiveMethod()}, + proc {|client| client.recv_primitiveMethod.should == 1} + ) + end + + it "should perform a complete rpc with a struct return type" do + srv_test( + proc {|client| client.send_structMethod()}, + proc {|client| + result = client.recv_structMethod + result.set_byte_map = nil + result.map_byte_map = nil + result.should == Fixtures::COMPACT_PROTOCOL_TEST_STRUCT + } + ) + end + + def get_socket_connection + server = Thrift::ServerSocket.new("localhost", 9090) + server.listen + + clientside = Thrift::Socket.new("localhost", 9090) + clientside.open + serverside = server.accept + [clientside, serverside, server] + end + + def srv_test(firstblock, secondblock) + clientside, serverside, server = get_socket_connection + + clientproto = protocol_class.new(clientside) + serverproto = protocol_class.new(serverside) + + processor = Srv::Processor.new(SrvHandler.new) + + client = Srv::Client.new(clientproto, clientproto) + + # first block + firstblock.call(client) + + processor.process(serverproto, serverproto) + + # second block + secondblock.call(client) + ensure + clientside.close + serverside.close + server.close + end + + class SrvHandler + def voidMethod() + end + + def primitiveMethod + 1 + end + + def structMethod + Fixtures::COMPACT_PROTOCOL_TEST_STRUCT + end + end +end diff --git a/lib/rb/spec/client_spec.rb b/lib/rb/spec/client_spec.rb new file mode 100644 index 000000000..e707d8163 --- /dev/null +++ b/lib/rb/spec/client_spec.rb @@ -0,0 +1,100 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftClientSpec < Spec::ExampleGroup + include Thrift + + class ClientSpec + include Thrift::Client + end + + before(:each) do + @prot = mock("MockProtocol") + @client = ClientSpec.new(@prot) + end + + describe Client do + it "should re-use iprot for oprot if not otherwise specified" do + @client.instance_variable_get(:'@iprot').should eql(@prot) + @client.instance_variable_get(:'@oprot').should eql(@prot) + end + + it "should send a test message" do + @prot.should_receive(:write_message_begin).with('testMessage', MessageTypes::CALL, 0) + mock_args = mock('#<TestMessage_args:mock>') + mock_args.should_receive(:foo=).with('foo') + mock_args.should_receive(:bar=).with(42) + mock_args.should_receive(:write).with(@prot) + @prot.should_receive(:write_message_end) + @prot.should_receive(:trans) do + mock('trans').tee do |trans| + trans.should_receive(:flush) + end + end + klass = stub("TestMessage_args", :new => mock_args) + @client.send_message('testMessage', klass, :foo => 'foo', :bar => 42) + end + + it "should increment the sequence id when sending messages" do + pending "it seems sequence ids are completely ignored right now" do + @prot.should_receive(:write_message_begin).with('testMessage', MessageTypes::CALL, 0).ordered + @prot.should_receive(:write_message_begin).with('testMessage2', MessageTypes::CALL, 1).ordered + @prot.should_receive(:write_message_begin).with('testMessage3', MessageTypes::CALL, 2).ordered + @prot.stub!(:write_message_end) + @prot.stub!(:trans).and_return mock("trans").as_null_object + @client.send_message('testMessage', mock("args class").as_null_object) + @client.send_message('testMessage2', mock("args class").as_null_object) + @client.send_message('testMessage3', mock("args class").as_null_object) + end + end + + it "should receive a test message" do + @prot.should_receive(:read_message_begin).and_return [nil, MessageTypes::CALL, 0] + @prot.should_receive(:read_message_end) + mock_klass = mock("#<MockClass:mock>") + mock_klass.should_receive(:read).with(@prot) + @client.receive_message(stub("MockClass", :new => mock_klass)) + end + + it "should handle received exceptions" do + @prot.should_receive(:read_message_begin).and_return [nil, MessageTypes::EXCEPTION, 0] + @prot.should_receive(:read_message_end) + ApplicationException.should_receive(:new).and_return do + StandardError.new.tee do |mock_exc| + mock_exc.should_receive(:read).with(@prot) + end + end + lambda { @client.receive_message(nil) }.should raise_error(StandardError) + end + + it "should close the transport if an error occurs while sending a message" do + @prot.stub!(:write_message_begin) + @prot.should_not_receive(:write_message_end) + mock_args = mock("#<TestMessage_args:mock>") + mock_args.should_receive(:write).with(@prot).and_raise(StandardError) + trans = mock("MockTransport") + @prot.stub!(:trans).and_return(trans) + trans.should_receive(:close) + klass = mock("TestMessage_args", :new => mock_args) + lambda { @client.send_message("testMessage", klass) }.should raise_error(StandardError) + end + end +end diff --git a/lib/rb/spec/compact_protocol_spec.rb b/lib/rb/spec/compact_protocol_spec.rb new file mode 100644 index 000000000..b9a798103 --- /dev/null +++ b/lib/rb/spec/compact_protocol_spec.rb @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +describe Thrift::CompactProtocol do + TESTS = { + :byte => (-127..127).to_a, + :i16 => (0..14).map {|shift| [1 << shift, -(1 << shift)]}.flatten.sort, + :i32 => (0..30).map {|shift| [1 << shift, -(1 << shift)]}.flatten.sort, + :i64 => (0..62).map {|shift| [1 << shift, -(1 << shift)]}.flatten.sort, + :string => ["", "1", "short", "fourteen123456", "fifteen12345678", "1" * 127, "1" * 3000], + :binary => ["", "\001", "\001" * 5, "\001" * 14, "\001" * 15, "\001" * 127, "\001" * 3000], + :double => [0.0, 1.0, -1.0, 1.1, -1.1, 10000000.1, 1.0/0.0, -1.0/0.0], + :bool => [true, false] + } + + it "should encode and decode naked primitives correctly" do + TESTS.each_pair do |primitive_type, test_values| + test_values.each do |value| + # puts "testing #{value}" if primitive_type == :i64 + trans = Thrift::MemoryBufferTransport.new + proto = Thrift::CompactProtocol.new(trans) + + proto.send(writer(primitive_type), value) + # puts "buf: #{trans.inspect_buffer}" if primitive_type == :i64 + read_back = proto.send(reader(primitive_type)) + read_back.should == value + end + end + end + + it "should encode and decode primitives in fields correctly" do + TESTS.each_pair do |primitive_type, test_values| + final_primitive_type = primitive_type == :binary ? :string : primitive_type + thrift_type = Thrift::Types.const_get(final_primitive_type.to_s.upcase) + # puts primitive_type + test_values.each do |value| + trans = Thrift::MemoryBufferTransport.new + proto = Thrift::CompactProtocol.new(trans) + + proto.write_field_begin(nil, thrift_type, 15) + proto.send(writer(primitive_type), value) + proto.write_field_end + + proto = Thrift::CompactProtocol.new(trans) + name, type, id = proto.read_field_begin + type.should == thrift_type + id.should == 15 + read_back = proto.send(reader(primitive_type)) + read_back.should == value + proto.read_field_end + end + end + end + + it "should encode and decode a monster struct correctly" do + trans = Thrift::MemoryBufferTransport.new + proto = Thrift::CompactProtocol.new(trans) + + struct = CompactProtoTestStruct.new + # sets and maps don't hash well... not sure what to do here. + struct.write(proto) + + struct2 = CompactProtoTestStruct.new + struct2.read(proto) + struct2.should == struct + end + + it "should make method calls correctly" do + client_out_trans = Thrift::MemoryBufferTransport.new + client_out_proto = Thrift::CompactProtocol.new(client_out_trans) + + client_in_trans = Thrift::MemoryBufferTransport.new + client_in_proto = Thrift::CompactProtocol.new(client_in_trans) + + processor = Srv::Processor.new(JankyHandler.new) + + client = Srv::Client.new(client_in_proto, client_out_proto) + client.send_Janky(1) + # puts client_out_trans.inspect_buffer + processor.process(client_out_proto, client_in_proto) + client.recv_Janky.should == 2 + end + + class JankyHandler + def Janky(i32arg) + i32arg * 2 + end + end + + def writer(sym) + sym = sym == :binary ? :string : sym + "write_#{sym.to_s}" + end + + def reader(sym) + sym = sym == :binary ? :string : sym + "read_#{sym.to_s}" + end +end
\ No newline at end of file diff --git a/lib/rb/spec/exception_spec.rb b/lib/rb/spec/exception_spec.rb new file mode 100644 index 000000000..fc3213781 --- /dev/null +++ b/lib/rb/spec/exception_spec.rb @@ -0,0 +1,142 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftExceptionSpec < Spec::ExampleGroup + include Thrift + + describe Exception do + it "should have an accessible message" do + e = Exception.new("test message") + e.message.should == "test message" + end + end + + describe ApplicationException do + it "should inherit from Thrift::Exception" do + ApplicationException.superclass.should == Exception + end + + it "should have an accessible type and message" do + e = ApplicationException.new + e.type.should == ApplicationException::UNKNOWN + e.message.should be_nil + e = ApplicationException.new(ApplicationException::UNKNOWN_METHOD, "test message") + e.type.should == ApplicationException::UNKNOWN_METHOD + e.message.should == "test message" + end + + it "should read a struct off of a protocol" do + prot = mock("MockProtocol") + prot.should_receive(:read_struct_begin).ordered + prot.should_receive(:read_field_begin).exactly(3).times.and_return( + ["message", Types::STRING, 1], + ["type", Types::I32, 2], + [nil, Types::STOP, 0] + ) + prot.should_receive(:read_string).ordered.and_return "test message" + prot.should_receive(:read_i32).ordered.and_return ApplicationException::BAD_SEQUENCE_ID + prot.should_receive(:read_field_end).exactly(2).times + prot.should_receive(:read_struct_end).ordered + + e = ApplicationException.new + e.read(prot) + e.message.should == "test message" + e.type.should == ApplicationException::BAD_SEQUENCE_ID + end + + it "should skip bad fields when reading a struct" do + prot = mock("MockProtocol") + prot.should_receive(:read_struct_begin).ordered + prot.should_receive(:read_field_begin).exactly(5).times.and_return( + ["type", Types::I32, 2], + ["type", Types::STRING, 2], + ["message", Types::MAP, 1], + ["message", Types::STRING, 3], + [nil, Types::STOP, 0] + ) + prot.should_receive(:read_i32).and_return ApplicationException::INVALID_MESSAGE_TYPE + prot.should_receive(:skip).with(Types::STRING).twice + prot.should_receive(:skip).with(Types::MAP) + prot.should_receive(:read_field_end).exactly(4).times + prot.should_receive(:read_struct_end).ordered + + e = ApplicationException.new + e.read(prot) + e.message.should be_nil + e.type.should == ApplicationException::INVALID_MESSAGE_TYPE + end + + it "should write a Thrift::ApplicationException struct to the oprot" do + prot = mock("MockProtocol") + prot.should_receive(:write_struct_begin).with("Thrift::ApplicationException").ordered + prot.should_receive(:write_field_begin).with("message", Types::STRING, 1).ordered + prot.should_receive(:write_string).with("test message").ordered + prot.should_receive(:write_field_begin).with("type", Types::I32, 2).ordered + prot.should_receive(:write_i32).with(ApplicationException::UNKNOWN_METHOD).ordered + prot.should_receive(:write_field_end).twice + prot.should_receive(:write_field_stop).ordered + prot.should_receive(:write_struct_end).ordered + + e = ApplicationException.new(ApplicationException::UNKNOWN_METHOD, "test message") + e.write(prot) + end + + it "should skip nil fields when writing to the oprot" do + prot = mock("MockProtocol") + prot.should_receive(:write_struct_begin).with("Thrift::ApplicationException").ordered + prot.should_receive(:write_field_begin).with("message", Types::STRING, 1).ordered + prot.should_receive(:write_string).with("test message").ordered + prot.should_receive(:write_field_end).ordered + prot.should_receive(:write_field_stop).ordered + prot.should_receive(:write_struct_end).ordered + + e = ApplicationException.new(nil, "test message") + e.write(prot) + + prot = mock("MockProtocol") + prot.should_receive(:write_struct_begin).with("Thrift::ApplicationException").ordered + prot.should_receive(:write_field_begin).with("type", Types::I32, 2).ordered + prot.should_receive(:write_i32).with(ApplicationException::BAD_SEQUENCE_ID).ordered + prot.should_receive(:write_field_end).ordered + prot.should_receive(:write_field_stop).ordered + prot.should_receive(:write_struct_end).ordered + + e = ApplicationException.new(ApplicationException::BAD_SEQUENCE_ID) + e.write(prot) + + prot = mock("MockProtocol") + prot.should_receive(:write_struct_begin).with("Thrift::ApplicationException").ordered + prot.should_receive(:write_field_stop).ordered + prot.should_receive(:write_struct_end).ordered + + e = ApplicationException.new(nil) + e.write(prot) + end + end + + describe ProtocolException do + it "should have an accessible type" do + prot = ProtocolException.new(ProtocolException::SIZE_LIMIT, "message") + prot.type.should == ProtocolException::SIZE_LIMIT + prot.message.should == "message" + end + end +end diff --git a/lib/rb/spec/http_client_spec.rb b/lib/rb/spec/http_client_spec.rb new file mode 100644 index 000000000..94526deb7 --- /dev/null +++ b/lib/rb/spec/http_client_spec.rb @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftHTTPClientTransportSpec < Spec::ExampleGroup + include Thrift + + describe HTTPClientTransport do + before(:each) do + @client = HTTPClientTransport.new("http://my.domain.com/path/to/service") + end + + it "should always be open" do + @client.should be_open + @client.close + @client.should be_open + end + + it "should post via HTTP and return the results" do + @client.write "a test" + @client.write " frame" + Net::HTTP.should_receive(:new).with("my.domain.com", 80).and_return do + mock("Net::HTTP").tee do |http| + http.should_receive(:use_ssl=).with(false) + http.should_receive(:post).with("/path/to/service", "a test frame", {"Content-Type"=>"application/x-thrift"}).and_return([nil, "data"]) + end + end + @client.flush + @client.read(10).should == "data" + end + end +end diff --git a/lib/rb/spec/mongrel_http_server_spec.rb b/lib/rb/spec/mongrel_http_server_spec.rb new file mode 100644 index 000000000..c994491cb --- /dev/null +++ b/lib/rb/spec/mongrel_http_server_spec.rb @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require 'thrift/server/mongrel_http_server' + +class ThriftHTTPServerSpec < Spec::ExampleGroup + include Thrift + + Handler = MongrelHTTPServer::Handler + + describe MongrelHTTPServer do + it "should have appropriate defaults" do + mock_factory = mock("BinaryProtocolFactory") + mock_proc = mock("Processor") + BinaryProtocolFactory.should_receive(:new).and_return(mock_factory) + Mongrel::HttpServer.should_receive(:new).with("0.0.0.0", 80).and_return do + mock("Mongrel::HttpServer").tee do |mock| + handler = mock("Handler") + Handler.should_receive(:new).with(mock_proc, mock_factory).and_return(handler) + mock.should_receive(:register).with("/", handler) + end + end + MongrelHTTPServer.new(mock_proc) + end + + it "should understand :ip, :port, :path, and :protocol_factory" do + mock_proc = mock("Processor") + mock_factory = mock("ProtocolFactory") + Mongrel::HttpServer.should_receive(:new).with("1.2.3.4", 1234).and_return do + mock("Mongrel::HttpServer").tee do |mock| + handler = mock("Handler") + Handler.should_receive(:new).with(mock_proc, mock_factory).and_return(handler) + mock.should_receive(:register).with("/foo", handler) + end + end + MongrelHTTPServer.new(mock_proc, :ip => "1.2.3.4", :port => 1234, :path => "foo", + :protocol_factory => mock_factory) + end + + it "should serve using Mongrel::HttpServer" do + BinaryProtocolFactory.stub!(:new) + Mongrel::HttpServer.should_receive(:new).and_return do + mock("Mongrel::HttpServer").tee do |mock| + Handler.stub!(:new) + mock.stub!(:register) + mock.should_receive(:run).and_return do + mock("Mongrel::HttpServer.run").tee do |runner| + runner.should_receive(:join) + end + end + end + end + MongrelHTTPServer.new(nil).serve + end + end + + describe MongrelHTTPServer::Handler do + before(:each) do + @processor = mock("Processor") + @factory = mock("ProtocolFactory") + @handler = Handler.new(@processor, @factory) + end + + it "should return 404 for non-POST requests" do + request = mock("request", :params => {"REQUEST_METHOD" => "GET"}) + response = mock("response") + response.should_receive(:start).with(404) + response.should_not_receive(:start).with(200) + @handler.process(request, response) + end + + it "should serve using application/x-thrift" do + request = mock("request", :params => {"REQUEST_METHOD" => "POST"}, :body => nil) + response = mock("response") + head = mock("head") + head.should_receive(:[]=).with("Content-Type", "application/x-thrift") + IOStreamTransport.stub!(:new) + @factory.stub!(:get_protocol) + @processor.stub!(:process) + response.should_receive(:start).with(200).and_yield(head, nil) + @handler.process(request, response) + end + + it "should use the IOStreamTransport" do + body = mock("body") + request = mock("request", :params => {"REQUEST_METHOD" => "POST"}, :body => body) + response = mock("response") + head = mock("head") + head.stub!(:[]=) + out = mock("out") + protocol = mock("protocol") + transport = mock("transport") + IOStreamTransport.should_receive(:new).with(body, out).and_return(transport) + @factory.should_receive(:get_protocol).with(transport).and_return(protocol) + @processor.should_receive(:process).with(protocol, protocol) + response.should_receive(:start).with(200).and_yield(head, out) + @handler.process(request, response) + end + end +end diff --git a/lib/rb/spec/nonblocking_server_spec.rb b/lib/rb/spec/nonblocking_server_spec.rb new file mode 100644 index 000000000..a0e86cf2e --- /dev/null +++ b/lib/rb/spec/nonblocking_server_spec.rb @@ -0,0 +1,266 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/gen-rb/nonblocking_service' + +class ThriftNonblockingServerSpec < Spec::ExampleGroup + include Thrift + include SpecNamespace + + class Handler + def initialize + @queue = Queue.new + end + + attr_accessor :server + + def greeting(english) + if english + SpecNamespace::Hello.new + else + SpecNamespace::Hello.new(:greeting => "Aloha!") + end + end + + def block + @queue.pop + end + + def unblock(n) + n.times { @queue.push true } + end + + def sleep(time) + Kernel.sleep time + end + + def shutdown + @server.shutdown(0, false) + end + end + + class SpecTransport < BaseTransport + def initialize(transport, queue) + @transport = transport + @queue = queue + @flushed = false + end + + def open? + @transport.open? + end + + def open + @transport.open + end + + def close + @transport.close + end + + def read(sz) + @transport.read(sz) + end + + def write(buf,sz=nil) + @transport.write(buf, sz) + end + + def flush + @queue.push :flushed unless @flushed or @queue.nil? + @flushed = true + @transport.flush + end + end + + class SpecServerSocket < ServerSocket + def initialize(host, port, queue) + super(host, port) + @queue = queue + end + + def listen + super + @queue.push :listen + end + end + + describe Thrift::NonblockingServer do + before(:each) do + @port = 43251 + handler = Handler.new + processor = NonblockingService::Processor.new(handler) + queue = Queue.new + @transport = SpecServerSocket.new('localhost', @port, queue) + transport_factory = FramedTransportFactory.new + logger = Logger.new(STDERR) + logger.level = Logger::WARN + @server = NonblockingServer.new(processor, @transport, transport_factory, nil, 5, logger) + handler.server = @server + @server_thread = Thread.new(Thread.current) do |master_thread| + begin + @server.serve + rescue => e + p e + puts e.backtrace * "\n" + master_thread.raise e + end + end + queue.pop + + @clients = [] + @catch_exceptions = false + end + + after(:each) do + @clients.each { |client, trans| trans.close } + # @server.shutdown(1) + @server_thread.kill + @transport.close + end + + def setup_client(queue = nil) + transport = SpecTransport.new(FramedTransport.new(Socket.new('localhost', @port)), queue) + protocol = BinaryProtocol.new(transport) + client = NonblockingService::Client.new(protocol) + transport.open + @clients << [client, transport] + client + end + + def setup_client_thread(result) + queue = Queue.new + Thread.new do + begin + client = setup_client + while (cmd = queue.pop) + msg, *args = cmd + case msg + when :block + result << client.block + when :unblock + client.unblock(args.first) + when :hello + result << client.greeting(true) # ignore result + when :sleep + client.sleep(args[0] || 0.5) + result << :slept + when :shutdown + client.shutdown + when :exit + result << :done + break + end + end + @clients.each { |c,t| t.close and break if c == client } #close the transport + rescue => e + raise e unless @catch_exceptions + end + end + queue + end + + it "should handle basic message passing" do + client = setup_client + client.greeting(true).should == Hello.new + client.greeting(false).should == Hello.new(:greeting => 'Aloha!') + @server.shutdown + end + + it "should handle concurrent clients" do + queue = Queue.new + trans_queue = Queue.new + 4.times do + Thread.new(Thread.current) do |main_thread| + begin + queue.push setup_client(trans_queue).block + rescue => e + main_thread.raise e + end + end + end + 4.times { trans_queue.pop } + setup_client.unblock(4) + 4.times { queue.pop.should be_true } + @server.shutdown + end + + it "should handle messages from more than 5 long-lived connections" do + queues = [] + result = Queue.new + 7.times do |i| + queues << setup_client_thread(result) + Thread.pass if i == 4 # give the server time to accept connections + end + client = setup_client + # block 4 connections + 4.times { |i| queues[i] << :block } + queues[4] << :hello + queues[5] << :hello + queues[6] << :hello + 3.times { result.pop.should == Hello.new } + client.greeting(true).should == Hello.new + queues[5] << [:unblock, 4] + 4.times { result.pop.should be_true } + queues[2] << :hello + result.pop.should == Hello.new + client.greeting(false).should == Hello.new(:greeting => 'Aloha!') + 7.times { queues.shift << :exit } + client.greeting(true).should == Hello.new + @server.shutdown + end + + it "should shut down when asked" do + # connect first to ensure it's running + client = setup_client + client.greeting(false) # force a message pass + @server.shutdown + @server_thread.join(2).should be_an_instance_of(Thread) + end + + it "should continue processing active messages when shutting down" do + result = Queue.new + client = setup_client_thread(result) + client << :sleep + sleep 0.1 # give the server time to start processing the client's message + @server.shutdown + @server_thread.join(2).should be_an_instance_of(Thread) + result.pop.should == :slept + end + + it "should kill active messages when they don't expire while shutting down" do + result = Queue.new + client = setup_client_thread(result) + client << [:sleep, 10] + sleep 0.1 # start processing the client's message + @server.shutdown(1) + @catch_exceptions = true + @server_thread.join(3).should_not be_nil + result.should be_empty + end + + it "should allow shutting down in response to a message" do + client = setup_client + client.greeting(true).should == Hello.new + client.shutdown + @server_thread.join(2).should_not be_nil + end + end +end diff --git a/lib/rb/spec/processor_spec.rb b/lib/rb/spec/processor_spec.rb new file mode 100644 index 000000000..d35f65284 --- /dev/null +++ b/lib/rb/spec/processor_spec.rb @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftProcessorSpec < Spec::ExampleGroup + include Thrift + + class ProcessorSpec + include Thrift::Processor + end + + describe "Processor" do + before(:each) do + @processor = ProcessorSpec.new(mock("MockHandler")) + @prot = mock("MockProtocol") + end + + def mock_trans(obj) + obj.should_receive(:trans).ordered.and_return do + mock("trans").tee do |trans| + trans.should_receive(:flush).ordered + end + end + end + + it "should call process_<message> when it receives that message" do + @prot.should_receive(:read_message_begin).ordered.and_return ['testMessage', MessageTypes::CALL, 17] + @processor.should_receive(:process_testMessage).with(17, @prot, @prot).ordered + @processor.process(@prot, @prot).should == true + end + + it "should raise an ApplicationException when the received message cannot be processed" do + @prot.should_receive(:read_message_begin).ordered.and_return ['testMessage', MessageTypes::CALL, 4] + @prot.should_receive(:skip).with(Types::STRUCT).ordered + @prot.should_receive(:read_message_end).ordered + @prot.should_receive(:write_message_begin).with('testMessage', MessageTypes::EXCEPTION, 4).ordered + ApplicationException.should_receive(:new).with(ApplicationException::UNKNOWN_METHOD, "Unknown function testMessage").and_return do + mock(ApplicationException).tee do |e| + e.should_receive(:write).with(@prot).ordered + end + end + @prot.should_receive(:write_message_end).ordered + mock_trans(@prot) + @processor.process(@prot, @prot) + end + + it "should pass args off to the args class" do + args_class = mock("MockArgsClass") + args = mock("#<MockArgsClass:mock>").tee do |args| + args.should_receive(:read).with(@prot).ordered + end + args_class.should_receive(:new).and_return args + @prot.should_receive(:read_message_end).ordered + @processor.read_args(@prot, args_class).should eql(args) + end + + it "should write out a reply when asked" do + @prot.should_receive(:write_message_begin).with('testMessage', MessageTypes::REPLY, 23).ordered + result = mock("MockResult") + result.should_receive(:write).with(@prot).ordered + @prot.should_receive(:write_message_end).ordered + mock_trans(@prot) + @processor.write_result(result, @prot, 'testMessage', 23) + end + end +end diff --git a/lib/rb/spec/serializer_spec.rb b/lib/rb/spec/serializer_spec.rb new file mode 100644 index 000000000..82f374b13 --- /dev/null +++ b/lib/rb/spec/serializer_spec.rb @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/gen-rb/thrift_spec_types' + +class ThriftSerializerSpec < Spec::ExampleGroup + include Thrift + include SpecNamespace + + describe Serializer do + it "should serialize structs to binary by default" do + serializer = Serializer.new(Thrift::BinaryProtocolAcceleratedFactory.new) + data = serializer.serialize(Hello.new(:greeting => "'Ello guv'nor!")) + data.should == "\x0B\x00\x01\x00\x00\x00\x0E'Ello guv'nor!\x00" + end + + it "should serialize structs to the given protocol" do + protocol = BaseProtocol.new(mock("transport")) + protocol.should_receive(:write_struct_begin).with("SpecNamespace::Hello") + protocol.should_receive(:write_field_begin).with("greeting", Types::STRING, 1) + protocol.should_receive(:write_string).with("Good day") + protocol.should_receive(:write_field_end) + protocol.should_receive(:write_field_stop) + protocol.should_receive(:write_struct_end) + protocol_factory = mock("ProtocolFactory") + protocol_factory.stub!(:get_protocol).and_return(protocol) + serializer = Serializer.new(protocol_factory) + serializer.serialize(Hello.new(:greeting => "Good day")) + end + end + + describe Deserializer do + it "should deserialize structs from binary by default" do + deserializer = Deserializer.new + data = "\x0B\x00\x01\x00\x00\x00\x0E'Ello guv'nor!\x00" + deserializer.deserialize(Hello.new, data).should == Hello.new(:greeting => "'Ello guv'nor!") + end + + it "should deserialize structs from the given protocol" do + protocol = BaseProtocol.new(mock("transport")) + protocol.should_receive(:read_struct_begin).and_return("SpecNamespace::Hello") + protocol.should_receive(:read_field_begin).and_return(["greeting", Types::STRING, 1], + [nil, Types::STOP, 0]) + protocol.should_receive(:read_string).and_return("Good day") + protocol.should_receive(:read_field_end) + protocol.should_receive(:read_struct_end) + protocol_factory = mock("ProtocolFactory") + protocol_factory.stub!(:get_protocol).and_return(protocol) + deserializer = Deserializer.new(protocol_factory) + deserializer.deserialize(Hello.new, "").should == Hello.new(:greeting => "Good day") + end + end +end diff --git a/lib/rb/spec/server_socket_spec.rb b/lib/rb/spec/server_socket_spec.rb new file mode 100644 index 000000000..fce501342 --- /dev/null +++ b/lib/rb/spec/server_socket_spec.rb @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + "/socket_spec_shared" + +class ThriftServerSocketSpec < Spec::ExampleGroup + include Thrift + + describe ServerSocket do + before(:each) do + @socket = ServerSocket.new(1234) + end + + it "should create a handle when calling listen" do + TCPServer.should_receive(:new).with(nil, 1234) + @socket.listen + end + + it "should accept an optional host argument" do + @socket = ServerSocket.new('localhost', 1234) + TCPServer.should_receive(:new).with('localhost', 1234) + @socket.listen + end + + it "should create a Thrift::Socket to wrap accepted sockets" do + handle = mock("TCPServer") + TCPServer.should_receive(:new).with(nil, 1234).and_return(handle) + @socket.listen + sock = mock("sock") + handle.should_receive(:accept).and_return(sock) + trans = mock("Socket") + Socket.should_receive(:new).and_return(trans) + trans.should_receive(:handle=).with(sock) + @socket.accept.should == trans + end + + it "should close the handle when closed" do + handle = mock("TCPServer", :closed? => false) + TCPServer.should_receive(:new).with(nil, 1234).and_return(handle) + @socket.listen + handle.should_receive(:close) + @socket.close + end + + it "should return nil when accepting if there is no handle" do + @socket.accept.should be_nil + end + + it "should return true for closed? when appropriate" do + handle = mock("TCPServer", :closed? => false) + TCPServer.stub!(:new).and_return(handle) + @socket.listen + @socket.should_not be_closed + handle.stub!(:close) + @socket.close + @socket.should be_closed + @socket.listen + @socket.should_not be_closed + handle.stub!(:closed?).and_return(true) + @socket.should be_closed + end + end +end diff --git a/lib/rb/spec/server_spec.rb b/lib/rb/spec/server_spec.rb new file mode 100644 index 000000000..ffe9bffa9 --- /dev/null +++ b/lib/rb/spec/server_spec.rb @@ -0,0 +1,160 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftServerSpec < Spec::ExampleGroup + include Thrift + + describe BaseServer do + it "should default to BaseTransportFactory and BinaryProtocolFactory when not specified" do + server = BaseServer.new(mock("Processor"), mock("BaseServerTransport")) + server.instance_variable_get(:'@transport_factory').should be_an_instance_of(BaseTransportFactory) + server.instance_variable_get(:'@protocol_factory').should be_an_instance_of(BinaryProtocolFactory) + end + + # serve is a noop, so can't test that + end + + shared_examples_for "servers" do + before(:each) do + @processor = mock("Processor") + @serverTrans = mock("ServerTransport") + @trans = mock("BaseTransport") + @prot = mock("BaseProtocol") + @client = mock("Client") + @server = server_type.new(@processor, @serverTrans, @trans, @prot) + end + end + + describe SimpleServer do + it_should_behave_like "servers" + + def server_type + SimpleServer + end + + it "should serve in the main thread" do + @serverTrans.should_receive(:listen).ordered + @serverTrans.should_receive(:accept).exactly(3).times.and_return(@client) + @trans.should_receive(:get_transport).exactly(3).times.with(@client).and_return(@trans) + @prot.should_receive(:get_protocol).exactly(3).times.with(@trans).and_return(@prot) + x = 0 + @processor.should_receive(:process).exactly(3).times.with(@prot, @prot).and_return do + case (x += 1) + when 1 then raise Thrift::TransportException + when 2 then raise Thrift::ProtocolException + when 3 then throw :stop + end + end + @trans.should_receive(:close).exactly(3).times + @serverTrans.should_receive(:close).ordered + lambda { @server.serve }.should throw_symbol(:stop) + end + end + + describe ThreadedServer do + it_should_behave_like "servers" + + def server_type + ThreadedServer + end + + it "should serve using threads" do + @serverTrans.should_receive(:listen).ordered + @serverTrans.should_receive(:accept).exactly(3).times.and_return(@client) + @trans.should_receive(:get_transport).exactly(3).times.with(@client).and_return(@trans) + @prot.should_receive(:get_protocol).exactly(3).times.with(@trans).and_return(@prot) + Thread.should_receive(:new).with(@prot, @trans).exactly(3).times.and_yield(@prot, @trans) + x = 0 + @processor.should_receive(:process).exactly(3).times.with(@prot, @prot).and_return do + case (x += 1) + when 1 then raise Thrift::TransportException + when 2 then raise Thrift::ProtocolException + when 3 then throw :stop + end + end + @trans.should_receive(:close).exactly(3).times + @serverTrans.should_receive(:close).ordered + lambda { @server.serve }.should throw_symbol(:stop) + end + end + + describe ThreadPoolServer do + it_should_behave_like "servers" + + def server_type + # put this stuff here so it runs before the server is created + @threadQ = mock("SizedQueue") + SizedQueue.should_receive(:new).with(20).and_return(@threadQ) + @excQ = mock("Queue") + Queue.should_receive(:new).and_return(@excQ) + ThreadPoolServer + end + + it "should set up the queues" do + @server.instance_variable_get(:'@thread_q').should be(@threadQ) + @server.instance_variable_get(:'@exception_q').should be(@excQ) + end + + it "should serve inside a thread" do + Thread.should_receive(:new).and_return do |block| + @server.should_receive(:serve) + block.call + @server.rspec_verify + end + @excQ.should_receive(:pop).and_throw(:popped) + lambda { @server.rescuable_serve }.should throw_symbol(:popped) + end + + it "should avoid running the server twice when retrying rescuable_serve" do + Thread.should_receive(:new).and_return do |block| + @server.should_receive(:serve) + block.call + @server.rspec_verify + end + @excQ.should_receive(:pop).twice.and_throw(:popped) + lambda { @server.rescuable_serve }.should throw_symbol(:popped) + lambda { @server.rescuable_serve }.should throw_symbol(:popped) + end + + it "should serve using a thread pool" do + @serverTrans.should_receive(:listen).ordered + @threadQ.should_receive(:push).with(:token) + @threadQ.should_receive(:pop) + Thread.should_receive(:new).and_yield + @serverTrans.should_receive(:accept).exactly(3).times.and_return(@client) + @trans.should_receive(:get_transport).exactly(3).times.and_return(@trans) + @prot.should_receive(:get_protocol).exactly(3).times.and_return(@prot) + x = 0 + error = RuntimeError.new("Stopped") + @processor.should_receive(:process).exactly(3).times.with(@prot, @prot).and_return do + case (x += 1) + when 1 then raise Thrift::TransportException + when 2 then raise Thrift::ProtocolException + when 3 then raise error + end + end + @trans.should_receive(:close).exactly(3).times + @excQ.should_receive(:push).with(error).and_throw(:stop) + @serverTrans.should_receive(:close) + lambda { @server.serve }.should throw_symbol(:stop) + end + end +end diff --git a/lib/rb/spec/socket_spec.rb b/lib/rb/spec/socket_spec.rb new file mode 100644 index 000000000..dd8b0f924 --- /dev/null +++ b/lib/rb/spec/socket_spec.rb @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + "/socket_spec_shared" + +class ThriftSocketSpec < Spec::ExampleGroup + include Thrift + + describe Socket do + before(:each) do + @socket = Socket.new + @handle = mock("Handle", :closed? => false) + @handle.stub!(:close) + @handle.stub!(:connect_nonblock) + ::Socket.stub!(:new).and_return(@handle) + end + + it_should_behave_like "a socket" + + it "should raise a TransportException when it cannot open a socket" do + ::Socket.should_receive(:new).and_raise(StandardError) + lambda { @socket.open }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::NOT_OPEN } + end + + it "should open a ::Socket with default args" do + ::Socket.should_receive(:new).and_return(mock("Handle", :connect_nonblock => true)) + ::Socket.should_receive(:getaddrinfo).with("localhost", 9090).and_return([[]]) + ::Socket.should_receive(:sockaddr_in) + @socket.open + end + + it "should accept host/port options" do + ::Socket.should_receive(:new).and_return(mock("Handle", :connect_nonblock => true)) + ::Socket.should_receive(:getaddrinfo).with("my.domain", 1234).and_return([[]]) + ::Socket.should_receive(:sockaddr_in) + Socket.new('my.domain', 1234).open + end + + it "should accept an optional timeout" do + ::Socket.stub!(:new) + Socket.new('localhost', 8080, 5).timeout.should == 5 + end + end +end diff --git a/lib/rb/spec/socket_spec_shared.rb b/lib/rb/spec/socket_spec_shared.rb new file mode 100644 index 000000000..96b433b8c --- /dev/null +++ b/lib/rb/spec/socket_spec_shared.rb @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +shared_examples_for "a socket" do + it "should open a socket" do + @socket.open.should == @handle + end + + it "should be open whenever it has a handle" do + @socket.should_not be_open + @socket.open + @socket.should be_open + @socket.handle = nil + @socket.should_not be_open + @socket.handle = @handle + @socket.close + @socket.should_not be_open + end + + it "should write data to the handle" do + @socket.open + @handle.should_receive(:write).with("foobar") + @socket.write("foobar") + @handle.should_receive(:write).with("fail").and_raise(StandardError) + lambda { @socket.write("fail") }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::NOT_OPEN } + end + + it "should raise an error when it cannot read from the handle" do + @socket.open + @handle.should_receive(:readpartial).with(17).and_raise(StandardError) + lambda { @socket.read(17) }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::NOT_OPEN } + end + + it "should return the data read when reading from the handle works" do + @socket.open + @handle.should_receive(:readpartial).with(17).and_return("test data") + @socket.read(17).should == "test data" + end + + it "should declare itself as closed when it has an error" do + @socket.open + @handle.should_receive(:write).with("fail").and_raise(StandardError) + @socket.should be_open + lambda { @socket.write("fail") }.should raise_error + @socket.should_not be_open + end + + it "should raise an error when the stream is closed" do + @socket.open + @handle.stub!(:closed?).and_return(true) + @socket.should_not be_open + lambda { @socket.write("fail") }.should raise_error(IOError, "closed stream") + lambda { @socket.read(10) }.should raise_error(IOError, "closed stream") + end + + it "should support the timeout accessor for read" do + @socket.timeout = 3 + @socket.open + IO.should_receive(:select).with([@handle], nil, nil, 3).and_return([[@handle], [], []]) + @handle.should_receive(:readpartial).with(17).and_return("test data") + @socket.read(17).should == "test data" + end + + it "should support the timeout accessor for write" do + @socket.timeout = 3 + @socket.open + IO.should_receive(:select).with(nil, [@handle], nil, 3).twice.and_return([[], [@handle], []]) + @handle.should_receive(:write_nonblock).with("test data").and_return(4) + @handle.should_receive(:write_nonblock).with(" data").and_return(5) + @socket.write("test data").should == 9 + end + + it "should raise an error when read times out" do + @socket.timeout = 0.5 + @socket.open + IO.should_receive(:select).with([@handle], nil, nil, 0.5).at_least(1).times.and_return(nil) + lambda { @socket.read(17) }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::TIMED_OUT } + end + + it "should raise an error when write times out" do + @socket.timeout = 0.5 + @socket.open + IO.should_receive(:select).with(nil, [@handle], nil, 0.5).any_number_of_times.and_return(nil) + lambda { @socket.write("test data") }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::TIMED_OUT } + end +end diff --git a/lib/rb/spec/spec_helper.rb b/lib/rb/spec/spec_helper.rb new file mode 100644 index 000000000..3c20bf999 --- /dev/null +++ b/lib/rb/spec/spec_helper.rb @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'rubygems' +# require at least 1.1.4 to fix a bug with describing Modules +gem 'rspec', '>= 1.1.4' +require 'spec' + +$:.unshift File.join(File.dirname(__FILE__), *%w[.. ext]) + +# pretend we already loaded fastthread, otherwise the nonblocking_server_spec +# will get screwed up +# $" << 'fastthread.bundle' + +require File.dirname(__FILE__) + '/../lib/thrift' + +class Object + # tee is a useful method, so let's let our tests have it + def tee(&block) + block.call(self) + self + end +end + +Spec::Runner.configure do |configuration| + configuration.before(:each) do + Thrift.type_checking = true + end +end + +require File.dirname(__FILE__) + "/../debug_proto_test/gen-rb/Srv" +require File.dirname(__FILE__) + "/../debug_proto_test/gen-rb/debug_proto_test_constants" + +module Fixtures + COMPACT_PROTOCOL_TEST_STRUCT = COMPACT_TEST.dup + COMPACT_PROTOCOL_TEST_STRUCT.a_binary = [0,1,2,3,4,5,6,7,8].pack('c*') + COMPACT_PROTOCOL_TEST_STRUCT.set_byte_map = nil + COMPACT_PROTOCOL_TEST_STRUCT.map_byte_map = nil +end
\ No newline at end of file diff --git a/lib/rb/spec/struct_spec.rb b/lib/rb/spec/struct_spec.rb new file mode 100644 index 000000000..23a701ecf --- /dev/null +++ b/lib/rb/spec/struct_spec.rb @@ -0,0 +1,253 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/gen-rb/thrift_spec_types' + +class ThriftStructSpec < Spec::ExampleGroup + include Thrift + include SpecNamespace + + describe Struct do + it "should iterate over all fields properly" do + fields = {} + Foo.new.each_field { |fid,field_info| fields[fid] = field_info } + fields.should == Foo::FIELDS + end + + it "should initialize all fields to defaults" do + struct = Foo.new + struct.simple.should == 53 + struct.words.should == "words" + struct.hello.should == Hello.new(:greeting => 'hello, world!') + struct.ints.should == [1, 2, 2, 3] + struct.complex.should be_nil + struct.shorts.should == Set.new([5, 17, 239]) + end + + it "should not share default values between instances" do + begin + struct = Foo.new + struct.ints << 17 + Foo.new.ints.should == [1,2,2,3] + ensure + # ensure no leakage to other tests + Foo::FIELDS[4][:default] = [1,2,2,3] + end + end + + it "should properly initialize boolean values" do + struct = BoolStruct.new(:yesno => false) + struct.yesno.should be_false + end + + it "should have proper == semantics" do + Foo.new.should_not == Hello.new + Foo.new.should == Foo.new + Foo.new(:simple => 52).should_not == Foo.new + end + + it "should read itself off the wire" do + struct = Foo.new + prot = BaseProtocol.new(mock("transport")) + prot.should_receive(:read_struct_begin).twice + prot.should_receive(:read_struct_end).twice + prot.should_receive(:read_field_begin).and_return( + ['complex', Types::MAP, 5], # Foo + ['words', Types::STRING, 2], # Foo + ['hello', Types::STRUCT, 3], # Foo + ['greeting', Types::STRING, 1], # Hello + [nil, Types::STOP, 0], # Hello + ['simple', Types::I32, 1], # Foo + ['ints', Types::LIST, 4], # Foo + ['shorts', Types::SET, 6], # Foo + [nil, Types::STOP, 0] # Hello + ) + prot.should_receive(:read_field_end).exactly(7).times + prot.should_receive(:read_map_begin).and_return( + [Types::I32, Types::MAP, 2], # complex + [Types::STRING, Types::DOUBLE, 2], # complex/1/value + [Types::STRING, Types::DOUBLE, 1] # complex/2/value + ) + prot.should_receive(:read_map_end).exactly(3).times + prot.should_receive(:read_list_begin).and_return([Types::I32, 4]) + prot.should_receive(:read_list_end) + prot.should_receive(:read_set_begin).and_return([Types::I16, 2]) + prot.should_receive(:read_set_end) + prot.should_receive(:read_i32).and_return( + 1, 14, # complex keys + 42, # simple + 4, 23, 4, 29 # ints + ) + prot.should_receive(:read_string).and_return("pi", "e", "feigenbaum", "apple banana", "what's up?") + prot.should_receive(:read_double).and_return(Math::PI, Math::E, 4.669201609) + prot.should_receive(:read_i16).and_return(2, 3) + prot.should_not_receive(:skip) + struct.read(prot) + + struct.simple.should == 42 + struct.complex.should == {1 => {"pi" => Math::PI, "e" => Math::E}, 14 => {"feigenbaum" => 4.669201609}} + struct.hello.should == Hello.new(:greeting => "what's up?") + struct.words.should == "apple banana" + struct.ints.should == [4, 23, 4, 29] + struct.shorts.should == Set.new([3, 2]) + end + + it "should skip unexpected fields in structs and use default values" do + struct = Foo.new + prot = BaseProtocol.new(mock("transport")) + prot.should_receive(:read_struct_begin) + prot.should_receive(:read_struct_end) + prot.should_receive(:read_field_begin).and_return( + ['simple', Types::I32, 1], + ['complex', Types::STRUCT, 5], + ['thinz', Types::MAP, 7], + ['foobar', Types::I32, 3], + ['words', Types::STRING, 2], + [nil, Types::STOP, 0] + ) + prot.should_receive(:read_field_end).exactly(5).times + prot.should_receive(:read_i32).and_return(42) + prot.should_receive(:read_string).and_return("foobar") + prot.should_receive(:skip).with(Types::STRUCT) + prot.should_receive(:skip).with(Types::MAP) + # prot.should_receive(:read_map_begin).and_return([Types::I32, Types::I32, 0]) + # prot.should_receive(:read_map_end) + prot.should_receive(:skip).with(Types::I32) + struct.read(prot) + + struct.simple.should == 42 + struct.complex.should be_nil + struct.words.should == "foobar" + struct.hello.should == Hello.new(:greeting => 'hello, world!') + struct.ints.should == [1, 2, 2, 3] + struct.shorts.should == Set.new([5, 17, 239]) + end + + it "should write itself to the wire" do + prot = BaseProtocol.new(mock("transport")) #mock("Protocol") + prot.should_receive(:write_struct_begin).with("SpecNamespace::Foo") + prot.should_receive(:write_struct_begin).with("SpecNamespace::Hello") + prot.should_receive(:write_struct_end).twice + prot.should_receive(:write_field_begin).with('ints', Types::LIST, 4) + prot.should_receive(:write_i32).with(1) + prot.should_receive(:write_i32).with(2).twice + prot.should_receive(:write_i32).with(3) + prot.should_receive(:write_field_begin).with('complex', Types::MAP, 5) + prot.should_receive(:write_i32).with(5) + prot.should_receive(:write_string).with('foo') + prot.should_receive(:write_double).with(1.23) + prot.should_receive(:write_field_begin).with('shorts', Types::SET, 6) + prot.should_receive(:write_i16).with(5) + prot.should_receive(:write_i16).with(17) + prot.should_receive(:write_i16).with(239) + prot.should_receive(:write_field_stop).twice + prot.should_receive(:write_field_end).exactly(6).times + prot.should_receive(:write_field_begin).with('simple', Types::I32, 1) + prot.should_receive(:write_i32).with(53) + prot.should_receive(:write_field_begin).with('hello', Types::STRUCT, 3) + prot.should_receive(:write_field_begin).with('greeting', Types::STRING, 1) + prot.should_receive(:write_string).with('hello, world!') + prot.should_receive(:write_map_begin).with(Types::I32, Types::MAP, 1) + prot.should_receive(:write_map_begin).with(Types::STRING, Types::DOUBLE, 1) + prot.should_receive(:write_map_end).twice + prot.should_receive(:write_list_begin).with(Types::I32, 4) + prot.should_receive(:write_list_end) + prot.should_receive(:write_set_begin).with(Types::I16, 3) + prot.should_receive(:write_set_end) + + struct = Foo.new + struct.words = nil + struct.complex = {5 => {"foo" => 1.23}} + struct.write(prot) + end + + it "should raise an exception if presented with an unknown container" do + # yeah this is silly, but I'm going for code coverage here + struct = Foo.new + lambda { struct.send :write_container, nil, nil, {:type => "foo"} }.should raise_error(StandardError, "Not a container type: foo") + end + + it "should support optional type-checking in Thrift::Struct.new" do + Thrift.type_checking = true + begin + lambda { Hello.new(:greeting => 3) }.should raise_error(TypeError, "Expected Types::STRING, received Fixnum for field greeting") + ensure + Thrift.type_checking = false + end + lambda { Hello.new(:greeting => 3) }.should_not raise_error(TypeError) + end + + it "should support optional type-checking in field accessors" do + Thrift.type_checking = true + begin + hello = Hello.new + lambda { hello.greeting = 3 }.should raise_error(TypeError, "Expected Types::STRING, received Fixnum for field greeting") + ensure + Thrift.type_checking = false + end + lambda { hello.greeting = 3 }.should_not raise_error(TypeError) + end + + it "should raise an exception when unknown types are given to Thrift::Struct.new" do + lambda { Hello.new(:fish => 'salmon') }.should raise_error(Exception, "Unknown key given to SpecNamespace::Hello.new: fish") + end + + it "should support `raise Xception, 'message'` for Exception structs" do + begin + raise Xception, "something happened" + rescue Thrift::Exception => e + e.message.should == "something happened" + e.code.should == 1 + # ensure it gets serialized properly, this is the really important part + prot = BaseProtocol.new(mock("trans")) + prot.should_receive(:write_struct_begin).with("SpecNamespace::Xception") + prot.should_receive(:write_struct_end) + prot.should_receive(:write_field_begin).with('message', Types::STRING, 1)#, "something happened") + prot.should_receive(:write_string).with("something happened") + prot.should_receive(:write_field_begin).with('code', Types::I32, 2)#, 1) + prot.should_receive(:write_i32).with(1) + prot.should_receive(:write_field_stop) + prot.should_receive(:write_field_end).twice + + e.write(prot) + end + end + + it "should support the regular initializer for exception structs" do + begin + raise Xception, :message => "something happened", :code => 5 + rescue Thrift::Exception => e + e.message.should == "something happened" + e.code.should == 5 + prot = BaseProtocol.new(mock("trans")) + prot.should_receive(:write_struct_begin).with("SpecNamespace::Xception") + prot.should_receive(:write_struct_end) + prot.should_receive(:write_field_begin).with('message', Types::STRING, 1) + prot.should_receive(:write_string).with("something happened") + prot.should_receive(:write_field_begin).with('code', Types::I32, 2) + prot.should_receive(:write_i32).with(5) + prot.should_receive(:write_field_stop) + prot.should_receive(:write_field_end).twice + + e.write(prot) + end + end + end +end diff --git a/lib/rb/spec/types_spec.rb b/lib/rb/spec/types_spec.rb new file mode 100644 index 000000000..d979cfb10 --- /dev/null +++ b/lib/rb/spec/types_spec.rb @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/gen-rb/thrift_spec_types' + +class ThriftTypesSpec < Spec::ExampleGroup + include Thrift + + before(:each) do + Thrift.type_checking = true + end + + after(:each) do + Thrift.type_checking = false + end + + describe "Type checking" do + it "should return the proper name for each type" do + Thrift.type_name(Types::I16).should == "Types::I16" + Thrift.type_name(Types::VOID).should == "Types::VOID" + Thrift.type_name(Types::LIST).should == "Types::LIST" + Thrift.type_name(42).should be_nil + end + + it "should check types properly" do + # lambda { Thrift.check_type(nil, Types::STOP) }.should raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::STOP}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::VOID}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::VOID}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(true, {:type => Types::BOOL}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::BOOL}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(42, {:type => Types::BYTE}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(42, {:type => Types::I16}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(42, {:type => Types::I32}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(42, {:type => Types::I64}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3.14, {:type => Types::I32}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(3.14, {:type => Types::DOUBLE}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::DOUBLE}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type("3", {:type => Types::STRING}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::STRING}, :foo) }.should raise_error(TypeError) + hello = SpecNamespace::Hello.new + lambda { Thrift.check_type(hello, {:type => Types::STRUCT, :class => SpecNamespace::Hello}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type("foo", {:type => Types::STRUCT}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type({:foo => 1}, {:type => Types::MAP}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type([1], {:type => Types::MAP}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type([1], {:type => Types::LIST}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type({:foo => 1}, {:type => Types::LIST}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(Set.new([1,2]), {:type => Types::SET}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type([1,2], {:type => Types::SET}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type({:foo => true}, {:type => Types::SET}, :foo) }.should raise_error(TypeError) + end + + it "should error out if nil is passed and skip_types is false" do + lambda { Thrift.check_type(nil, {:type => Types::BOOL}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::BYTE}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::I16}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::I32}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::I64}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::DOUBLE}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::STRING}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::STRUCT}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::LIST}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::SET}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::MAP}, :foo, false) }.should raise_error(TypeError) + end + + it "should check element types on containers" do + field = {:type => Types::LIST, :element => {:type => Types::I32}} + lambda { Thrift.check_type([1, 2], field, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type([1, nil, 2], field, :foo) }.should raise_error(TypeError) + field = {:type => Types::MAP, :key => {:type => Types::I32}, :value => {:type => Types::STRING}} + lambda { Thrift.check_type({1 => "one", 2 => "two"}, field, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type({1 => "one", nil => "nil"}, field, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type({1 => nil, 2 => "two"}, field, :foo) }.should raise_error(TypeError) + field = {:type => Types::SET, :element => {:type => Types::I32}} + lambda { Thrift.check_type(Set.new([1, 2]), field, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(Set.new([1, nil, 2]), field, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(Set.new([1, 2.3, 2]), field, :foo) }.should raise_error(TypeError) + + field = {:type => Types::STRUCT, :class => SpecNamespace::Hello} + lambda { Thrift.check_type(SpecNamespace::BoolStruct, field, :foo) }.should raise_error(TypeError) + end + + it "should give the TypeError a readable message" do + msg = "Expected Types::STRING, received Fixnum for field foo" + lambda { Thrift.check_type(3, {:type => Types::STRING}, :foo) }.should raise_error(TypeError, msg) + msg = "Expected Types::STRING, received Fixnum for field foo.element" + field = {:type => Types::LIST, :element => {:type => Types::STRING}} + lambda { Thrift.check_type([3], field, :foo) }.should raise_error(TypeError, msg) + msg = "Expected Types::I32, received NilClass for field foo.element.key" + field = {:type => Types::LIST, + :element => {:type => Types::MAP, + :key => {:type => Types::I32}, + :value => {:type => Types::I32}}} + lambda { Thrift.check_type([{nil => 3}], field, :foo) }.should raise_error(TypeError, msg) + msg = "Expected Types::I32, received NilClass for field foo.element.value" + lambda { Thrift.check_type([{1 => nil}], field, :foo) }.should raise_error(TypeError, msg) + end + end +end diff --git a/lib/rb/spec/unix_socket_spec.rb b/lib/rb/spec/unix_socket_spec.rb new file mode 100644 index 000000000..df239d7ab --- /dev/null +++ b/lib/rb/spec/unix_socket_spec.rb @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + "/socket_spec_shared" + +class ThriftUNIXSocketSpec < Spec::ExampleGroup + include Thrift + + describe UNIXSocket do + before(:each) do + @path = '/tmp/thrift_spec_socket' + @socket = UNIXSocket.new(@path) + @handle = mock("Handle", :closed? => false) + @handle.stub!(:close) + ::UNIXSocket.stub!(:new).and_return(@handle) + end + + it_should_behave_like "a socket" + + it "should raise a TransportException when it cannot open a socket" do + ::UNIXSocket.should_receive(:new).and_raise(StandardError) + lambda { @socket.open }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::NOT_OPEN } + end + + it "should accept an optional timeout" do + ::UNIXSocket.stub!(:new) + UNIXSocket.new(@path, 5).timeout.should == 5 + end + end + + describe UNIXServerSocket do + before(:each) do + @path = '/tmp/thrift_spec_socket' + @socket = UNIXServerSocket.new(@path) + end + + it "should create a handle when calling listen" do + UNIXServer.should_receive(:new).with(@path) + @socket.listen + end + + it "should create a Thrift::UNIXSocket to wrap accepted sockets" do + handle = mock("UNIXServer") + UNIXServer.should_receive(:new).with(@path).and_return(handle) + @socket.listen + sock = mock("sock") + handle.should_receive(:accept).and_return(sock) + trans = mock("UNIXSocket") + UNIXSocket.should_receive(:new).and_return(trans) + trans.should_receive(:handle=).with(sock) + @socket.accept.should == trans + end + + it "should close the handle when closed" do + handle = mock("UNIXServer", :closed? => false) + UNIXServer.should_receive(:new).with(@path).and_return(handle) + @socket.listen + handle.should_receive(:close) + File.stub!(:delete) + @socket.close + end + + it "should delete the socket when closed" do + handle = mock("UNIXServer", :closed? => false) + UNIXServer.should_receive(:new).with(@path).and_return(handle) + @socket.listen + handle.stub!(:close) + File.should_receive(:delete).with(@path) + @socket.close + end + + it "should return nil when accepting if there is no handle" do + @socket.accept.should be_nil + end + + it "should return true for closed? when appropriate" do + handle = mock("UNIXServer", :closed? => false) + UNIXServer.stub!(:new).and_return(handle) + File.stub!(:delete) + @socket.listen + @socket.should_not be_closed + handle.stub!(:close) + @socket.close + @socket.should be_closed + @socket.listen + @socket.should_not be_closed + handle.stub!(:closed?).and_return(true) + @socket.should be_closed + end + end +end diff --git a/lib/st/README b/lib/st/README new file mode 100644 index 000000000..be865b8f9 --- /dev/null +++ b/lib/st/README @@ -0,0 +1,35 @@ +Thrift SmallTalk Software Library + +Last updated Nov 2007 + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Library +======= + +To get started, just file in thrift.st with Squeak, run thrift -st +on the tutorial .thrift files (and file in the resulting code), and +then: + +calc := CalculatorClient binaryOnHost: 'localhost' port: '9090' +calc addNum1: 10 num2: 15 + +Tested in Squeak 3.7, but should work fine with anything later. diff --git a/lib/st/thrift.st b/lib/st/thrift.st new file mode 100644 index 000000000..6883539ff --- /dev/null +++ b/lib/st/thrift.st @@ -0,0 +1,812 @@ +" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +License); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +" + +SystemOrganization addCategory: #Thrift! +SystemOrganization addCategory: #'Thrift-Protocol'! +SystemOrganization addCategory: #'Thrift-Transport'! + +Error subclass: #TError + instanceVariableNames: 'code' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift'! + +!TError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:28'! +signalWithCode: anInteger + self new code: anInteger; signal! ! + +!TError methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:28'! +code + ^ code! ! + +!TError methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:28'! +code: anInteger + code := anInteger! ! + +TError subclass: #TProtocolError + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:39'! +badVersion + ^ 4! ! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:39'! +invalidData + ^ 1! ! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:39'! +negativeSize + ^ 2! ! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:40'! +sizeLimit + ^ 3! ! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:40'! +unknown + ^ 0! ! + +TError subclass: #TTransportError + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Transport'! + +TTransportError subclass: #TTransportClosedError + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Transport'! + +Object subclass: #TClient + instanceVariableNames: 'iprot oprot seqid remoteSeqid' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift'! + +!TClient class methodsFor: 'as yet unclassified' stamp: 'pc 11/7/2007 06:00'! +binaryOnHost: aString port: anInteger + | sock | + sock := TSocket new host: aString; port: anInteger; open; yourself. + ^ self new + inProtocol: (TBinaryProtocol new transport: sock); + yourself! ! + +!TClient methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 23:03'! +inProtocol: aProtocol + iprot := aProtocol. + oprot ifNil: [oprot := aProtocol]! ! + +!TClient methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 04:28'! +nextSeqid + ^ seqid + ifNil: [seqid := 0] + ifNotNil: [seqid := seqid + 1]! ! + +!TClient methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:51'! +outProtocol: aProtocol + oprot := aProtocol! ! + +!TClient methodsFor: 'as yet unclassified' stamp: 'pc 10/28/2007 15:32'! +validateRemoteMessage: aMsg + remoteSeqid + ifNil: [remoteSeqid := aMsg seqid] + ifNotNil: + [(remoteSeqid + 1) = aMsg seqid ifFalse: + [TProtocolError signal: 'Bad seqid: ', aMsg seqid asString, + '; wanted: ', remoteSeqid asString]. + remoteSeqid := aMsg seqid]! ! + +Object subclass: #TField + instanceVariableNames: 'name type id' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:05'! +id + ^ id ifNil: [0]! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:44'! +id: anInteger + id := anInteger! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:04'! +name + ^ name ifNil: ['']! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:44'! +name: anObject + name := anObject! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:05'! +type + ^ type ifNil: [TType stop]! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:44'! +type: anInteger + type := anInteger! ! + +Object subclass: #TMessage + instanceVariableNames: 'name seqid type' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +TMessage subclass: #TCallMessage + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TCallMessage methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:53'! +type + ^ 1! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:05'! +name + ^ name ifNil: ['']! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:35'! +name: aString + name := aString! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:05'! +seqid + ^ seqid ifNil: [0]! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:35'! +seqid: anInteger + seqid := anInteger! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:06'! +type + ^ type ifNil: [0]! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:35'! +type: anInteger + type := anInteger! ! + +Object subclass: #TProtocol + instanceVariableNames: 'transport' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +TProtocol subclass: #TBinaryProtocol + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 04:24'! +intFromByteArray: buf + | vals | + vals := Array new: buf size. + 1 to: buf size do: [:n | vals at: n put: ((buf at: n) bitShift: (buf size - n) * 8)]. + ^ vals sum! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 18:46'! +readBool + ^ self readByte isZero not! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/25/2007 00:02'! +readByte + ^ (self transport read: 1) first! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/28/2007 16:24'! +readDouble + | val | + val := Float new: 2. + ^ val basicAt: 1 put: (self readRawInt: 4); + basicAt: 2 put: (self readRawInt: 4); + yourself! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 20:02'! +readFieldBegin + | field | + field := TField new type: self readByte. + + ^ field type = TType stop + ifTrue: [field] + ifFalse: [field id: self readI16; yourself]! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:15'! +readI16 + ^ self readInt: 2! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:20'! +readI32 + ^ self readInt: 4! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:20'! +readI64 + ^ self readInt: 8! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 02:35'! +readInt: size + | buf val | + buf := transport read: size. + val := self intFromByteArray: buf. + ^ buf first > 16r7F + ifTrue: [self unsignedInt: val size: size] + ifFalse: [val]! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:57'! +readListBegin + ^ TList new + elemType: self readByte; + size: self readI32! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:58'! +readMapBegin + ^ TMap new + keyType: self readByte; + valueType: self readByte; + size: self readI32! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 04:22'! +readMessageBegin + | version | + version := self readI32. + + (version bitAnd: self versionMask) = self version1 + ifFalse: [TProtocolError signalWithCode: TProtocolError badVersion]. + + ^ TMessage new + type: (version bitAnd: 16r000000FF); + name: self readString; + seqid: self readI32! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/28/2007 16:24'! +readRawInt: size + ^ self intFromByteArray: (transport read: size)! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 00:59'! +readSetBegin + "element type, size" + ^ TSet new + elemType: self readByte; + size: self readI32! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 02/07/2009 19:00'! +readString +readString + | sz | + sz := self readI32. + ^ sz > 0 ifTrue: [(transport read: sz) asString] ifFalse: ['']! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 04:22'! +unsignedInt: val size: size + ^ 0 - ((val - 1) bitXor: ((2 raisedTo: (size * 8)) - 1))! ! + +!TBinaryProtocol methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:13'! +version1 + ^ 16r80010000 ! ! + +!TBinaryProtocol methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:01'! +versionMask + ^ 16rFFFF0000! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 18:35'! +write: aString + transport write: aString! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:23'! +writeBool: bool + bool ifTrue: [self writeByte: 1] + ifFalse: [self writeByte: 0]! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/26/2007 09:31'! +writeByte: aNumber + aNumber > 16rFF ifTrue: [TError signal: 'writeByte too big']. + transport write: (Array with: aNumber)! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/28/2007 16:16'! +writeDouble: aDouble + self writeI32: (aDouble basicAt: 1); + writeI32: (aDouble basicAt: 2)! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:56'! +writeField: aField + self writeByte: aField type; + writeI16: aField id! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/25/2007 00:01'! +writeFieldBegin: aField + self writeByte: aField type. + self writeI16: aField id! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 18:04'! +writeFieldStop + self writeByte: TType stop! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 02:06'! +writeI16: i16 + self writeInt: i16 size: 2! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 02:06'! +writeI32: i32 + self writeInt: i32 size: 4! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 02:06'! +writeI64: i64 + self writeInt: i64 size: 8! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 04:23'! +writeInt: val size: size + 1 to: size do: [:n | self writeByte: ((val bitShift: (size negated + n) * 8) bitAnd: 16rFF)]! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 00:48'! +writeListBegin: aList + self writeByte: aList elemType; writeI32: aList size! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:55'! +writeMapBegin: aMap + self writeByte: aMap keyType; + writeByte: aMap valueType; + writeI32: aMap size! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 20:36'! +writeMessageBegin: msg + self writeI32: (self version1 bitOr: msg type); + writeString: msg name; + writeI32: msg seqid! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 00:56'! +writeSetBegin: aSet + self writeByte: aSet elemType; writeI32: aSet size! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 18:35'! +writeString: aString + self writeI32: aString size; + write: aString! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readBool! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readByte! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readDouble! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readFieldBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readFieldEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readI16! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readI32! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readI64! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readListBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readListEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readMapBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readMapEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:39'! +readMessageBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:39'! +readMessageEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readSetBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readSetEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/25/2007 16:10'! +readSimpleType: aType + aType = TType bool ifTrue: [^ self readBool]. + aType = TType byte ifTrue: [^ self readByte]. + aType = TType double ifTrue: [^ self readDouble]. + aType = TType i16 ifTrue: [^ self readI16]. + aType = TType i32 ifTrue: [^ self readI32]. + aType = TType i64 ifTrue: [^ self readI64]. + aType = TType list ifTrue: [^ self readBool].! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readString! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readStructBegin + ! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readStructEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/26/2007 21:34'! +skip: aType + aType = TType stop ifTrue: [^ self]. + aType = TType bool ifTrue: [^ self readBool]. + aType = TType byte ifTrue: [^ self readByte]. + aType = TType i16 ifTrue: [^ self readI16]. + aType = TType i32 ifTrue: [^ self readI32]. + aType = TType i64 ifTrue: [^ self readI64]. + aType = TType string ifTrue: [^ self readString]. + aType = TType double ifTrue: [^ self readDouble]. + aType = TType struct ifTrue: + [| field | + self readStructBegin. + [(field := self readFieldBegin) type = TType stop] whileFalse: + [self skip: field type. self readFieldEnd]. + ^ self readStructEnd]. + aType = TType map ifTrue: + [| map | + map := self readMapBegin. + map size timesRepeat: [self skip: map keyType. self skip: map valueType]. + ^ self readMapEnd]. + aType = TType list ifTrue: + [| list | + list := self readListBegin. + list size timesRepeat: [self skip: list elemType]. + ^ self readListEnd]. + aType = TType set ifTrue: + [| set | + set := self readSetBegin. + set size timesRepeat: [self skip: set elemType]. + ^ self readSetEnd]. + + self error: 'Unknown type'! ! + +!TProtocol methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 23:02'! +transport + ^ transport! ! + +!TProtocol methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:28'! +transport: aTransport + transport := aTransport! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeBool: aBool! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeByte: aByte! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:38'! +writeDouble: aFloat! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:38'! +writeFieldBegin: aField! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeFieldEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeFieldStop! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeI16: i16! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeI32: i32! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeI64: i64! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:39'! +writeListBegin: aList! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeListEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:39'! +writeMapBegin: aMap! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeMapEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:36'! +writeMessageBegin! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:36'! +writeMessageEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:39'! +writeSetBegin: aSet! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeSetEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:38'! +writeString: aString! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:38'! +writeStructBegin: aStruct! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeStructEnd! ! + +Object subclass: #TResult + instanceVariableNames: 'success oprot iprot exception' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift'! + +!TResult methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 21:35'! +exception + ^ exception! ! + +!TResult methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 21:35'! +exception: anError + exception := anError! ! + +!TResult methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 14:43'! +success + ^ success! ! + +!TResult methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 14:43'! +success: anObject + success := anObject! ! + +Object subclass: #TSizedObject + instanceVariableNames: 'size' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +TSizedObject subclass: #TList + instanceVariableNames: 'elemType' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TList methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:04'! +elemType + ^ elemType ifNil: [TType stop]! ! + +!TList methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:42'! +elemType: anInteger + elemType := anInteger! ! + +TList subclass: #TSet + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +TSizedObject subclass: #TMap + instanceVariableNames: 'keyType valueType' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TMap methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:04'! +keyType + ^ keyType ifNil: [TType stop]! ! + +!TMap methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:45'! +keyType: anInteger + keyType := anInteger! ! + +!TMap methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:04'! +valueType + ^ valueType ifNil: [TType stop]! ! + +!TMap methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:45'! +valueType: anInteger + valueType := anInteger! ! + +!TSizedObject methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 20:03'! +size + ^ size ifNil: [0]! ! + +!TSizedObject methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 20:06'! +size: anInteger + size := anInteger! ! + +Object subclass: #TSocket + instanceVariableNames: 'host port stream' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Transport'! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:34'! +close + self isOpen ifTrue: [stream close]! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:23'! +connect + ^ (self socketStream openConnectionToHost: + (NetNameResolver addressForName: host) port: port) + timeout: 180; + binary; + yourself! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 20:35'! +flush + stream flush! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:08'! +host: aString + host := aString! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 20:34'! +isOpen + ^ stream isNil not + and: [stream socket isConnected] + and: [stream socket isOtherEndClosed not]! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:22'! +open + stream := self connect! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:09'! +port: anInteger + port := anInteger! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:17'! +read: size + | data | + [data := stream next: size. + data isEmpty ifTrue: [TTransportError signal: 'Could not read ', size asString, ' bytes']. + ^ data] + on: ConnectionClosed + do: [TTransportClosedError signal]! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:18'! +socketStream + ^ Smalltalk at: #FastSocketStream ifAbsent: [SocketStream] ! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:17'! +write: aCollection + [stream nextPutAll: aCollection] + on: ConnectionClosed + do: [TTransportClosedError signal]! ! + +Object subclass: #TStruct + instanceVariableNames: 'name' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TStruct methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:47'! +name + ^ name! ! + +!TStruct methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:47'! +name: aString + name := aString! ! + +Object subclass: #TTransport + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Transport'! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:18'! +close + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:22'! +flush + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:18'! +isOpen + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:18'! +open + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:18'! +read: anInteger + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:22'! +readAll: anInteger + ^ String streamContents: [:str | + [str size < anInteger] whileTrue: + [str nextPutAll: (self read: anInteger - str size)]]! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:22'! +write: aString + self subclassResponsibility! ! + +Object subclass: #TType + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift'! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +bool + ^ 2! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +byte + ^ 3! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/25/2007 15:55'! +codeOf: aTypeName + self typeMap do: [:each | each first = aTypeName ifTrue: [^ each second]]. + ^ nil! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +double + ^ 4! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +i16 + ^ 6! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +i32 + ^ 8! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +i64 + ^ 10! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +list + ^ 15! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +map + ^ 13! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/25/2007 15:56'! +nameOf: aTypeCode + self typeMap do: [:each | each second = aTypeCode ifTrue: [^ each first]]. + ^ nil! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +set + ^ 14! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +stop + ^ 0! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +string + ^ 11! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +struct + ^ 12! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/25/2007 15:51'! +typeMap + ^ #((bool 2) (byte 3) (double 4) (i16 6) (i32 8) (i64 10) (list 15) + (map 13) (set 15) (stop 0) (string 11) (struct 12) (void 1))! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +void + ^ 1! ! |